理解 NumPy 的 einsum 函数


概述

本文是理解和使用 numpy.einsum 的简要说明与实用指南,它让我们能够使用爱因斯坦记号(Einstein notation)对多维数组进行运算。本文主要关注 einsum 的显式模式(使用 -> 并在下标字符串中明确指定输出维度),以及机器学习中常见的用例,同时也会简要介绍其他模式。

核心概念:einsum 通过下标字符串描述数组运算,重复的维度标签表示收缩(求和),未重复的标签保留在输出中。


基本用例:矩阵乘法

让我们从一个基础演示开始:使用 einsum 进行矩阵乘法。在本文中,AB 将是以下矩阵:

>>> import numpy as np
>>> A = np.arange(6).reshape(2, 3)
>>> A
array([[0, 1, 2],
       [3, 4, 5]])

>>> B = np.arange(12).reshape(3, 4) + 1
>>> B
array([[ 1,  2,  3,  4],
       [ 5,  6,  7,  8],
       [ 9, 10, 11, 12]])

A 的形状为 (2,3)B 的形状为 (3,4),我们可以执行 A @ B 得到一个 (2,4) 的矩阵。这也可以用 einsum 实现:

>>> np.einsum('ij,jk->ik', A, B)
array([[ 23,  26,  29,  32],
       [ 68,  80,  92, 104]])

下标字符串解析

einsum 的第一个参数是下标字符串(subscript string),用于描述对后续操作数执行的操作:

部分 含义
ij 第一个输入 A 的维度标签,表示形状 (i,j)
jk 第二个输入 B 的维度标签,表示形状 (j,k)
-> 分隔输入与输出
ik 输出数组的维度标签,表示形状 (i,k)

关键规则

  1. 在输入中重复出现未出现在输出中的标签(如 j)会被收缩(沿该维度求和)
  2. 仅出现在输入或输出中一次的标签会被保留
  3. 输出维度的顺序由 -> 后的标签顺序决定

简化心智模型

以下是 einsum 工作原理的简化理解方式:

对于输出中的每个元素 out[i,k]:
    out[i,k] = Σⱼ (A[i,j] × B[j,k])

即:每个输出元素是 A 的第 i 行与 B 的第 k 列的点积。

转置输出

我们可以通过调整输出标签的顺序来转置结果:

>>> np.einsum('ij,jk->ki', A, B)
array([[ 23,  68],
       [ 26,  80],
       [ 29,  92],
       [ 32, 104]])

这等价于 (A @ B).T


批量矩阵乘法

当输入数组的维度数(ndim)增加时,einsum 作为文档辅助工具的价值更加明显。例如,我们可能希望在单个操作中对一整批输入执行矩阵乘法:

>>> Ab = np.arange(6*6).reshape(6, 2, 3)   # 形状: (6, 2, 3)
>>> Bb = np.arange(6*12).reshape(6, 3, 4)  # 形状: (6, 3, 4)

这里 6批量维度(batch dimension)。我们将一批 6 个 (2,3) 矩阵与一批 6 个 (3,4) 矩阵逐对相乘,结果形状为 (6,2,4)

使用 @ 运算符

NumPy 的 @ 运算符天然支持批量矩阵乘法:

>>> result = Ab @ Bb  # 形状: (6, 2, 4)

收缩发生在第一个数组的最后一个维度(d=3)和第二个数组的倒数第二个维度(d=3)之间,批量维度自动广播对齐。

使用 einsum(推荐)

>>> np.einsum('bmd,bdn->bmn', Ab, Bb)
标签 可能含义
b batch(批量)
m sequence length(序列长度)
d model dimension(模型维度/深度)
n output dimension(输出维度)

优势:下标字符串让维度语义显式化,代码可读性显著提升。

⚠️ 注意:虽然 b 在输入中重复出现,但它也出现在输出中,因此不会被收缩


排序输出维度

einsum 下标中输出维度的顺序让我们不仅能进行矩阵乘法,还可以任意重排输出维度

>>> Bb.shape
(6, 3, 4)

>>> np.einsum('ijk->kij', Bb).shape
(4, 6, 3)  # 原 (i,j,k) → 新 (k,i,j)

实际应用:Transformer 中的多头注意力

以下示例取自 Noam Shazeer 的《Fast Transformer Decoding》论文:

# 维度常量
>>> m, d, k, h, b = 4, 3, 6, 5, 10

# 随机张量
>>> Pk = np.random.randn(h, d, k)  # 形状: (h, d, k) - 键投影权重
>>> M = np.random.randn(b, m, d)   # 形状: (b, m, d) - 输入表示

# 计算所有键:einsum 一次性完成收缩 + 维度重排
>>> np.einsum('bmd,hdk->bhmk', M, Pk).shape
(10, 5, 4, 6)  # (b, h, m, k)
张量 形状 含义
M (b,m,d) 批量×序列长度×模型深度
Pk (h,d,k) 头数×模型深度×键的头维度
输出 (b,h,m,k) 批量×头数×序列长度×键维度

灵活重排:若需要不同维度顺序,只需修改输出标签:

>>> np.einsum('bmd,hdk->hbmk', M, Pk).shape  # (h, b, m, k)
(5, 10, 4, 6)

多维度收缩

单个 einsum 操作可以同时收缩多个维度

>>> b, n, d, v, h = 10, 4, 3, 6, 5
>>> O = np.random.randn(b, h, n, v)   # 形状: (b, h, n, v)
>>> Po = np.random.randn(h, d, v)      # 形状: (h, d, v)

>>> np.einsum('bhnv,hdv->bnd', O, Po).shape
(10, 4, 3)  # (b, n, d)

收缩分析

标签 出现位置 是否在输出 操作
b 输入1 保留
h 输入1, 输入2 收缩(求和)
n 输入1 保留
v 输入1, 输入2 收缩(求和)
d 输入2 保留

等价计算(伪代码):

out[b,n,d] = Σₕ Σᵥ (O[b,h,n,v] × Po[h,d,v])

转置输入

在指定 einsum 的输入时,我们可以通过重新排序维度标签来隐式转置输入:

>>> A.shape
(2, 3)

# A @ A.T:第二个输入使用 'kj' 而非 'jk'
>>> np.einsum('ij,kj->ik', A, A)
array([[ 5, 14],
       [14, 50]])

# 验证
>>> (A @ A.T)
array([[ 5, 14],
       [14, 50]])

原理j 在两个输入中重复出现但未出现在输出中,因此被收缩;kj 的顺序使第二个输入被视为转置。


多于两个参数

einsum 支持任意数量的输入,适合链式矩阵乘法:

>>> C = np.arange(20).reshape(4, 5)  # 形状: (4, 5)

# 传统写法
>>> A @ B @ C
array([[ 900, 1010, 1120, 1230, 1340],
       [2880, 3224, 3568, 3912, 4256]])

# einsum 写法
>>> np.einsum('ij,jk,kp->ip', A, B, C)
array([[ 900, 1010, 1120, 1230, 1340],
       [2880, 3224, 3568, 3912, 4256]])

优势:显式的维度名称使复杂链式运算的维度流动一目了然。


einsum 的教学式实现

理解 einsum 的最佳方式之一是手动实现其核心逻辑。以下以 'ij,jk->ik'(矩阵乘法)为例:

步骤 1:解析维度大小

def calc(__a, __b):
    # 从下标 'ij,jk->ik' 解析维度
    i_size = __a.shape[0]  # i
    j_size = __a.shape[1]  # j (also __b.shape[0])
    k_size = __b.shape[1]  # k
    
    assert j_size == __b.shape[0], "收缩维度大小不匹配"

步骤 2:初始化输出

    out = np.zeros((i_size, k_size))  # 输出形状: (i, k)

步骤 3:生成嵌套循环

    # 外层循环:输出维度 (i, k)
    for i in range(i_size):
        for k in range(k_size):
            # 内层循环:收缩维度 (j)
            for j in range(j_size):
                out[i, k] += __a[i, j] * __b[j, k]
    return out

通用模式

对于任意下标字符串,可遵循以下规则生成代码:

  1. 输出维度 → 外层循环
  2. 收缩维度(输入中出现但输出中缺失)→ 内层循环 + 累加
  3. 元素访问 → 直接按标签索引:out[i,k] += a[i,j] * b[j,k]

多维度收缩示例

对于 'bhnv,hdv->bnd'

# 输出维度: b, n, d → 外层循环
# 收缩维度: h, v → 内层循环
for b in range(b_size):
    for n in range(n_size):
        for d in range(d_size):
            for h in range(h_size):      # 收缩
                for v in range(v_size):  # 收缩
                    out[b, n, d] += __a[b, h, n, v] * __b[h, d, v]

无收缩情况:外积

对于 'i,j->ij'(外积):

def calc(__a, __b):
    i_size = __a.shape[0]
    j_size = __b.shape[0]
    out = np.zeros((i_size, j_size))
    
    # 无收缩维度,直接赋值
    for i in range(i_size):
        for j in range(j_size):
            out[i, j] = __a[i] * __b[j]  # 注意:不是 +=
    return out

爱因斯坦记号(历史背景)

这种记号以阿尔伯特·爱因斯坦命名,他在 1916 年关于广义相对论的论文中首次引入,用于简洁表达张量运算中的嵌套求和。

物理中的经典形式

在物理学中,张量通常有下标(协变)和上标(逆变),例如:

B¹ = a₁₁A¹ + a₁₂A² + a₁₃A³ = Σⱼ₌₁³ a₁ⱼAʲ
B² = a₂₁A¹ + a₂₂A² + a₂₃A³ = Σⱼ₌₁³ a₂ⱼAʲ  
B³ = a₃₁A¹ + a₃₂A² + a₃₃A³ = Σⱼ₌₁³ a₃ⱼAʲ

使用索引 i 合并为:

Bⁱ = Σⱼ₌₁³ aᵢⱼAʲ

爱因斯坦约定:当同一索引在单项中重复出现(一次上标、一次下标)时,隐含对该索引求和

Bⁱ = aᵢⱼAʲ   ← 求和符号 Σⱼ 被省略

与 NumPy einsum 的对应关系

爱因斯坦记号 NumPy einsum 含义
aᵢⱼAʲ 'ij,j->i' 矩阵 - 向量乘法
TᵢⱼₖSᵏˡ 'ijk,kl->ijl' 张量收缩
δᵢᵢ 'ii->' 矩阵的迹(对角线求和)

关键区别:物理记号依赖上下标区分求和方向,而 NumPy einsum 仅用标签重复表示收缩,更适用于数值计算。


隐式模式 einsum

隐式模式下,省略 -> 及输出标签,输出维度由输入标签按字母顺序自动推断。

基本示例

# 显式模式(推荐)
>>> np.einsum('ij,jk->ik', A, B)  # 输出: (i,k)

# 隐式模式(等价)
>>> np.einsum('ij,jk', A, B)      # 输出: (i,k) - i,k 按字母序排列

字母顺序的影响

隐式模式下,输出维度顺序由未收缩标签的字母序决定:

# 想要 (A @ B).T,即输出形状 (k,i)
>>> np.einsum('ij,jk->ki', A, B)        # 显式:直接指定 'ki'

# 隐式模式:使用 'h' 替代 'k',因 'h' < 'i',输出为 (h,i)
>>> np.einsum('ij,jh', A, B)            # 等价于 'ij,jh->hi'

隐式 vs 显式:对比总结

特性 隐式模式 显式模式
语法 np.einsum('ij,jk', A, B) np.einsum('ij,jk->ik', A, B)
输出顺序 按字母序自动排列 手动指定
可读性 较低(需记忆字母序规则) (维度意图明确)
适用场景 简单运算、快速原型 生产代码、论文复现

建议:除非有特殊需求,始终使用显式模式,以提升代码可维护性。


常见错误

❌ 错误 1:维度大小不匹配

# 错误:j 维度大小不一致
A = np.random.randn(2, 3)
B = np.random.randn(4, 5)
np.einsum('ij,jk->ik', A, B)  # ValueError: 收缩维度 3 ≠ 4

解决:确保收缩维度的大小一致,或使用广播(需显式指定 ...)。

❌ 错误 2:隐式模式的字母序陷阱

# 以为输出是 (i,j),实际是 (j,i) 因为 'j' < 'i'
result = np.einsum('ij', A)  # A 形状 (2,3)
print(result.shape)  # (2,3) ✓ 但如果是 'ji' 则输出 (3,2)

解决:始终使用显式模式避免歧义。

❌ 错误 3:广播未显式声明

# 期望广播批量维度,但未使用 '...'
A = np.random.randn(2, 3, 4)   # (b, i, j)
B = np.random.randn(3, 4, 5)   # (i, j, k) - 无批量维度

# 错误:维度不匹配
np.einsum('bij,ijk->bk', A, B)  # ValueError

# 正确:使用 '...' 启用广播
np.einsum('...ij,...jk->...ik', A, B)  # 输出: (2, 3, 5)

常用运算对照

运算 NumPy 原生 einsum 表达式
矩阵乘法 A @ B 'ij,jk->ik'
转置 A.T 'ij->ji'
对角线提取 np.diag(A) 'ii->i'
矩阵迹 np.trace(A) 'ii->'
向量内积 np.inner(a,b) 'i,i->'
向量外积 np.outer(a,b) 'i,j->ij'
批量矩阵乘 Ab @ Bb 'bmd,bdn->bmn'
多维权重求和 - 'bhnv,hdv->bnd'
逐元素乘 + 求和 np.sum(A*B) 'ij,ij->'