机器学习系统:设计和实现
上QQ阅读APP看书,第一时间看更新

2.2.5 训练及保存模型

MindSpore提供了回调(Callback)机制,可以在训练过程中执行自定义逻辑。代码2.6使用框架提供的ModelCheckpoint函数,ModelCheckpoint函数可以保存网络模型和参数,以便进行后续的Fine-tuning(微调)操作。

代码2.6 定义模型保存

通过MindSpore提供的model.train接口可以方便地进行网络的训练,同时使用Loss-Monitor可以监控训练过程中损失(loss)值的变化,如代码2.7所示。

代码2.7 定义模型训练

其中,dataset_sink_mode用于控制数据是否下沉,数据下沉是指数据通过通道直接传送到设备(Device)上,可以加快训练速度,dataset_sink_mode为真(True),表示数据下沉,否则为非下沉。

有了数据集、模型、损失函数、优化器后就可以进行训练了。代码2.8把train_epoch设置为1,对数据集进行1次迭代训练。在train_net方法中,加载了之前下载的训练数据集,mnist_path是MNIST数据集路径。

代码2.8 训练模型