1.5 NumPy入门
NumPy是一个Python包,它代表“Numeric Python”。它是一个由多维数组对象和用于处理数组的例程集合组成的库。
1.5.1 NumPy的用法
简单地说,NumPy可以看成附加了各种代数运算的列表(List),所以NumPy数组的定义、提取、更改等都很直观。
运行程序,输出如下:
由于NumPy数组通常以高级数组的形式出现,所以除上例中通过简单下标提取元素的方法外,NumPy还有针对性地设计了一套高效的、提取相应切片的索引方法(Indexing)。
【例1-1】 以下实例获取了4×3数组中的4个角的元素。行索引是[0,0]和[3,3],而列索引是[0,2]和[0,2]。
运行程序,输出如下:
对这一套索引方法进行深入、详尽的了解是非常有必要的。由于本书不是详细介绍Python软件的,而是应用Python软件解决人工智能问题的,所以关于这个方法,读者可参考官方文档(https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html)加深理解。
由于在算法的编程任务中一般不会给定一个大矩阵,然后让程序员手动将矩阵的每个值输入计算机,所以一般会给定一个大矩阵的形状让程序进行相应的初始化(初始化为全0、全1及随机数)。为此,NumPy自带了许多有用的初始化方法。详细说明也可以参考官方文档(https://docs.scipy.org/doc/numpy/user/basics.creation.html#arrays-creation),此处仅介绍几个最常用的初始化方法。
运行程序,输出如下:
接下来要说明的就是NumPy数组的加、减、乘、除了。
运行程序,输出如下:
除此之外,NumPy支持一系列对高维数组的数学运算操作,这为算法的向量化打下了坚实的基础。所有NumPy支持的数学运算都可以在官方文档(https://docs.scipy.org/doc/ numpy/reference/routines.math.html)中查到,下面展示几个比较典型的方法。
运行程序,输出如下:
这里提到了axis的概念。直观地说,不同的axis其实可以视为不同深度的for循环所对应的数据,具体而言:
• 第一个axis对应第一层for循环所能达到的数据;
• 第二个axis对应第二层for循环所能达到的数据。
以此类推至第n个axis对应第n层for循环所能达到的数据。
1.5.2 广播
广播(Broadcast)是NumPy对不同形状(Shape)的数组进行数值计算的方式,对数组的算术运算通常在相应的元素上进行。
如果两个数组a和b形状相同,即满足a.shape == b.shape,那么a*b的结果就是a与b数组对应位相乘。这要求维度相同,且各维度的长度相同。例如:
运行程序,输出如下:
当运算中2个数组的形状不同时,NumPy将自动触发广播机制。例如:
运行程序,输出结果为:
图1-17展示了数组b是如何通过广播来与数组a兼容的。
图1-17 广播与数组兼容过程
4×3的二维数组与长为3的一维数组相加,等效于把数组b在二维上重复4次再运算:
运行程序,输出如下:
需要记住的是广播的规则。
• 让所有输入数组都向其中形状最长的数组看齐,形状中不足的部分都通过在前面加1补齐。
• 输出数组的形状是输入数组形状的各个维度上的最大值。
• 如果输入数组的某个维度和输出数组的对应维度的长度相同或者其长度为1时,那么这个数组能够用来计算,否则出错。
• 当输入数组的某个维度的长度为1时,沿着此维度运算时都用此维度上的第一组值。
1.5.3 向量化与“升维”
NumPy之所以能够将性能提升那么多,很大程度上依赖着由底层语言编写的线性代数运算库,而代数运算中的基础——矩阵运算自然是这么多年来被重点反复优化的算法之一。所以如果想要写出高效的算法,将算法进行向量化是必不可少的步骤。下面我们主要介绍以下两点思想。
• 将for循环替换成NumPy运算。
• 将难以直接向量化的算法所对应的数组进行“升维”。
先看第一点,该思想是向量化思维的基石。
就上述代码而言,第一种for循环实现耗时大约为540ms,第二种利用NumPy运算实现耗时大约为4.5ms,第三种利用NumPy函数实现耗时大约为2.6ms。可以看出最快的方法比最慢的方法要快了200倍左右,由此可见向量化的效率。对于第二点,其实是对广播的高级应用,即升维的思想。
升维其实很简单:利用广播将某一段重复的运算向量化。例如,有如下两个数组:
而我们希望计算出
那么可以直接利用广播来进行实现:
运行程序,输出:
但是,不难发现y中其实有大量重复的元素,这在实际问题中常常反映为当
时,计算
的结果,这可以通过两种方式完成。第一种方式就是把y直接写成开始时那种具有大量重复元素的矩阵形式,这可以通过NumPy自带的函数——np.tile直接实现:
运行程序,输出如下:
第二种方式就是进行“升维”,利用NumPy的广播来帮助我们完成重复的运算,这种做法是更快、更省内存的。
运行程序,输出如下:
其中,y[:,None]的效果为
即y从一维数组(1×3)变化为了二维数组(3×1)。如果用此时的y减去x,由于y的“宽度”仅为1,而x的“宽度”为3,所以NumPy的广播会在内部对y进行“扩张”以“适配”x的宽度。
由于这个扩张是在NumPy内部隐性进行的,所以比起第一种方法中的显性计算,它的性能会好许多。
1.5.4 NumPy的应用思想
使用NumPy时,应该尽量避免不必要的复制。例如,在计算a1=x+1时,最快的实现方法是np.add(x,1,a1)。这种写法比直接写a1=x+1更优的原因是,可以通过拆解运算过程来直观认知:
• 对于np.add(x,1,a1),在计算x+1时会将结果直接写进a1;
• 对于a1=x+1,会先将x+1的结果放进内存,再把内存中的结果赋给a1。
这说明执行a1=x+1时我们进行了不必要的复制(将结果复制进内存)。这会引发NumPy数组复制操作,包括但不限于:
• x=x+1(建议使用x+=1);
• y=x.flatten()(建议使用y=x.ravel());
• x=x.T(无替代方案,不过这告诉我们要尽量少用转置)。
除以上3种比较常见的操作外,还有一些数组的reshape操作也会引发NumPy的复制。总之,在发觉程序运行不如想象中的高效时,检查是否引发了不必要的复制是一个重要的应用思想。