PyTorch: Computational Graphs
首先感谢有 PyTorchViz 这个项目 (依赖 Graphviz) 使得我们可以画出 PyTorch 的 computational graph。PyTorch 自己好像并没有 built-in 的机制来做这件事情。
一个简单的例子 (直接在 notebook 中运行):
import torch as t
from torchviz import make_dot
a = t.randn(3,4).requires_grad_()
b = t.zeros(3,4).requires_grad_()
c = a + b
make_dot(var=c, params={"a": a, "b": b})

查看 torchviz::make_dot() 的源代码,发现:
- 上例中画图的过程,开端其实是
add_nodes(var.grad_fn),这里var = c c.grad_fn类型是ThAddBackwardc.grad_fn.next_functions是个 tuple,包含两个AccumulateGradAccumulateGrad.variable可以回溯到a和b

c.grad_fn.next_functions 这棵 tree,唯一可能的构建时间是在 c = a + b。我猜测是 PyTorch 是 overload 了 Tensor 的 +,但是这一点有点难验证,原因:
- 继承关系有
class Tensor(torch._C._TensorBase)(源代码),所以Tensor.__add__()其实是在torch._C._TensorBase里定义的 - 而
torch._C._TensorBase其实是 C++ 代码
PyTorch 的 C++ 代码目前看来有 pytorch/torch/csrc/ 这么多,目前我还不知道要从何读起。另外有人提到有些 operation 是通过一个 YAML 定义的,即 pytorch/torch/csrc/generic/methods/TensorMath.cwrap,其中的确有 add 的 entry,但是这个机制我也不懂。具体的讨论及资源有:
留下评论