好友
阅读权限10
听众
最后登录1970-1-1
|
象相合
发表于 2018-5-17 00:11
本帖最后由 象相合 于 2018-5-17 15:53 编辑
大家好久不见,最近在摸《算法导论》,简单的算法感觉别人都写过就不想发,写到Strassen算法的时候,发现GitHub上python版本没有正宗的写法,只有捞的写法,于是就分享一下代码=w=
Strassen算法是做[2^n]*[2^n]矩阵相乘的算法,局限性很大,但比较具有挑战性,因为需要使用【分治策略】且对矩阵和下标比较熟悉, 不然总会有这样或那样的错误,而且是矩阵调试很混乱,楼主比较菜写了个前置算法 递归的matlab版本(square_matrix_multiply_recursive) 再转成python, 最后补成Strassen算法才弄好,有闲时间的同学们可以玩一玩w
贴代码:
[Python] 纯文本查看 复制代码
import numpy as np
import math
def strassen_algorithm(A, B, L1, L2):
n = int(L1[1]) - int(L1[0]) + 1
d = math.floor(n / 2 - 1) # the half length of matrix width/height minus one
# if the matrix's length is 1, then set whose value
C = np.zeros((n, n))
if n <= 0:
return
if n == 1:
C[n - 1, n - 1] = A[L1[3], L1[0]] * B[L2[3], L2[0]]
else:
a11 = [L1[0], L1[0] + d, L1[2], L1[2] + d]
a12 = [int((L1[0] + L1[1] + 1) / 2), L1[1], L1[2], L1[2] + d]
a21 = [L1[0], L1[0] + d, int((L1[2] + L1[3] + 1) / 2), L1[3]]
a22 = [int((L1[0] + L1[1] + 1) / 2), L1[1], int((L1[3] + L1[2] + 1) / 2), L1[3]]
b11 = [L2[0], L2[0] + d, L2[2], L2[2] + d]
b21 = [L2[0], L2[0] + d, int((L2[2] + L2[3] + 1) / 2), L2[3]]
b12 = [int((L2[0] + L2[1] + 1) / 2), L2[1], L2[2], L2[2] + d]
b22 = [int((L2[0] + L2[1] + 1) / 2), L2[1], int((L2[2] + L2[3] + 1) / 2), L2[3]]
P1 = strassen_algorithm(A, B, a11, b12) - strassen_algorithm(A, B, a11, b22)
P2 = strassen_algorithm(A, B, a11, b22) + strassen_algorithm(A, B, a12, b22)
P3 = strassen_algorithm(A, B, a21, b11) + strassen_algorithm(A, B, a22, b11)
P4 = strassen_algorithm(A, B, a22, b21) - strassen_algorithm(A, B, a22, b11)
P5 = strassen_algorithm(A, B, a11, b11) + strassen_algorithm(A, B, a11, b22) + \
strassen_algorithm(A, B, a22, b11) + strassen_algorithm(A, B, a22, b22)
P6 = strassen_algorithm(A, B, a12, b21) + strassen_algorithm(A, B, a12, b22) - \
strassen_algorithm(A, B, a22, b21) - strassen_algorithm(A, B, a22, b22)
P7 = strassen_algorithm(A, B, a11, b11) + strassen_algorithm(A, B, a11, b12) - \
strassen_algorithm(A, B, a21, b11) - strassen_algorithm(A, B, a21, b12)
C[0:d + 1, 0:d + 1] = P5 + P4 - P2 + P6
C[0:d + 1, int(n / 2):int(n / 2 + d + 1)] = P1 + P2
C[int(n / 2):int(n / 2 + d + 1), 0:d + 1] = P3 + P4
C[int(n / 2):int(n / 2 + d + 1), int(n / 2):int(n / 2 + d + 1)] = P5 + P1 - P3 - P7
return C
n = 8
A = np.random.randint(0, 100, size=[n, n])
B = np.random.randint(0, 100, size=[n, n])
L1 = [0, n - 1, 0, n - 1]
L2 = [0, n - 1, 0, n - 1]
ret = strassen_algorithm(A, B, L1, L2)
# ret2 = np.dot(A, B)
# print(ret2)
print(ret)
为啥这种是正宗的写法呢?因为在15-22行中,a11,a12,...,b21,b22 这些变量存放的是矩阵的下标。我看了大佬python的写法,也就随便直接实例数组了。这种实例数组的写法比用下标的更慢,甚至比矩阵乘法的标准写法三次循环O(n^3)更慢。
代码分享地址:https://github.com/EleComb/IntroductionToAlgorithms/tree/master/chapter_4_DivideStrategy
里面也有之前的matlab写的和转成python的前置算法,欢迎各位互相学习指正=w=
|
免费评分
-
查看全部评分
|