ASRAMV 发表于 2021-10-17 06:31

Numba算法优化相关,出现错误求助

本帖最后由 ASRAMV 于 2021-10-17 16:23 编辑

尝试用Numba优化矩阵乘法算法
1为初始,后面考虑即使编译和并行计算都没有问题
再尝试把矩阵分为l*l的小块去计算出现了错误,救助一下原因

图片

ASRAMV 发表于 2021-10-17 09:12

抱歉图片想用超链接编辑了几次好像还是失败了,直接发出来好了
https://z3.ax1x.com/2021/10/17/5YeYcR.png

阳光肥肥 发表于 2021-10-17 10:42

m/l默认返回的是float 应该用m//l
并且你发代码建议如下发,别人好复制

for i in numba.prange(m//l):

ASRAMV 发表于 2021-10-17 15:58

本帖最后由 ASRAMV 于 2021-10-17 16:23 编辑

阳光肥肥 发表于 2021-10-17 10:42
m/l默认返回的是float 应该用m//l
并且你发代码建议如下发,别人好复制


感谢回复,试了一下不会报错,但是内存一直挂掉,而且不出结果
import numpy as np
import numba

@numba.njit(parallel=True)
def mat_product(mat_a, mat_b):
    m = mat_a.shape
    n = mat_b.shape
    l = 5

    assert(mat_a.shape == mat_b.shape)

    ncol = mat_a.shape

    mat_c = np.zeros((m, n), dtype=np.float64)
   
    mat_b = np.asfortranarray(mat_b)

    for i in numba.prange(m//l):
      for j in numba.prange(n//l):
            for r in range(l):
                for c in range(l):
                  for k in range(ncol):
                        mat_c += mat_a * mat_b
    return mat_c

a = np.random.randn(50, 50)
b = np.random.randn(50, 50)

c = mat_product(a, b)
检查我用的如下命令,正常是能返回时间的(也是检测优化结果的)
time = %timeit -o mat_product(a, b)
print(time.best)


尝试数值取小一些之后成功了,非常感谢
页: [1]
查看完整版本: Numba算法优化相关,出现错误求助