编辑
2025-11-21
C#
00

目录

摘要
正文
nuget 安装ml.net
准备数据
创建数据类
加载和转换数据
评估模型
使用预测模型

摘要


机器学习中的回归模型是一种用于预测连续值输出的常见算法。与分类模型不同,回归模型的目标是通过学习输入特征和对应的输出值之间的关系,来预测新的输入特征对应的输出值。这种模型通常被用于解决各种实际问题,例如房价预测、销售预测、股票价格预测等。

回归模型在ML.NET中是一个广泛使用的模型,因为它们非常适合于解决许多实际问题。与其他机器学习模型不同,回归模型可以预测一个数值,这使得它们非常适合于各种领域,例如商业、医学、工业和金融等。在ML.NET中,可以使用许多回归算法来训练和评估回归模型,例如线性回归、决策树回归、支持向量机回归等。

线性回归是一种简单而常用的回归算法,它的目标是找到一个线性函数来描述输入特征和输出值之间的关系。决策树回归是一种非线性的回归算法,它使用树形结构来描述输入特征和输出值之间的关系。支持向量机回归是一种常用的回归算法,它使用核函数来将输入特征映射到高维空间,以便更好地分离不同的数据点。

正文


nuget 安装ml.net

image.png

准备数据

image.png

提供的数据集包含以下列:

  • vendor_id: 出租车供应商的 ID 是一项特征。
  • rate_code: 出租车行程的费率类型是一项特征。
  • passenger_count: 行程中的乘客人数是一项特征。
  • trip_time_in_secs: 这次行程所花的时间。 希望在行程完成前预测行程费用。 当时并不知道行程有多长。 因此,行程时间不是一项特征,需要从模型删除此列。
  • trip_distance: 行程距离是一项特征。
  • payment_type: 付款方式(现金或信用卡)是一项特征。
  • fare_amount: 支付的总出租车费用是一个标签。

创建数据类

C#
public class TaxiTrip { [LoadColumn(0)] public string? VendorId; [LoadColumn(1)] public string? RateCode; [LoadColumn(2)] public float PassengerCount; [LoadColumn(3)] public float TripTime; [LoadColumn(4)] public float TripDistance; [LoadColumn(5)] public string? PaymentType; [LoadColumn(6)] public float FareAmount; }
C#
public class TaxiTripFarePrediction { [ColumnName("Score")] public float FareAmount; }

TaxiTrip 是输入数据类且具有针对每个数据集列的定义。 使用 LoadColumnAttribute 属性在数据集中指定源列的索引。

TaxiTripFarePrediction 类表示预测的结果。 它应用了单个浮动 FareAmount 字段,附带 ScoreColumnNameAttribute 属性。 对于回归任务,“分数”列包含预测的标签值。

加载和转换数据

image.png

  • 加载数据。
  • 提取并转换数据。
  • 定型模型。
  • 返回模型。
C#
static ITransformer Train(MLContext mlContext, string dataPath) { // 加载数据集 IDataView dataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(dataPath, hasHeader: true, separatorChar: ','); // 创建数据处理流水线 var pipeline = mlContext.Transforms.CopyColumns(outputColumnName: "Label", inputColumnName: "FareAmount") // 将分类特征进行独热编码 .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "VendorIdEncoded", inputColumnName: "VendorId")) .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "RateCodeEncoded", inputColumnName: "RateCode")) .Append(mlContext.Transforms.Categorical.OneHotEncoding(outputColumnName: "PaymentTypeEncoded", inputColumnName: "PaymentType")) // 将所有特征连接成一个特征向量 .Append(mlContext.Transforms.Concatenate("Features", "VendorIdEncoded", "RateCodeEncoded", "PassengerCount", "TripDistance", "PaymentTypeEncoded")) // 使用FastTree算法进行回归训练 .Append(mlContext.Regression.Trainers.FastTree()); // 训练模型 var model = pipeline.Fit(dataView); // 返回训练好的模型 return model; }

评估模型

  • 加载测试数据集。
  • 创建回归计算器。
  • 评估模型并创建指标。
  • 显示指标。
C#
static void Evaluate(MLContext mlContext, ITransformer model) { // 加载测试数据集 IDataView dataView = mlContext.Data.LoadFromTextFile<TaxiTrip>(_testDataPath, hasHeader: true, separatorChar: ','); // 使用训练好的模型对测试数据进行预测 var predictions = model.Transform(dataView); // 评估模型的质量指标 var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score"); // 打印评估结果 Console.WriteLine(); Console.WriteLine($"*************************************************"); Console.WriteLine($"* Model quality metrics evaluation "); Console.WriteLine($"*------------------------------------------------"); Console.WriteLine($"* RSquared Score: {metrics.RSquared:0.##}"); Console.WriteLine($"* Root Mean Squared Error: {metrics.RootMeanSquaredError:#.##}"); }

RSquared 是回归模型的另一种评估指标。 RSquared 在 0 和 1 之间取值。 值越接近 1,模型就越好。 将以下代码添加到 Evaluate 方法以显示 RSquared 值:

RMS 是回归模型的一种评估指标。 指标越低,模型就越好。

使用预测模型

  • 创建测试数据的单个注释。
  • 根据测试数据预测费用。
  • 结合测试数据和预测进行报告。
  • 显示预测结果。
C#
static void TestSinglePrediction(MLContext mlContext, ITransformer model) { // 创建用于预测的预测引擎 var predictionFunction = mlContext.Model.CreatePredictionEngine<TaxiTrip, TaxiTripFarePrediction>(model); // 创建一个样本数据用于预测 var taxiTripSample = new TaxiTrip() { VendorId = "VTS", RateCode = "1", PassengerCount = 1, TripTime = 1140, TripDistance = 3.75f, PaymentType = "CRD", FareAmount = 0 // 待预测的值。实际观测值为 15.5 }; // 使用预测引擎进行预测 var prediction = predictionFunction.Predict(taxiTripSample); // 打印预测结果和实际观测值 Console.WriteLine($"**********************************************************************"); Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5"); Console.WriteLine($"**********************************************************************"); }

image.png

本文作者:技术老小子

本文链接:

版权声明:本博客所有文章除特别声明外,均采用 BY-NC-SA 许可协议。转载请注明出处!