自动微分与张量操作
PyTorch 的 autograd 是一个自动求导引擎,autograd 通过在执行模型前向传播时构建反向图来工作。每次执行一个操作 Op 时,我们都会生成一个 autograd 节点 OpBackward,并将其记录在结果张量的 grad_fn 属性上。这里用图示会更有帮助。考虑下面这个程序:
x = torch.randn(8, requires_grad=True)
y = x ** 2
z = y * 2
这个过程的「前向传播」可以想象成一条流水线:
输入 x → [平方] → 中间结果 y → [乘2] → 输出 z
那么我们期望构建出如下的 autograd 图:
x -- grad_fn --> AccumulateGrad
^
|
y -- grad_fn --> PowBackward0
^
|
z -- grad_fn --> MulBackward0
关于上图的一些细节说明:
-
如果实际尝试打印
x.grad_fn,它会输出None。这是因为 PyTorch 隐藏了一个内部实现细节:每个叶子变量(leaf variable,即requires_grad为 True 但不是由其他变量计算得到的变量)都关联着一个AccumulateGrad节点。在底层,这个节点确实存在,实际上可以通过类似y.grad_fn.next_functions[0][0]的方式访问它。直观地说,AccumulateGrad节点的作用是将流入的梯度累加到对应叶子张量的grad字段中。这种设计使得 autograd 引擎的运行时实现更为简洁,因为无需为叶子节点设置特殊处理逻辑。 -
图中的箭头表示强引用(owning pointers)。在底层,autograd 使用引用计数指针实现(就像 Python 一样),避免内存循环和全局状态将在我们的实现策略中扮演重要角色。
-
每个反向节点名称(如 MulBackward0、PowBackward0 等)都对应一个生成的 C++ 类,用于表示该 autograd 节点。整数后缀用于区分同一函数的不同重载版本。在 PyTorch 构建目录下的
torch/csrc/autograd/generated/Functions.h或torch/csrc/autograd/generated/VariableTypeEverything.cpp文件中可以查看这些生成的 C++ 代码。
现在,在构建反向图的过程中,我们实际上是在为一个_隐式的_前向图编写导数,该前向图对应于前向传播中发生的纯(非可变)操作序列。这个前向图从未被具体化(除非你使用了某种追踪器),但思考这个图是有帮助的,因为从某种意义上说,autograd 所做的就是获取这个前向图,反转所有箭头的方向,并将前向节点替换为反向节点。
前向传播 反向传播
===================================
x x.grad
| ^
V |
[ Pow ] [ PowBackward ]
| ^
| |
y y.grad
| |
V |
[ Mul ] [ MulBackward ]
| ^
V |
z z.grad
(注意:除非你使用 retain_grad=True 运行 autograd,否则 y.grad 和 z.grad 实际上不会被填充。)
需要强调的是,这个隐式的前向图_始终_被视为纯函数图。不存在 MulInplaceBackward 这样的节点。要理解可变操作发生时发生了什么,可以想象一个_没有_可变操作的程序版本:autograd 图正是这个版本的反向。考虑:
x = torch.randn(8, requires_grad=True)
y = x ** 2
y.mul_(2)
这在语义上等价于:
x = torch.randn(8, requires_grad=True)
y = x ** 2
y2 = y * 2
# 假设 y 没有其他别名,现在将所有对 y 的使用替换为 y2
因此,这才是会在反向图中编码的程序。那么这究竟是如何实现的呢?在可变操作发生之前,我们有一个这样的图:
x -- grad_fn --> AccumulateGrad
^
|
y -- grad_fn --> PowBackward0
当我们执行乘法操作时,我们会生成新的节点 MulBackward0:
x -- grad_fn --> AccumulateGrad
^
|
y -- grad_fn --> PowBackward0
^
|
MulBackward0
但由于我们对 y 进行了原地修改,我们还必须原地修改 grad_fn 指针,使其指向这个新节点:
x -- grad_fn --> AccumulateGrad
^
|
y PowBackward0
| ^
| |
---- grad_fn --> MulBackward0
如果 y 没有任何别名,一旦我们执行了可变操作,就不再有任何张量的内容包含 pow 操作之后的值;相应地,也没有任何张量的 grad_fn 指向 PowBackward0。
总结一下,当发生可变操作时,我们需要做两件事:
- 为在张量上发生的计算生成一个新的反向节点。
- 修改受可变操作影响的张量的
grad_fn,使其指向这个新节点。
处理与基张量(base tensor)的别名关系
当 y 没有别名时,只有一个张量受影响,处理起来很简单。但如果存在别名,多个张量可能会受到影响。让我们看一个非常简单的例子,其中可变操作用于修改张量的某一行:
x = torch.randn((8, 8), requires_grad=True)
y = x ** 2
v = y[0]
v.mul_(2)
在可变操作发生之前,我们有一个如下所示的图:
x -- grad_fn --> AccumulateGrad
^
|
y -- grad_fn --> PowBackward0
^
|
v -- grad_fn --> SelectBackward0
要执行这个可变操作,我们必须同时更新 v.grad_fn 和 y.grad_fn。这里有几个问题需要解决。首先,我们究竟要让 y 指向哪个反向节点?它肯定不是 MulBackward0,因为那对应的是整个 y 都被乘法的情况;但在这个例子中,只有 y 的某一行被乘了。PyTorch 为这类情况提供了一个巧妙的复合反向节点:CopySlices。幸运的是,PyTorch 代码库中有一段很好的注释解释了它的作用:
// CopySlices 是什么?
// ~~~~~~~~~~~~~~~~~~~
//
// 我们支持带原地可变操作的 autograd;例如,如果你写 x.mul_(2)
// autograd 的工作方式就像你在底层有多个张量,并且你执行了:
// x = t.clone()
// x0 = x
// x1 = x0 * 2
// x = x1
// 如你所见,在这个操作之后,x.grad_fn 现在指向 x1.grad_fn
// (即 MulBackward 节点),而这个节点指向 x 原来的 grad_fn(也就是
// x0.grad_fn)。重要的是要记住,在原地操作之后,
// 不再有任何张量对象表示 x0 的状态了。但它的图
// 仍然存在于 autograd 中(以防 x 在被原地修改之前被使用过)。
//
// 现在,一个棘手的情况是,如果 x 是基张量 b 的一个可微分视图(differentiable view)。
// b = t.clone()
// x = b.select(0, 0)
// x *= 2
// 使用与上面相同的方法,这将变成:
// b = t.clone()
// x = b.select(0, 0)
// b0 = b
// x0 = x
// x1 = x0 * 2
// b1 = b0.select_scatter(x1, 0, 0)
// x2 = b1.select(0, 0)
// x = x2
// b = b1
// 如你所见,我们不仅需要修改 x 的 grad_fn,还需要
// 修改 b 的 grad_fn。我们还需要确保 x 的新 grad_fn
// 与 b 的新 grad_fn 相链接。这个 select_scatter、乘法和
// select 的链式操作就是 CopySlices 所做的,全部封装在一个节点中。
好的,那么我们现在得到了这样的图状态:
x -- grad_fn --> AccumulateGrad
^
|
y PowBackward0
| ^
| |
+--- grad_fn --> CopySlices(MulBackward0)
那 v 会发生什么变化呢?为了避免冗余的反向计算,我们想要做的是将 v 的 grad_fn “重基”(rebase)到其基张量(y)所指向的新 CopySlices 反向节点之上。(你也可以选择直接在旧的 v 反向节点上挂一个 MulBackward0,但如果这样做,MulBackward0 就会被计算两次!)最终的图看起来是这样的:
x -- grad_fn --> AccumulateGrad
^
|
y PowBackward0 <--- SelectBackward0 (现已失效!)
| ^
| |
+--- grad_fn --> CopySlices(MulBackward0)
^
|
v -------------> SelectBackward0
注意,旧的 SelectBackward0 --> PowBackward0 节点现在已经失效了!如果视图 v 的内容在可变操作之前没有被用于其他可微分计算,我们现在就可以垃圾回收(GC)这个节点,因为它不再对导数计算有贡献。然而,如果 v 在微分之前曾被其他计算使用过,那么那些值的 grad_fn 仍然会指向旧的 SelectBackward0!
(另一个技术细节:如果你在这个例子中实际打印 v.grad_fn,它会显示为 AsStridedBackward0。这是因为,默认情况下,当视图操作被重基时,我们会将它们解糖(desugar)为 as_strided 调用,因为 PyTorch 中的每个视图操作都可以通过 as_strided 来表示。这给不支持 as_strided 的替代后端实现者带来了无尽的麻烦,因此有一个可选的、后端特定的模式,试图让 PyTorch 更努力地保留原始的 autograd 函数。as_strided 的一个好处是,如果你有多个视图,它们会合并成一个单一的 AsStridedBackward0 节点。)
总结一下,当我们对一个张量的别名进行可变操作时:
- 我们更新基张量(创建一个
CopySlices反向节点)。 - 我们重基别名张量(将其反向节点重新应用到新的
CopySlices节点之上)。
因为你需要访问基张量才能执行此操作,所以 PyTorch 中所有可微分视图都会跟踪它们的基张量。你实际上可以通过 v._base 从用户代码中访问它。某些操作会破坏可微分视图(最值得注意的是 detach()),允许你修改张量而不传播导数。然而,即使你有一个不可微分的视图,版本计数器仍然会共享(因此你仍然能够检测到为反向传播保存的值是否因可变操作而失效)。
处理多个别名
在上面的例子中,重基视图很容易,因为我们正是在调用可变操作的那个视图上进行的。但如果有多个视图呢?
x = torch.randn((8, 8), requires_grad=True)
y = x ** 2
v1 = y[0,:]
v2 = y[:,0]
v1.mul_(2)
现在我们就遇到麻烦了:我们怎么知道要重基 v2 呢?我们无法知道:v1 只维护了一个指向 y 的指针,而不是指向 v2。在这种情况下,一个明显的做法是以某种方式让 y 跟踪所有指向它的视图。事实上,当 Sam 最初计划如何实现这一点时,我们就如何处理这种情况进行了设计讨论,而这正是"显而易见"的做法。然而,这样做有缺点。y 可能有潜在的很多视图,因此如果 y 要跟踪所有这些视图,它需要一个动态分配的缓冲区来保存所有视图的引用。指向视图的指针不能是强引用,因为视图本身持有指向基张量的指针,这样会造成循环引用。最后,这样的设计对多线程不友好,因为多个线程可能正在操作共享基张量的视图,它们会竞争这些跟踪指针。
这里有一个更好的想法:在 v2 实际被用于某些计算之前,没有人需要它。所以让我们惰性重基(lazily rebase):我们将在 v2 的 grad_fn 上记录父张量的一个版本,当访问 v2.grad_fn 时,我们检查它是否仍然与基张量保持最新。如果不是,我们就在计算之前即时生成一个重基后的 autograd 节点。简单又优雅!
结论
总结一下:
- 在每个可变计算内部,隐式地存在一个纯计算,这才是反映在 autograd 反向图中的内容。
- 当张量被可变操作修改时,它们的
grad_fn会被修改为指向新的反向节点。 - 对视图进行可变操作会先修改基张量的
grad_fn(使用 CopySlices),然后再将视图重基到这个新的基张量反向节点之上。 - 当存在多个视图时,其他视图会在后续首次访问时被惰性重基。