企业级AI技术内幕:深度学习框架开发+机器学习案例实战+Alluxio解密
上QQ阅读APP看书,第一时间看更新

5.1 关于损失度的思考——所有人工智能框架终身的魔咒

从过去的经历中学到一些知识,并把学到的知识用于改善下一次的行为,如此循环反复,实现能力的螺旋式上升。这样的过程是人类正常的学习过程,也是深度学习所有框架的核心思想,及绝大多数框架的实现过程(有些深度学习框架的实现不是按照前向传播、反向传播的过程来实现的)。通过前向传播去经历,通过反向传播来反思,并不断进行调整。

我们在盘古人工智能框架中输入数据,实例集只有4条数据,并采用了二维数组的方式,其实现的思想和TensorFlow是一样的。我们在写第一个“Hello World”程序的时候,一般是打印一行字符串,打印一行字符串和以后处理上万条、几十万条数据的过程几乎是一样的。可能有人会认为不一样,几十万条数据可能存储在网盘上,处理的时候可能需要多线程等,但这些不是框架运行机制的核心内容,是以后要优化的内容。因此,从本质上讲,我们使用实例集处理4条数据和处理几十万条数据的过程没太大的区别。在第8章中,我们将实现通过矩阵的方式加载外部数据源。

接下来我们回顾一下神经网络的内容。

1.用盘古人工智能框架构建整个神经网络
  • Neuron_Network_Entry.py:构建实例集输入数据,构建神经网络运行的主流程。
  • NetworkStructure.py:构建神经网络的架构,创建整个神经网络的所有节点,实现输入层、隐藏层、输出层。
  • Node.py:构建神经元节点,设置和访问神经元的层次、索引ID、名称、是否为偏爱因子、值、误差等内容。
  • Weight.py:构建权重,设置和访问权重的索引ID、来源节点、目标节点、权重值等信息。
  • NetworkConnection.py:构建神经网络中前后层次神经元节点之间的权重关系。
2.实现神经网络前向传播功能

如图5-2所示,从输入层输入的数据中提取特征,可以直接输入x1x2的特征,也可以输入x1x2、sin(x1)、sin(x2)等的特征值。这里我们构建的是最基本、最原始的结构,暂时只关注x1x2的特征值;隐藏层根据输入的数据和权重计算出神经元的值,然后根据权重及上一个隐藏层的值计算下一个隐藏层所有神经元的值;输出层根据最后一层隐藏层的值以及权重计算出最终的预测结果。一般情况下,预测结果和实际结果会有误差。

图5-2 x1x2的各种特征

3.实现神经网络反向传播功能

实现神经网络反向传播的时候要更新权重,权重的赋值包含两部分:

(1)权重初始化:在神经网络初始化的时候,根据具体的算法对权重进行初始化赋值。

(2)权重更新:反向传播时更新权重。

图5-3 前向传播及反向传播

如图5-3所示,所有数据从输入到输出运算1次是一次时代,通过前向传播算法得到一个预测值,计算预测值和实际值的差;然后在反向传播算法中,将输出层的误差值从后往前推,依次传递到前面的隐藏层,计算各个节点要负的责任(误差),对原有的激活函数Sigmoid进行求导,通过梯度下降算法更新权重。注意,调整的时候不调整输入层,输入层是我们的输入数据,不能修改输入的数据,要改的是神经网络隐藏层的神经元以及神经元关联的权重。随着时代的迭代,损失度将越来越低。

神经网络进行非线性变换的时候,我们可以改变激活函数,ReLU是我们运用得最多的激活函数之一。使用ReLU函数,会发现损失度收敛的速度会快很多,时代运行次数越多,精度越高;也可以采用Sigmoid激活函数的方式,它的收敛速度非常快,但比ReLU差了一点。可以通过Sigmoid激活函数实现ReLU,它们之间的公式有关联,就像线性回归是逻辑回归的基础,逻辑回归是神经网络的基础,神经网络是CNN的基础,而CNN是RNN的基础。线性回归是第一层的阶梯,逻辑回归是第二层的阶梯,神经网络是第三层的阶梯,CNN是第四层的阶梯,RNN是第五层的阶梯。

盘古人工智能框架的代码目前还没有改进,训练的效果不理想,预测值没有向1或者0趋近。经过改进,输出结果的准确率将达到95%。从实现神经网络的角度,我们已经完成了前向传播算法、反向传播算法,盘古人工智能框架和TensorFlow、PyTorch的第一个原型是一样的。

本节将进行数据可视化,如将神经网络训练10 000个时代,每100个时代打印一下损失度。如图5-4所示,TensorFlow可视化图给出了动态损失度的曲线图。

图5-4 损失度曲线

在计算过程中,激活函数和梯度下降的算法不一样,计算误差的方法也不一样,导致损失度的值不一样。例如,最原始的误差计算函数是用预测值减去实际值,将预测值和实际值进行比较:第1条记录将预测值0.4385711425310979减去实际值0;第2条记录将预测值0.4129240972484869减去实际值1;第3条记录将预测值0.4651651378938895减去实际值1;第4条记录将预测值0.43581722934765915减去实际值0。