ML.NET教程之出租车车费预测(回归问题)

理解问题 出租车的车费不仅与距离有关,还涉及乘客数量,是否使用信用卡等因素(这是的出租车是指纽约市的)。所以并不是一个简单的一元方程问题。 准备数据 建立一控制台应用程序工程,新建Data文件夹,在其目录下添加taxi-fare-train.csv与taxi-fare-test.csv文件,不要忘了把它们的Copy to Output Directory属性改为Copy if newer。之后,添加Microsoft.ML类库包。 加载数据 新建MLContext对象,及创建TextLoader对象。TextLoader对象可用于从文件中读取数据。 MLContext mlContext = new MLContext(seed: 0); _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments() { Separator = ",", HasHeader = true, Column = new[] { new TextLoader.Column("VendorId", DataKind.Text, 0), new TextLoader.Column("RateCode", DataKind.Text, 1), new TextLoader.Column("PassengerCount", DataKind.R4, 2), new TextLoader.Column("TripTime", DataKind.R4, 3), new TextLoader.Column("TripDistance", DataKind.R4, 4), new TextLoader.Column("PaymentType", DataKind.Text, 5), new TextLoader.Column("FareAmount", DataKind.R4, 6) } }); 提取特征 数据集文件里共有七列,前六列做为特征数据,最后一列是标记数据。 public class TaxiTrip { [Column("0")] public string VendorId; [Column("1")] public string RateCode; [Column("2")] public float PassengerCount; [Column("3")] public float TripTime; [Column("4")] public float TripDistance; [Column("5")] public string PaymentType; [Column("6")] public float FareAmount; } public class TaxiTripFarePrediction { [ColumnName("Score")] public float FareAmount; } 训练模型 首先读取训练数据集,其次建立管道。管道中第一步是把FareAmount列复制到Label列,做为标记数据。第二步,通过OneHotEncoding方式将VendorId,RateCode,PaymentType三个字符串类型列转换成数值类型列。第三步,合并六个数据列为一个特征数据列。最后一步,选择FastTreeRegressionTrainer算法做为训练方法。 完成管道后,开始训练模型。 IDataView dataView = _textLoader.Read(dataPath); var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label") .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId")) .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode")) .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType")) .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType")) .Append(mlContext.Regression.Trainers.FastTree()); var model = pipeline.Fit(dataView); 评估模型 这里要使用测试数据集,并用回归问题的Evaluate方法进行评估。 IDataView dataView = _textLoader.Read(_testDataPath); var predictions = model.Transform(dataView); var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score"); Console.WriteLine(); Console.WriteLine($"*************************************************"); Console.WriteLine($"* Model quality metrics evaluation "); Console.WriteLine($"*------------------------------------------------"); Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}"); 保存模型 完成训练的模型可以被保存为zip文件以备之后使用。 using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write)) mlContext.Model.Save(model, fileStream); 使用模型 首先加载已经保存的模型。接着建立预测函数对象,TaxiTrip为函数的输入类型,TaxiTripFarePrediction为输出类型。之后执行预测方法,传入待测数据。 ITransformer loadedModel; using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read)) { loadedModel = mlContext.Model.Load(stream); } var predictionFunction = loadedModel.MakePredictionFunction(mlContext); var taxiTripSample = new TaxiTrip() { VendorId = "VTS", RateCode = "1", PassengerCount = 1, TripTime = 1140, TripDistance = 3.75f, PaymentType = "CRD", FareAmount = 0 // To predict. Actual/Observed = 15.5 }; var prediction = predictionFunction.Predict(taxiTripSample); Console.WriteLine($"**********************************************************************"); Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5"); Console.WriteLine($"**********************************************************************"); 完整示例代码 using Microsoft.ML; using Microsoft.ML.Core.Data; using Microsoft.ML.Runtime.Data; using System; using System.IO; namespace TexiFarePredictor { class Program { static readonly string _trainDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-train.csv"); static readonly string _testDataPath = Path.Combine(Environment.CurrentDirectory, "Data", "taxi-fare-test.csv"); static readonly string _modelPath = Path.Combine(Environment.CurrentDirectory, "Data", "Model.zip"); static TextLoader _textLoader; static void Main(string[] args) { MLContext mlContext = new MLContext(seed: 0); _textLoader = mlContext.Data.TextReader(new TextLoader.Arguments() { Separator = ",", HasHeader = true, Column = new[] { new TextLoader.Column("VendorId", DataKind.Text, 0), new TextLoader.Column("RateCode", DataKind.Text, 1), new TextLoader.Column("PassengerCount", DataKind.R4, 2), new TextLoader.Column("TripTime", DataKind.R4, 3), new TextLoader.Column("TripDistance", DataKind.R4, 4), new TextLoader.Column("PaymentType", DataKind.Text, 5), new TextLoader.Column("FareAmount", DataKind.R4, 6) } }); var model = Train(mlContext, _trainDataPath); Evaluate(mlContext, model); TestSinglePrediction(mlContext); Console.Read(); } public static ITransformer Train(MLContext mlContext, string dataPath) { IDataView dataView = _textLoader.Read(dataPath); var pipeline = mlContext.Transforms.CopyColumns("FareAmount", "Label") .Append(mlContext.Transforms.Categorical.OneHotEncoding("VendorId")) .Append(mlContext.Transforms.Categorical.OneHotEncoding("RateCode")) .Append(mlContext.Transforms.Categorical.OneHotEncoding("PaymentType")) .Append(mlContext.Transforms.Concatenate("Features", "VendorId", "RateCode", "PassengerCount", "TripTime", "TripDistance", "PaymentType")) .Append(mlContext.Regression.Trainers.FastTree()); var model = pipeline.Fit(dataView); SaveModelAsFile(mlContext, model); return model; } private static void SaveModelAsFile(MLContext mlContext, ITransformer model) { using (var fileStream = new FileStream(_modelPath, FileMode.Create, FileAccess.Write, FileShare.Write)) mlContext.Model.Save(model, fileStream); } private static void Evaluate(MLContext mlContext, ITransformer model) { IDataView dataView = _textLoader.Read(_testDataPath); var predictions = model.Transform(dataView); var metrics = mlContext.Regression.Evaluate(predictions, "Label", "Score"); Console.WriteLine(); Console.WriteLine($"*************************************************"); Console.WriteLine($"* Model quality metrics evaluation "); Console.WriteLine($"*------------------------------------------------"); Console.WriteLine($"* R2 Score: {metrics.RSquared:0.##}"); Console.WriteLine($"* RMS loss: {metrics.Rms:#.##}"); } private static void TestSinglePrediction(MLContext mlContext) { ITransformer loadedModel; using (var stream = new FileStream(_modelPath, FileMode.Open, FileAccess.Read, FileShare.Read)) { loadedModel = mlContext.Model.Load(stream); } var predictionFunction = loadedModel.MakePredictionFunction(mlContext); var taxiTripSample = new TaxiTrip() { VendorId = "VTS", RateCode = "1", PassengerCount = 1, TripTime = 1140, TripDistance = 3.75f, PaymentType = "CRD", FareAmount = 0 // To predict. Actual/Observed = 15.5 }; var prediction = predictionFunction.Predict(taxiTripSample); Console.WriteLine($"**********************************************************************"); Console.WriteLine($"Predicted fare: {prediction.FareAmount:0.####}, actual fare: 15.5"); Console.WriteLine($"**********************************************************************"); } } } 程序运行后显示的结果: ************************************************* * Model quality metrics evaluation *------------------------------------------------ * R2 Score: 0.92 * RMS loss: 2.81 ********************************************************************** Predicted fare: 15.7855, actual fare: 15.5 ********************************************************************** 最后的预测结果还是比较符合实际数值的。 分类: Technologyhttps://www.cnblogs.com/kenwoo/p/10171481.html
50000+
5万行代码练就真实本领
17年
创办于2008年老牌培训机构
1000+
合作企业
98%
就业率

联系我们

电话咨询

0532-85025005

扫码添加微信