Numba算法优化相关,出现错误求助
本帖最后由 ASRAMV 于 2021-10-17 16:23 编辑尝试用Numba优化矩阵乘法算法
1为初始,后面考虑即使编译和并行计算都没有问题
再尝试把矩阵分为l*l的小块去计算出现了错误,救助一下原因
图片 抱歉图片想用超链接编辑了几次好像还是失败了,直接发出来好了
https://z3.ax1x.com/2021/10/17/5YeYcR.png m/l默认返回的是float 应该用m//l
并且你发代码建议如下发,别人好复制
for i in numba.prange(m//l): 本帖最后由 ASRAMV 于 2021-10-17 16:23 编辑
阳光肥肥 发表于 2021-10-17 10:42
m/l默认返回的是float 应该用m//l
并且你发代码建议如下发,别人好复制
感谢回复,试了一下不会报错,但是内存一直挂掉,而且不出结果
import numpy as np
import numba
@numba.njit(parallel=True)
def mat_product(mat_a, mat_b):
m = mat_a.shape
n = mat_b.shape
l = 5
assert(mat_a.shape == mat_b.shape)
ncol = mat_a.shape
mat_c = np.zeros((m, n), dtype=np.float64)
mat_b = np.asfortranarray(mat_b)
for i in numba.prange(m//l):
for j in numba.prange(n//l):
for r in range(l):
for c in range(l):
for k in range(ncol):
mat_c += mat_a * mat_b
return mat_c
a = np.random.randn(50, 50)
b = np.random.randn(50, 50)
c = mat_product(a, b)
检查我用的如下命令,正常是能返回时间的(也是检测优化结果的)
time = %timeit -o mat_product(a, b)
print(time.best)
尝试数值取小一些之后成功了,非常感谢
页:
[1]