less than 1 minute read

首先感谢有 PyTorchViz 这个项目 (依赖 Graphviz) 使得我们可以画出 PyTorch 的 computational graph。PyTorch 自己好像并没有 built-in 的机制来做这件事情。

一个简单的例子 (直接在 notebook 中运行):

import torch as t
from torchviz import make_dot

a = t.randn(3,4).requires_grad_()
b = t.zeros(3,4).requires_grad_()

c = a + b

make_dot(var=c, params={"a": a, "b": b})

查看 torchviz::make_dot()源代码,发现:

  • 上例中画图的过程,开端其实是 add_nodes(var.grad_fn),这里 var = c
  • c.grad_fn 类型是 ThAddBackward
  • c.grad_fn.next_functions 是个 tuple,包含两个 AccumulateGrad
  • AccumulateGrad.variable 可以回溯到 ab

c.grad_fn.next_functions 这棵 tree,唯一可能的构建时间是在 c = a + b。我猜测是 PyTorch 是 overload 了 Tensor+,但是这一点有点难验证,原因:

  • 继承关系有 class Tensor(torch._C._TensorBase) (源代码),所以 Tensor.__add__() 其实是在 torch._C._TensorBase 里定义的
  • torch._C._TensorBase 其实是 C++ 代码

PyTorch 的 C++ 代码目前看来有 pytorch/torch/csrc/ 这么多,目前我还不知道要从何读起。另外有人提到有些 operation 是通过一个 YAML 定义的,即 pytorch/torch/csrc/generic/methods/TensorMath.cwrap,其中的确有 add 的 entry,但是这个机制我也不懂。具体的讨论及资源有:

Categories:

Updated:

Comments