理解 NumPy 的 einsum 函数
概述
本文是理解和使用 numpy.einsum 的简要说明与实用指南,它让我们能够使用爱因斯坦记号(Einstein notation)对多维数组进行运算。本文主要关注 einsum 的显式模式(使用 -> 并在下标字符串中明确指定输出维度),以及机器学习中常见的用例,同时也会简要介绍其他模式。
核心概念:einsum 通过下标字符串描述数组运算,重复的维度标签表示收缩(求和),未重复的标签保留在输出中。
基本用例:矩阵乘法
让我们从一个基础演示开始:使用 einsum 进行矩阵乘法。在本文中,A 和 B 将是以下矩阵:
>>> import numpy as np
>>> A = np.arange(6).reshape(2, 3)
>>> A
array([[0, 1, 2],
[3, 4, 5]])
>>> B = np.arange(12).reshape(3, 4) + 1
>>> B
array([[ 1, 2, 3, 4],
[ 5, 6, 7, 8],
[ 9, 10, 11, 12]])
A 的形状为 (2,3),B 的形状为 (3,4),我们可以执行 A @ B 得到一个 (2,4) 的矩阵。这也可以用 einsum 实现:
>>> np.einsum('ij,jk->ik', A, B)
array([[ 23, 26, 29, 32],
[ 68, 80, 92, 104]])
下标字符串解析
einsum 的第一个参数是下标字符串(subscript string),用于描述对后续操作数执行的操作:
| 部分 | 含义 |
|---|---|
ij |
第一个输入 A 的维度标签,表示形状 (i,j) |
jk |
第二个输入 B 的维度标签,表示形状 (j,k) |
-> |
分隔输入与输出 |
ik |
输出数组的维度标签,表示形状 (i,k) |
关键规则:
- 在输入中重复出现但未出现在输出中的标签(如
j)会被收缩(沿该维度求和)- 仅出现在输入或输出中一次的标签会被保留
- 输出维度的顺序由
->后的标签顺序决定
简化心智模型
以下是 einsum 工作原理的简化理解方式:
对于输出中的每个元素 out[i,k]:
out[i,k] = Σⱼ (A[i,j] × B[j,k])
即:每个输出元素是 A 的第 i 行与 B 的第 k 列的点积。
转置输出
我们可以通过调整输出标签的顺序来转置结果:
>>> np.einsum('ij,jk->ki', A, B)
array([[ 23, 68],
[ 26, 80],
[ 29, 92],
[ 32, 104]])
这等价于 (A @ B).T。
批量矩阵乘法
当输入数组的维度数(ndim)增加时,einsum 作为文档辅助工具的价值更加明显。例如,我们可能希望在单个操作中对一整批输入执行矩阵乘法:
>>> Ab = np.arange(6*6).reshape(6, 2, 3) # 形状: (6, 2, 3)
>>> Bb = np.arange(6*12).reshape(6, 3, 4) # 形状: (6, 3, 4)
这里 6 是批量维度(batch dimension)。我们将一批 6 个 (2,3) 矩阵与一批 6 个 (3,4) 矩阵逐对相乘,结果形状为 (6,2,4)。
使用 @ 运算符
NumPy 的 @ 运算符天然支持批量矩阵乘法:
>>> result = Ab @ Bb # 形状: (6, 2, 4)
收缩发生在第一个数组的最后一个维度(d=3)和第二个数组的倒数第二个维度(d=3)之间,批量维度自动广播对齐。
使用 einsum(推荐)
>>> np.einsum('bmd,bdn->bmn', Ab, Bb)
| 标签 | 可能含义 |
|---|---|
b |
batch(批量) |
m |
sequence length(序列长度) |
d |
model dimension(模型维度/深度) |
n |
output dimension(输出维度) |
✅ 优势:下标字符串让维度语义显式化,代码可读性显著提升。
⚠️ 注意:虽然
b在输入中重复出现,但它也出现在输出中,因此不会被收缩。
排序输出维度
einsum 下标中输出维度的顺序让我们不仅能进行矩阵乘法,还可以任意重排输出维度:
>>> Bb.shape
(6, 3, 4)
>>> np.einsum('ijk->kij', Bb).shape
(4, 6, 3) # 原 (i,j,k) → 新 (k,i,j)
实际应用:Transformer 中的多头注意力
以下示例取自 Noam Shazeer 的《Fast Transformer Decoding》论文:
# 维度常量
>>> m, d, k, h, b = 4, 3, 6, 5, 10
# 随机张量
>>> Pk = np.random.randn(h, d, k) # 形状: (h, d, k) - 键投影权重
>>> M = np.random.randn(b, m, d) # 形状: (b, m, d) - 输入表示
# 计算所有键:einsum 一次性完成收缩 + 维度重排
>>> np.einsum('bmd,hdk->bhmk', M, Pk).shape
(10, 5, 4, 6) # (b, h, m, k)
| 张量 | 形状 | 含义 |
|---|---|---|
M |
(b,m,d) |
批量×序列长度×模型深度 |
Pk |
(h,d,k) |
头数×模型深度×键的头维度 |
| 输出 | (b,h,m,k) |
批量×头数×序列长度×键维度 |
灵活重排:若需要不同维度顺序,只需修改输出标签:
>>> np.einsum('bmd,hdk->hbmk', M, Pk).shape # (h, b, m, k) (5, 10, 4, 6)
多维度收缩
单个 einsum 操作可以同时收缩多个维度:
>>> b, n, d, v, h = 10, 4, 3, 6, 5
>>> O = np.random.randn(b, h, n, v) # 形状: (b, h, n, v)
>>> Po = np.random.randn(h, d, v) # 形状: (h, d, v)
>>> np.einsum('bhnv,hdv->bnd', O, Po).shape
(10, 4, 3) # (b, n, d)
收缩分析
| 标签 | 出现位置 | 是否在输出 | 操作 |
|---|---|---|---|
b |
输入1 | ✅ | 保留 |
h |
输入1, 输入2 | ❌ | 收缩(求和) |
n |
输入1 | ✅ | 保留 |
v |
输入1, 输入2 | ❌ | 收缩(求和) |
d |
输入2 | ✅ | 保留 |
等价计算(伪代码):
out[b,n,d] = Σₕ Σᵥ (O[b,h,n,v] × Po[h,d,v])
转置输入
在指定 einsum 的输入时,我们可以通过重新排序维度标签来隐式转置输入:
>>> A.shape
(2, 3)
# A @ A.T:第二个输入使用 'kj' 而非 'jk'
>>> np.einsum('ij,kj->ik', A, A)
array([[ 5, 14],
[14, 50]])
# 验证
>>> (A @ A.T)
array([[ 5, 14],
[14, 50]])
原理:
j在两个输入中重复出现但未出现在输出中,因此被收缩;kj的顺序使第二个输入被视为转置。
多于两个参数
einsum 支持任意数量的输入,适合链式矩阵乘法:
>>> C = np.arange(20).reshape(4, 5) # 形状: (4, 5)
# 传统写法
>>> A @ B @ C
array([[ 900, 1010, 1120, 1230, 1340],
[2880, 3224, 3568, 3912, 4256]])
# einsum 写法
>>> np.einsum('ij,jk,kp->ip', A, B, C)
array([[ 900, 1010, 1120, 1230, 1340],
[2880, 3224, 3568, 3912, 4256]])
✅ 优势:显式的维度名称使复杂链式运算的维度流动一目了然。
einsum 的教学式实现
理解 einsum 的最佳方式之一是手动实现其核心逻辑。以下以 'ij,jk->ik'(矩阵乘法)为例:
步骤 1:解析维度大小
def calc(__a, __b):
# 从下标 'ij,jk->ik' 解析维度
i_size = __a.shape[0] # i
j_size = __a.shape[1] # j (also __b.shape[0])
k_size = __b.shape[1] # k
assert j_size == __b.shape[0], "收缩维度大小不匹配"
步骤 2:初始化输出
out = np.zeros((i_size, k_size)) # 输出形状: (i, k)
步骤 3:生成嵌套循环
# 外层循环:输出维度 (i, k)
for i in range(i_size):
for k in range(k_size):
# 内层循环:收缩维度 (j)
for j in range(j_size):
out[i, k] += __a[i, j] * __b[j, k]
return out
通用模式
对于任意下标字符串,可遵循以下规则生成代码:
- 输出维度 → 外层循环
- 收缩维度(输入中出现但输出中缺失)→ 内层循环 + 累加
- 元素访问 → 直接按标签索引:
out[i,k] += a[i,j] * b[j,k]
多维度收缩示例
对于 'bhnv,hdv->bnd':
# 输出维度: b, n, d → 外层循环
# 收缩维度: h, v → 内层循环
for b in range(b_size):
for n in range(n_size):
for d in range(d_size):
for h in range(h_size): # 收缩
for v in range(v_size): # 收缩
out[b, n, d] += __a[b, h, n, v] * __b[h, d, v]
无收缩情况:外积
对于 'i,j->ij'(外积):
def calc(__a, __b):
i_size = __a.shape[0]
j_size = __b.shape[0]
out = np.zeros((i_size, j_size))
# 无收缩维度,直接赋值
for i in range(i_size):
for j in range(j_size):
out[i, j] = __a[i] * __b[j] # 注意:不是 +=
return out
爱因斯坦记号(历史背景)
这种记号以阿尔伯特·爱因斯坦命名,他在 1916 年关于广义相对论的论文中首次引入,用于简洁表达张量运算中的嵌套求和。
物理中的经典形式
在物理学中,张量通常有下标(协变)和上标(逆变),例如:
B¹ = a₁₁A¹ + a₁₂A² + a₁₃A³ = Σⱼ₌₁³ a₁ⱼAʲ
B² = a₂₁A¹ + a₂₂A² + a₂₃A³ = Σⱼ₌₁³ a₂ⱼAʲ
B³ = a₃₁A¹ + a₃₂A² + a₃₃A³ = Σⱼ₌₁³ a₃ⱼAʲ
使用索引 i 合并为:
Bⁱ = Σⱼ₌₁³ aᵢⱼAʲ
爱因斯坦约定:当同一索引在单项中重复出现(一次上标、一次下标)时,隐含对该索引求和:
Bⁱ = aᵢⱼAʲ ← 求和符号 Σⱼ 被省略
与 NumPy einsum 的对应关系
| 爱因斯坦记号 | NumPy einsum | 含义 |
|---|---|---|
aᵢⱼAʲ |
'ij,j->i' |
矩阵 - 向量乘法 |
TᵢⱼₖSᵏˡ |
'ijk,kl->ijl' |
张量收缩 |
δᵢᵢ |
'ii->' |
矩阵的迹(对角线求和) |
关键区别:物理记号依赖上下标区分求和方向,而 NumPy einsum 仅用标签重复表示收缩,更适用于数值计算。
隐式模式 einsum
在隐式模式下,省略 -> 及输出标签,输出维度由输入标签按字母顺序自动推断。
基本示例
# 显式模式(推荐)
>>> np.einsum('ij,jk->ik', A, B) # 输出: (i,k)
# 隐式模式(等价)
>>> np.einsum('ij,jk', A, B) # 输出: (i,k) - i,k 按字母序排列
字母顺序的影响
隐式模式下,输出维度顺序由未收缩标签的字母序决定:
# 想要 (A @ B).T,即输出形状 (k,i)
>>> np.einsum('ij,jk->ki', A, B) # 显式:直接指定 'ki'
# 隐式模式:使用 'h' 替代 'k',因 'h' < 'i',输出为 (h,i)
>>> np.einsum('ij,jh', A, B) # 等价于 'ij,jh->hi'
隐式 vs 显式:对比总结
| 特性 | 隐式模式 | 显式模式 |
|---|---|---|
| 语法 | np.einsum('ij,jk', A, B) |
np.einsum('ij,jk->ik', A, B) |
| 输出顺序 | 按字母序自动排列 | 手动指定 |
| 可读性 | 较低(需记忆字母序规则) | 高(维度意图明确) |
| 适用场景 | 简单运算、快速原型 | 生产代码、论文复现 |
建议:除非有特殊需求,始终使用显式模式,以提升代码可维护性。
常见错误
❌ 错误 1:维度大小不匹配
# 错误:j 维度大小不一致
A = np.random.randn(2, 3)
B = np.random.randn(4, 5)
np.einsum('ij,jk->ik', A, B) # ValueError: 收缩维度 3 ≠ 4
✅ 解决:确保收缩维度的大小一致,或使用广播(需显式指定 ...)。
❌ 错误 2:隐式模式的字母序陷阱
# 以为输出是 (i,j),实际是 (j,i) 因为 'j' < 'i'
result = np.einsum('ij', A) # A 形状 (2,3)
print(result.shape) # (2,3) ✓ 但如果是 'ji' 则输出 (3,2)
✅ 解决:始终使用显式模式避免歧义。
❌ 错误 3:广播未显式声明
# 期望广播批量维度,但未使用 '...'
A = np.random.randn(2, 3, 4) # (b, i, j)
B = np.random.randn(3, 4, 5) # (i, j, k) - 无批量维度
# 错误:维度不匹配
np.einsum('bij,ijk->bk', A, B) # ValueError
# 正确:使用 '...' 启用广播
np.einsum('...ij,...jk->...ik', A, B) # 输出: (2, 3, 5)
常用运算对照
| 运算 | NumPy 原生 | einsum 表达式 |
|---|---|---|
| 矩阵乘法 | A @ B |
'ij,jk->ik' |
| 转置 | A.T |
'ij->ji' |
| 对角线提取 | np.diag(A) |
'ii->i' |
| 矩阵迹 | np.trace(A) |
'ii->' |
| 向量内积 | np.inner(a,b) |
'i,i->' |
| 向量外积 | np.outer(a,b) |
'i,j->ij' |
| 批量矩阵乘 | Ab @ Bb |
'bmd,bdn->bmn' |
| 多维权重求和 | - | 'bhnv,hdv->bnd' |
| 逐元素乘 + 求和 | np.sum(A*B) |
'ij,ij->' |