机器学习系统:设计和实现
上QQ阅读APP看书,第一时间看更新

2.3.3 自定义神经网络层

2.3.1节中使用伪代码定义机器学习库中低级API,有了实现的神经网络基类抽象方法,那么就可以设计更高层次的接口,解决手动管理参数的烦琐。假设已经有了神经网络模型抽象方法Cell,构建Conv2D将继承Cell类,并重构__init__和__call__方法,在__init__方法中初始化训练参数和输入参数,在__call__方法中调用低级API实现计算逻辑。使用伪代码,如代码2.12所示,通过接口定义描述自定义卷积层的过程。

代码2.12 自定义神经网络层

有了上述定义,在使用卷积层时,就不需要创建训练变量了。假设需要对30×30大小的10个通道的输入使用3×3的卷积核做卷积,卷积后输出20个通道,调用方式如代码2.13所示。

代码2.13 使用卷积层

在执行过程中,初始化Conv2D时,__setattr__方法会判断属性,把属于Cell类的神经网络层Conv2D记录到self._cells中,filters属于参数(parameter),把参数记录到self._params中。查看神经网络层参数使用conv.parameters_and_names;查看神经网络层列表使用conv.cells_and_names;执行操作使用conv(inputs)。