Numba向量化运算
Hi! 大家好,又和大家见面了。上次给大家介绍了Numba中一句话加速for循环的@jit加速你的python脚本,今天继续给大家介绍另外一个我觉得很不错的Numba的用法。
在之前处理很小规模的for循环的时候,我没有感觉到需要加速python脚本,觉得30秒和15秒运行时间的差别对我的影响远没有大到需要我花精力去改写脚本的程度。
直到有时改写下脚本,时间可以从30小时缩小到8小时的时候,我才反应过来,原来脚本提速后给人的感觉还是很明显的。
1. For Example
前面给大家介绍过Numba很好用的@jit用法,今天给大家说一说它的另外一个我用到觉得还不错的@vectorize向量化运算。
还是举个例子吧,这些都是最近学习写模型遇到的问题,所以我就直接简化模型中的一个公式给大家介绍下它的神奇之处。公式如下图:
整体来看是由两个函数组成,一个是二项式一个是一次函数,然后求每个k下这两个函数的乘积,最后再求k从0到n下所有乘积的加和。
其中f,n为已知数,这里我设置为0.01和1000万。
首先我最开始直接写了个for循环:
numba_vectorize_example.py:
import math
import numba as nb
f=0.01
n=10000000
def func1(k):
#二项式系数取log_e
C=math.lgamma(n+1)-math.lgamma(k+1)-math.lgamma(n-k+1)
#二项式第一项取log_e
i1=math.log(f)*k
#二项式第二项取log_e
i2=math.log(1-f)*(n-k)
item=C+i1+i2
#转换回原值
result=math.exp(1)**item
return result
def func2(k):
return 3*k+2
@nb.jit
def func_sigma():
sigma=0
for k in range(n+1):
#两个函数的相乘累加到sigma
sigma+=func1(k)*func2(k)
return sigma
sigma=func_sigma()
print(sigma)
这里二项式中求阶乘,python有时直接用阶乘函数会导致溢出,可以改用math.lgamma变换一下。另外这里也用到了之前说的@jit加快for循环。
运行时间23.4秒:
$ time python3 numba_vectorize_example.py
300002.00301576033
python3 numba_vectorize_example.py 23.91s user 1.96s system 110% cpu 23.390 total
之后我用了向量化运算,所谓向量运算,就是类似于线性代数里面的两个向量的点积,点积介绍如下(wikipedia):
不同于for循环中给一个k算一次,这里是把所有k都给出来,直接同时算出所有k的结果,然后求和(有点类似于apply或者map?读者可以自行验证下这两个函数)。
具体脚本如下:
numba_vectorize_example_v1.py
import math
import numpy as np
import numba as nb
f=0.01
n=10000000
@nb.vectorize(["float64(float64)"])
def func1(k):
C=math.lgamma(n+1)-math.lgamma(k+1)-math.lgamma(n-k+1)
i1=math.log(f)*k
i2=math.log(1-f)*(n-k)
item=C+i1+i2
result=math.exp(1)**item
return result
@nb.vectorize(["float64(float64)"])
def func2(k):
return 3*k+2
@nb.jit
def func_sigma():
#0-1000万的k放到列表ki_list里面
ki_list=np.arange(n+1)
#两个函数同时对列表里面的所有值进行运算,np.dot计算向量的点积
sigma=np.dot(func1(ki_list),func2(ki_list))
return sigma
sigma=func_sigma()
print(sigma)
这里的["float64(float64)"]一般我都用float格式,因为计算的结果是浮点型。一般你的函数有几个参数就写几次float,并且类型需要一致,都是float或者都是int,不能两种混合,不然会报错。
例如你的func3有4个参数那写成@nb.vectorize(["float64(float64,float64,float64,float64)"])
运行时间2.6秒:
$ time python3 numba_vectorize_example_v1.py
300001.9999232713
python3 numba_vectorize_example_v1.py 4.36s user 1.42s system 223% cpu 2.583 total
对于Numba的用法,我也是用的时候才去快速了解了一下它的工具书,目前暂时只用到了这两个装饰器,感觉已经使我的脚本速度大大加快了。
如果大家比较感兴趣,也可以去翻翻它的官方手册,开发者也使用实例来进行了讲解,并且有些地方也配上了运行时间对比,清楚易懂。