突然不知道什么原因,想起来实现一个矩阵的乘法,于是用python代码实现一下。
def matrix_multiply():
a = [[1, 2], [3, 4]]
b = [[5, 6, 7], [8, 9, 10]]
m, n = len(a[0]), len(b)
if m != n:
print('we need a column equal b row!')
m, t = len(a), len(a[0])
t, n = len(b), len(b[0])
result = []
inner = [0 for i in range(n)]
for i in range(m):
result.append(inner)
for i in range(m):
for j in range(n):
for k in range(t):
result[i][j] += a[i][k] * b[k][j]
print(result)
初步看上去,代码逻辑比较简单,好像也没有任何问题,但是运行的输出结果为
[[68, 78, 88], [68, 78, 88]]
很明显结果不符合预期,肯定是哪个环节出现了问题,于是查找问题。
通过debug,发现result[0][0]-result[0][2]的范围,输出都正确,从result[1][0]开始输出结果有问题,于是大概就猜测到了问题在哪,应该是result初始化的时候出现了问题。
为了验证猜测,进行了一些二维数组初始化的尝试。
比如下面这种方式
def init_array():
a = [[0 for _ in range(5)] for _ in range(3)]
print(a)
a[0][0] = 1
print(a)
通过双重列表推导的方式可以对二维数组实现初始化,最后输出为
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
上面这种初始化方式,对a[0][0]值进行修改的时候,a[1][0], a[2][0]的值未发生改变。
def init_array_v2():
a = [[0] * 5 for i in range(3)]
print(a)
a[0][0] = 1
print(a)
这样也可以进行二维数组的初始化,输出结果与第一种方式相同。
def init_array_v3():
a = [[0] * 5] * 3
print(a)
a[0][0] = 1
print(a)
a[1][2] = 2
print(a)
如果按上面这种方式进行初始化,输出结果为
[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0, 0]]
[[1, 0, 2, 0, 0], [1, 0, 2, 0, 0], [1, 0, 2, 0, 0]]
很明显,对a[0][0]进行修改以后,a[1][0], a[2][0]的值也发生了改变。
为什么会这样?
因为[0] * 5是一个一维数组的对象,*3的操作,是将这个对象的引用复制了三次。将a[0][0]的值修改为1以后,a[1][0], a[2][0]的值也变为了1。
我们前面矩阵乘法的初始化方式也跟上面这种方式类似。
def init_array_v4():
result = []
inner = [0 for i in range(5)]
for i in range(3):
result.append(inner)
print(result)
result[0][0] = 1
print(result)
输出结果为:
[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0, 0]]
很明显result.append(inner)的时候,也是对象的引用。将result[0][0]进行修改以后,第0列的值都跟着发生变化。
那么正确的矩阵乘法实现也很容易了
def matrix_multiply():
a = [[1, 2], [3, 4]]
b = [[5, 6, 7], [8, 9, 10]]
m, n = len(a[0]), len(b)
if m != n:
print('we need a column equal b row!')
m, t = len(a), len(a[0])
t, n = len(b), len(b[0])
result = [[0 for _ in range(n)] for _ in range(m)]
for i in range(m):
for j in range(n):
for k in range(t):
result[i][j] += a[i][k] * b[k][j]
print(result)