您的当前位置:首页正文

python初始化二维数据

2024-11-23 来源:个人技术集锦

1.遇到的问题

突然不知道什么原因,想起来实现一个矩阵的乘法,于是用python代码实现一下。

def matrix_multiply():
	a = [[1, 2], [3, 4]]
	b = [[5, 6, 7], [8, 9, 10]]

	m, n = len(a[0]), len(b)
	if m != n:
		print('we need a column equal b row!')
	m, t = len(a), len(a[0])
	t, n = len(b), len(b[0])

	result = []
	inner = [0 for i in range(n)]
	for i in range(m):
		result.append(inner)

	for i in range(m):
		for j in range(n):
			for k in range(t):
				result[i][j] += a[i][k] * b[k][j]

	print(result)

初步看上去,代码逻辑比较简单,好像也没有任何问题,但是运行的输出结果为

[[68, 78, 88], [68, 78, 88]]

很明显结果不符合预期,肯定是哪个环节出现了问题,于是查找问题。
通过debug,发现result[0][0]-result[0][2]的范围,输出都正确,从result[1][0]开始输出结果有问题,于是大概就猜测到了问题在哪,应该是result初始化的时候出现了问题。

2.二维数组初始化方式1

为了验证猜测,进行了一些二维数组初始化的尝试。

比如下面这种方式

def init_array():
	a = [[0 for _ in range(5)] for _ in range(3)]
	print(a)
	a[0][0] = 1
	print(a)

通过双重列表推导的方式可以对二维数组实现初始化,最后输出为

[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
[[1, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]

上面这种初始化方式,对a[0][0]值进行修改的时候,a[1][0], a[2][0]的值未发生改变。

3.二维数组初始化方式2

def init_array_v2():
	a = [[0] * 5 for i in range(3)]
	print(a)
	a[0][0] = 1
	print(a)

这样也可以进行二维数组的初始化,输出结果与第一种方式相同。

4.二维数组初始化方式3

def init_array_v3():
	a = [[0] * 5] * 3
	print(a)
	a[0][0] = 1
	print(a)
	a[1][2] = 2
	print(a)

如果按上面这种方式进行初始化,输出结果为

[[0, 0, 0, 0, 0], [0, 0, 0, 0, 0], [0, 0, 0, 0, 0]]
[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0, 0]]
[[1, 0, 2, 0, 0], [1, 0, 2, 0, 0], [1, 0, 2, 0, 0]]

很明显,对a[0][0]进行修改以后,a[1][0], a[2][0]的值也发生了改变。
为什么会这样?
因为[0] * 5是一个一维数组的对象,*3的操作,是将这个对象的引用复制了三次。将a[0][0]的值修改为1以后,a[1][0], a[2][0]的值也变为了1。

我们前面矩阵乘法的初始化方式也跟上面这种方式类似。

def init_array_v4():
	result = []
	inner = [0 for i in range(5)]
	for i in range(3):
		result.append(inner)
	print(result)
	result[0][0] = 1
	print(result)

输出结果为:

[[1, 0, 0, 0, 0], [1, 0, 0, 0, 0], [1, 0, 0, 0, 0]]

很明显result.append(inner)的时候,也是对象的引用。将result[0][0]进行修改以后,第0列的值都跟着发生变化。

5.正确的矩阵乘法实现

那么正确的矩阵乘法实现也很容易了

def matrix_multiply():
	a = [[1, 2], [3, 4]]
	b = [[5, 6, 7], [8, 9, 10]]
	m, n = len(a[0]), len(b)
	if m != n:
		print('we need a column equal b row!')
	m, t = len(a), len(a[0])
	t, n = len(b), len(b[0])
	result = [[0 for _ in range(n)] for _ in range(m)]
	for i in range(m):
		for j in range(n):
			for k in range(t):
				result[i][j] += a[i][k] * b[k][j]

	print(result)
显示全文