การเข้าใจความคิดของnumpy.einsum()
เรื่องง่ายมากถ้าคุณเข้าใจมันอย่างสังหรณ์ใจ เป็นตัวอย่างที่ขอเริ่มต้นด้วยคำอธิบายง่ายๆที่เกี่ยวข้องกับการคูณเมทริกซ์
ในการใช้งานnumpy.einsum()
สิ่งที่คุณต้องทำคือส่งสตริงตัวห้อยที่เรียกว่าเป็นอาร์กิวเมนต์ตามด้วยอินพุตอาร์เรย์ของคุณ
สมมติว่าคุณมีสอง 2D อาร์เรย์A
และB
และคุณต้องการที่จะทำคูณเมทริกซ์ ดังนั้นคุณทำ:
np.einsum("ij, jk -> ik", A, B)
นี่สตริงห้อย ij
สอดคล้องกับอาร์เรย์A
ขณะที่สตริงห้อย สอดคล้องกับอาร์เรย์jk
B
สิ่งที่สำคัญที่สุดที่ควรทราบคือจำนวนอักขระในสตริงตัวห้อย แต่ละตัวต้องตรงกับขนาดของอาเรย์ (นั่นคือสองตัวอักษรสำหรับ 2D อาร์เรย์สามตัวอักษรสำหรับแบบสามมิติและอื่น ๆ ) และถ้าคุณทำซ้ำตัวอักษรระหว่างสตริงตัวห้อย ( j
ในกรณีของเรา) นั่นหมายความว่าคุณต้องการein
ผลรวมที่เกิดขึ้นตามมิติเหล่านั้น ดังนั้นพวกเขาจะได้รับผลรวมลดลง (เช่นขนาดนั้นจะหายไป )
สตริงห้อยหลังจากนี้->
จะเป็นอาร์เรย์ผลลัพธ์ของเรา หากคุณปล่อยว่างไว้ทุกอย่างจะถูกรวมและส่งคืนค่าสเกลาร์เป็นผลลัพธ์ อื่นอาร์เรย์ผลจะมีขนาดตามที่สตริงห้อย ik
ในตัวอย่างของเรามันจะเป็น นี่คือสัญชาตญาณเพราะเรารู้ว่าสำหรับการคูณเมทริกซ์จำนวนคอลัมน์ในอาร์เรย์A
ต้องตรงกับจำนวนแถวในอาเรย์B
ซึ่งเป็นสิ่งที่เกิดขึ้นที่นี่ (เช่นเราเข้ารหัสความรู้นี้โดยการทำซ้ำอักขระj
ในสตริงตัวห้อย )
ต่อไปนี้เป็นตัวอย่างเพิ่มเติมที่แสดงให้เห็นถึงการใช้งาน / กำลังของnp.einsum()
ในการดำเนินการเมตริกซ์ทั่วไปหรือการดำเนินการอาร์เรย์บางลำดับ
ปัจจัยการผลิต
# a vector
In [197]: vec
Out[197]: array([0, 1, 2, 3])
# an array
In [198]: A
Out[198]:
array([[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]])
# another array
In [199]: B
Out[199]:
array([[1, 1, 1, 1],
[2, 2, 2, 2],
[3, 3, 3, 3],
[4, 4, 4, 4]])
1) การคูณเมทริกซ์ (คล้ายกับnp.matmul(arr1, arr2)
)
In [200]: np.einsum("ij, jk -> ik", A, B)
Out[200]:
array([[130, 130, 130, 130],
[230, 230, 230, 230],
[330, 330, 330, 330],
[430, 430, 430, 430]])
2) แยกองค์ประกอบตามแนวขวางหลัก (คล้ายกับnp.diag(arr)
)
In [202]: np.einsum("ii -> i", A)
Out[202]: array([11, 22, 33, 44])
3) ผลิตภัณฑ์ Hadamard (เช่นผลิตภัณฑ์องค์ประกอบที่ชาญฉลาดของสองอาร์เรย์) (คล้ายกับarr1 * arr2
)
In [203]: np.einsum("ij, ij -> ij", A, B)
Out[203]:
array([[ 11, 12, 13, 14],
[ 42, 44, 46, 48],
[ 93, 96, 99, 102],
[164, 168, 172, 176]])
4) การยกกำลังสององค์ประกอบ (คล้ายกับnp.square(arr)
หรือarr ** 2
)
In [210]: np.einsum("ij, ij -> ij", B, B)
Out[210]:
array([[ 1, 1, 1, 1],
[ 4, 4, 4, 4],
[ 9, 9, 9, 9],
[16, 16, 16, 16]])
5) ติดตาม (เช่นผลรวมขององค์ประกอบหลักเส้นทแยงมุม) (คล้ายกับnp.trace(arr)
)
In [217]: np.einsum("ii -> ", A)
Out[217]: 110
6) เมทริกซ์ขนย้าย (คล้ายกับnp.transpose(arr)
)
In [221]: np.einsum("ij -> ji", A)
Out[221]:
array([[11, 21, 31, 41],
[12, 22, 32, 42],
[13, 23, 33, 43],
[14, 24, 34, 44]])
7) สินค้าชั้นนอก (จากเวกเตอร์) (คล้ายกับnp.outer(vec1, vec2)
)
In [255]: np.einsum("i, j -> ij", vec, vec)
Out[255]:
array([[0, 0, 0, 0],
[0, 1, 2, 3],
[0, 2, 4, 6],
[0, 3, 6, 9]])
8) ผลิตภัณฑ์ชั้นใน (ของเวกเตอร์) (คล้ายกับnp.inner(vec1, vec2)
)
In [256]: np.einsum("i, i -> ", vec, vec)
Out[256]: 14
9) ผลรวมตามแกน 0 (คล้ายกับnp.sum(arr, axis=0)
)
In [260]: np.einsum("ij -> j", B)
Out[260]: array([10, 10, 10, 10])
10) ผลรวมตามแกน 1 (คล้ายกับnp.sum(arr, axis=1)
)
In [261]: np.einsum("ij -> i", B)
Out[261]: array([ 4, 8, 12, 16])
11) การคูณเมทริกซ์แบทช์
In [287]: BM = np.stack((A, B), axis=0)
In [288]: BM
Out[288]:
array([[[11, 12, 13, 14],
[21, 22, 23, 24],
[31, 32, 33, 34],
[41, 42, 43, 44]],
[[ 1, 1, 1, 1],
[ 2, 2, 2, 2],
[ 3, 3, 3, 3],
[ 4, 4, 4, 4]]])
In [289]: BM.shape
Out[289]: (2, 4, 4)
# batch matrix multiply using einsum
In [292]: BMM = np.einsum("bij, bjk -> bik", BM, BM)
In [293]: BMM
Out[293]:
array([[[1350, 1400, 1450, 1500],
[2390, 2480, 2570, 2660],
[3430, 3560, 3690, 3820],
[4470, 4640, 4810, 4980]],
[[ 10, 10, 10, 10],
[ 20, 20, 20, 20],
[ 30, 30, 30, 30],
[ 40, 40, 40, 40]]])
In [294]: BMM.shape
Out[294]: (2, 4, 4)
12) ผลรวมตามแกน 2 (คล้ายกับnp.sum(arr, axis=2)
)
In [330]: np.einsum("ijk -> ij", BM)
Out[330]:
array([[ 50, 90, 130, 170],
[ 4, 8, 12, 16]])
13) รวมองค์ประกอบทั้งหมดในอาร์เรย์ (คล้ายกับnp.sum(arr)
)
In [335]: np.einsum("ijk -> ", BM)
Out[335]: 480
14) ผลรวมมากกว่าหลายแกน (เช่น marginalization)
(คล้ายกับnp.sum(arr, axis=(axis0, axis1, axis2, axis3, axis4, axis6, axis7))
)
# 8D array
In [354]: R = np.random.standard_normal((3,5,4,6,8,2,7,9))
# marginalize out axis 5 (i.e. "n" here)
In [363]: esum = np.einsum("ijklmnop -> n", R)
# marginalize out axis 5 (i.e. sum over rest of the axes)
In [364]: nsum = np.sum(R, axis=(0,1,2,3,4,6,7))
In [365]: np.allclose(esum, nsum)
Out[365]: True
15) ผลิตภัณฑ์ Double Dot (คล้ายกับnp.sum (Hadamard-product) cf. 3 )
In [772]: A
Out[772]:
array([[1, 2, 3],
[4, 2, 2],
[2, 3, 4]])
In [773]: B
Out[773]:
array([[1, 4, 7],
[2, 5, 8],
[3, 6, 9]])
In [774]: np.einsum("ij, ij -> ", A, B)
Out[774]: 124
16) การคูณอาร์เรย์ 2 มิติและ 3 มิติ
การคูณนั้นมีประโยชน์มากเมื่อแก้ระบบสมการเชิงเส้น ( Ax = b ) ที่คุณต้องการตรวจสอบผลลัพธ์
# inputs
In [115]: A = np.random.rand(3,3)
In [116]: b = np.random.rand(3, 4, 5)
# solve for x
In [117]: x = np.linalg.solve(A, b.reshape(b.shape[0], -1)).reshape(b.shape)
# 2D and 3D array multiplication :)
In [118]: Ax = np.einsum('ij, jkl', A, x)
# indeed the same!
In [119]: np.allclose(Ax, b)
Out[119]: True
ในทางกลับกันหากจำเป็นต้องใช้np.matmul()
สำหรับการตรวจสอบนี้เราต้องดำเนินการสองอย่างreshape
เพื่อให้ได้ผลลัพธ์เดียวกันเช่น:
# reshape 3D array `x` to 2D, perform matmul
# then reshape the resultant array to 3D
In [123]: Ax_matmul = np.matmul(A, x.reshape(x.shape[0], -1)).reshape(x.shape)
# indeed correct!
In [124]: np.allclose(Ax, Ax_matmul)
Out[124]: True
โบนัส : อ่านคณิตศาสตร์เพิ่มเติมได้ที่นี่: Einstein-Summationและที่นี่แน่นอน: Tensor-Notation
(A * B)^T
B^T * A^T