添加链接
link之家
链接快照平台
  • 输入网页链接,自动生成快照
  • 标签化管理网页链接
相关文章推荐
活泼的打火机  ·  PyQt5 ...·  8 月前    · 

numba使用笔记

numba使用起来感觉非常费劲,看 官方文档 中,一些numpy算子明明支持,但是实际使用时怎么用都不行,可能还是不会用。在这里就使用心得记录一下。

  1. 切记不要在nb函数中使用任何numpy算子,要不然一直报错,包括np.array(), np.reshape, np.concatenate等等。
  2. 不推荐使用python list, (fun1),推荐使用numba的List (fun2), 速度比numpy list快1倍。
import numba as nb
import numpy as np
from numba.typed import List
import time
@nb.jit('List(f4)(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def fun1(a, b, len):
    res = []
    for i in range(len):
        res.append(a[i]+b[i])
    return res
@nb.jit('ListType(f4)(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def fun2(a, b, len):
    res = List()
    for i in range(len):
        res.append(a[i]+b[i])
    return res
def fun3(a, b, len):
    res = []
    for i in range(len):
        res.append(a[i]+b[i])
    return res
if __name__ == '__main__':
    len = 100000000
    a = np.random.randn(len).astype(np.float32)
    b = np.random.randn(len).astype(np.float32)
    t1 = time.time()
    c1 = fun1(a, b, len)
    t2 = time.time()
    c2 = fun2(a, b, len)
    t3 = time.time()
    c3 = fun3(a, b, len)
    t4 = time.time()
    print(f'fun1 cost: {t2-t1}s, \nfun2 cost: {t3-t2}s, \nfun3 cost: {t4-t3}s.')
output:
    fun1 cost: 5.080700397491455s, 
    fun2 cost: 2.126969337463379s, 
    fun3 cost: 55.357667684555054s.

3. 每个函数前都应该指才可以实现加速,如上述代码中的f4[:], i4等,f4代表float32, i4代表int32, 同理 u1代表uint8, f8代表float64。 方括号代表数据的维度,注意,f4[:], 表示numpy数组,shape为[N, ]的类型,f4[:, :]表示shape为[N, M]的numpy数组。 单f4表示一个float32的变量。

4. 需要指定数据类型,就涉及到怎么查看数据类型,具体有两种方式:

  • 一是首先把函数前的 @nb.jit 行注释掉,当做一个纯python函数运行,使用nb.typeof(var)查看变量var的数据类型
  • 二是在函数前只加@nb.jit, 后边的括号及数据类型都不要,函数运行完后,使用fun.inspect_types()函数输出函数func中各个阶段变量的数据类型。

5. 当有2个返回值时,目前还不知道如何操作,不过可以以List的形式返回多个相同类型的数组,如下:

import numba as nb
import numpy as np
from numba.typed import List
import time
@nb.jit('List(f4[:])(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def func(a, b, len):
    for i in range(len):
        a[i] += i
        b[i] += i
    res = [a, b] # python list
    return res
# 也可以用numba的List
@nb.jit('ListType(f4[:])(f4[:], f4[:], i4)', nopython=True, cache=True, parallel=False)
def func_nbList(a, b, len):
    for i in range(len):
        a[i] += i
        b[i] += i
    # res = List([a, b]) # 这种表达方式与下述等价