自动微分与张量操作

PyTorch 的 autograd 是一个自动求导引擎,autograd 通过在执行模型前向传播时构建反向图来工作。每次执行一个操作 Op 时,我们都会生成一个 autograd 节点 OpBackward,并将其记录在结果张量的 grad_fn 属性上。这里用图示会更有帮助。考虑下面这个程序:

x = torch.randn(8, requires_grad=True)
y = x ** 2
z = y * 2

这个过程的「前向传播」可以想象成一条流水线:

输入 x → [平方] → 中间结果 y → [乘2] → 输出 z

那么我们期望构建出如下的 autograd 图:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y -- grad_fn --> PowBackward0
                      ^
                      |
z -- grad_fn --> MulBackward0

关于上图的一些细节说明:

现在,在构建反向图的过程中,我们实际上是在为一个_隐式的_前向图编写导数,该前向图对应于前向传播中发生的纯(非可变)操作序列。这个前向图从未被具体化(除非你使用了某种追踪器),但思考这个图是有帮助的,因为从某种意义上说,autograd 所做的就是获取这个前向图,反转所有箭头的方向,并将前向节点替换为反向节点。

 前向传播           反向传播
===================================
    x               x.grad
    |                  ^
    V                  |
 [ Pow ]         [ PowBackward ]
    |                  ^
    |                  |
    y               y.grad
    |                  |
    V                  |
 [ Mul ]         [ MulBackward ]
    |                  ^
    V                  |
    z               z.grad

(注意:除非你使用 retain_grad=True 运行 autograd,否则 y.gradz.grad 实际上不会被填充。)

需要强调的是,这个隐式的前向图_始终_被视为纯函数图。不存在 MulInplaceBackward 这样的节点。要理解可变操作发生时发生了什么,可以想象一个_没有_可变操作的程序版本:autograd 图正是这个版本的反向。考虑:

x = torch.randn(8, requires_grad=True)
y = x ** 2
y.mul_(2)

这在语义上等价于:

x = torch.randn(8, requires_grad=True)
y = x ** 2
y2 = y * 2
# 假设 y 没有其他别名,现在将所有对 y 的使用替换为 y2

因此,这才是会在反向图中编码的程序。那么这究竟是如何实现的呢?在可变操作发生之前,我们有一个这样的图:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y -- grad_fn --> PowBackward0

当我们执行乘法操作时,我们会生成新的节点 MulBackward0

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y -- grad_fn --> PowBackward0
                      ^
                      |
                 MulBackward0

但由于我们对 y 进行了原地修改,我们还必须原地修改 grad_fn 指针,使其指向这个新节点:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y                PowBackward0
|                     ^
|                     |
---- grad_fn --> MulBackward0

如果 y 没有任何别名,一旦我们执行了可变操作,就不再有任何张量的内容包含 pow 操作之后的值;相应地,也没有任何张量的 grad_fn 指向 PowBackward0

总结一下,当发生可变操作时,我们需要做两件事:

  1. 为在张量上发生的计算生成一个新的反向节点。
  2. 修改受可变操作影响的张量的 grad_fn,使其指向这个新节点。

处理与基张量(base tensor)的别名关系

当 y 没有别名时,只有一个张量受影响,处理起来很简单。但如果存在别名,多个张量可能会受到影响。让我们看一个非常简单的例子,其中可变操作用于修改张量的某一行:

x = torch.randn((8, 8), requires_grad=True)
y = x ** 2
v = y[0]
v.mul_(2)

在可变操作发生之前,我们有一个如下所示的图:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y -- grad_fn --> PowBackward0
                      ^
                      |
v -- grad_fn --> SelectBackward0

要执行这个可变操作,我们必须同时更新 v.grad_fny.grad_fn。这里有几个问题需要解决。首先,我们究竟要让 y 指向哪个反向节点?它肯定不是 MulBackward0,因为那对应的是整个 y 都被乘法的情况;但在这个例子中,只有 y 的某一行被乘了。PyTorch 为这类情况提供了一个巧妙的复合反向节点:CopySlices。幸运的是,PyTorch 代码库中有一段很好的注释解释了它的作用:

// CopySlices 是什么?
// ~~~~~~~~~~~~~~~~~~~
//
// 我们支持带原地可变操作的 autograd;例如,如果你写 x.mul_(2)
// autograd 的工作方式就像你在底层有多个张量,并且你执行了:
//   x = t.clone()
//   x0 = x
//   x1 = x0 * 2
//   x = x1
// 如你所见,在这个操作之后,x.grad_fn 现在指向 x1.grad_fn
// (即 MulBackward 节点),而这个节点指向 x 原来的 grad_fn(也就是
// x0.grad_fn)。重要的是要记住,在原地操作之后,
// 不再有任何张量对象表示 x0 的状态了。但它的图
// 仍然存在于 autograd 中(以防 x 在被原地修改之前被使用过)。
//
// 现在,一个棘手的情况是,如果 x 是基张量 b 的一个可微分视图(differentiable view)。
//   b = t.clone()
//   x = b.select(0, 0)
//   x *= 2
// 使用与上面相同的方法,这将变成:
//   b = t.clone()
//   x = b.select(0, 0)
//   b0 = b
//   x0 = x
//   x1 = x0 * 2
//   b1 = b0.select_scatter(x1, 0, 0)
//   x2 = b1.select(0, 0)
//   x = x2
//   b = b1
// 如你所见,我们不仅需要修改 x 的 grad_fn,还需要
// 修改 b 的 grad_fn。我们还需要确保 x 的新 grad_fn
// 与 b 的新 grad_fn 相链接。这个 select_scatter、乘法和
// select 的链式操作就是 CopySlices 所做的,全部封装在一个节点中。

好的,那么我们现在得到了这样的图状态:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y                PowBackward0
|                     ^
|                     |
+--- grad_fn --> CopySlices(MulBackward0)

那 v 会发生什么变化呢?为了避免冗余的反向计算,我们想要做的是将 v 的 grad_fn “重基”(rebase)到其基张量(y)所指向的新 CopySlices 反向节点之上。(你也可以选择直接在旧的 v 反向节点上挂一个 MulBackward0,但如果这样做,MulBackward0 就会被计算两次!)最终的图看起来是这样的:

x -- grad_fn --> AccumulateGrad
                      ^
                      |
y                PowBackward0  <--- SelectBackward0  (现已失效!)
|                     ^
|                     |
+--- grad_fn --> CopySlices(MulBackward0)
                      ^
                      |
v -------------> SelectBackward0

注意,旧的 SelectBackward0 --> PowBackward0 节点现在已经失效了!如果视图 v 的内容在可变操作之前没有被用于其他可微分计算,我们现在就可以垃圾回收(GC)这个节点,因为它不再对导数计算有贡献。然而,如果 v 在微分之前曾被其他计算使用过,那么那些值的 grad_fn 仍然会指向旧的 SelectBackward0

(另一个技术细节:如果你在这个例子中实际打印 v.grad_fn,它会显示为 AsStridedBackward0。这是因为,默认情况下,当视图操作被重基时,我们会将它们解糖(desugar)为 as_strided 调用,因为 PyTorch 中的每个视图操作都可以通过 as_strided 来表示。这给不支持 as_strided 的替代后端实现者带来了无尽的麻烦,因此有一个可选的、后端特定的模式,试图让 PyTorch 更努力地保留原始的 autograd 函数。as_strided 的一个好处是,如果你有多个视图,它们会合并成一个单一的 AsStridedBackward0 节点。)

总结一下,当我们对一个张量的别名进行可变操作时:

  1. 我们更新基张量(创建一个 CopySlices 反向节点)。
  2. 我们重基别名张量(将其反向节点重新应用到新的 CopySlices 节点之上)。

因为你需要访问基张量才能执行此操作,所以 PyTorch 中所有可微分视图都会跟踪它们的基张量。你实际上可以通过 v._base 从用户代码中访问它。某些操作会破坏可微分视图(最值得注意的是 detach()),允许你修改张量而不传播导数。然而,即使你有一个不可微分的视图,版本计数器仍然会共享(因此你仍然能够检测到为反向传播保存的值是否因可变操作而失效)。

处理多个别名

在上面的例子中,重基视图很容易,因为我们正是在调用可变操作的那个视图上进行的。但如果有多个视图呢?

x = torch.randn((8, 8), requires_grad=True)
y = x ** 2
v1 = y[0,:]
v2 = y[:,0]
v1.mul_(2)

现在我们就遇到麻烦了:我们怎么知道要重基 v2 呢?我们无法知道:v1 只维护了一个指向 y 的指针,而不是指向 v2。在这种情况下,一个明显的做法是以某种方式让 y 跟踪所有指向它的视图。事实上,当 Sam 最初计划如何实现这一点时,我们就如何处理这种情况进行了设计讨论,而这正是"显而易见"的做法。然而,这样做有缺点。y 可能有潜在的很多视图,因此如果 y 要跟踪所有这些视图,它需要一个动态分配的缓冲区来保存所有视图的引用。指向视图的指针不能是强引用,因为视图本身持有指向基张量的指针,这样会造成循环引用。最后,这样的设计对多线程不友好,因为多个线程可能正在操作共享基张量的视图,它们会竞争这些跟踪指针。

这里有一个更好的想法:在 v2 实际被用于某些计算之前,没有人需要它。所以让我们惰性重基(lazily rebase):我们将在 v2grad_fn 上记录父张量的一个版本,当访问 v2.grad_fn 时,我们检查它是否仍然与基张量保持最新。如果不是,我们就在计算之前即时生成一个重基后的 autograd 节点。简单又优雅!

结论

总结一下:

  1. 在每个可变计算内部,隐式地存在一个纯计算,这才是反映在 autograd 反向图中的内容。
  2. 当张量被可变操作修改时,它们的 grad_fn 会被修改为指向新的反向节点。
  3. 对视图进行可变操作会先修改基张量的 grad_fn(使用 CopySlices),然后再将视图重基到这个新的基张量反向节点之上。
  4. 当存在多个视图时,其他视图会在后续首次访问时被惰性重基。