上QQ阅读APP看本书,新人免费读10天
设备和账号都新为新人
2.3 随机梯度下降及实现
在2.2.3节线性回归的实现中,我们看到在update_weights(更新权重)方法中,实际上每一次的更新都将所有输入数据处理了一次,将所有输入都作为一轮训练:
因为上例的训练数据很少,所以每次都载入所有数据进行训练并没有什么关系。然而,在实际的生产环境中训练数据可达到千万级别,要每次都载入这些数据进行训练是很不现实的。但我们可以换一种方式,每次只用一条随机数据来训练,这便是随机梯度下降的原理。
当然,这样会让训练时间变长,所以人们在实际环境中通常会采用mini batch方法,也就是每次更新权重时既不使用所有数据,也不随机挑选一条数据,而是随机挑选一个子集来训练。例如,如果有1000条数据,那么每次都可以随机挑选16条或者64条数据进行训练。对mini batch的大小设置是个很有趣的研究课题,和模型本身、数据特点等都有关系,这里不必细究。下面修改2.2节的线性回归实现,看看采用mini batch方法是如何进行训练的。
如上所示,我们只需要修改update_weights方法即可。在第6行设置了一个数组用于存储训练数据的索引下标;在第7行将有序索引数组随机打乱;在第8行设定每次训练都取4条数据。于是在第10~11行,我们只需从已经随机打乱的数据索引表中取前4条进行训练即可。
训练结果如下:
把这组结果和在2.2节中用全批量数据训练的结果相比,可以看到预测效果差别不大,这也证明了mini batch方法的有效性。至于随机梯度下降,它只是batch_size为1时的特殊情况而已。