PyTorch Tutorial: What is PyTorch?

https://pytorch.org/tutorials/beginner/blitz/tensor_tutorial.html#sphx-glr-beginner-blitz-tensor-tutorial-py

Python-based scientific computing package targeted at two sets of audiences:

  • A replacement for NumPy to use the power of GPUs
  • a deep learning research platform that provides maximum flexibility and speed
In [65]:
import torch
print(torch.__version__)
1.3.1

API

List of functions in torch: https://pytorch.org/docs/stable/torch.html

Tensors

similar to NumPy’s ndarrays, can also be used on a GPU to accelerate computing.

Note uninitialized matrix contains whatever values were in the allocated memory at the time

In [28]:
x = torch.empty(5, 3)
print(x)
tensor([[ 0.0000e+00,  0.0000e+00, -5.6109e+34],
        [-3.2286e-42,  0.0000e+00,  0.0000e+00],
        [ 4.4842e-44,  0.0000e+00,  2.1725e-30],
        [ 4.5908e-41,  5.6052e-45,  0.0000e+00],
        [ 2.3822e-44,  0.0000e+00, -1.9014e+38]])
In [29]:
# randomly initialized matrix

x = torch.rand(5, 3)
print(x)
tensor([[0.7854, 0.2313, 0.9500],
        [0.9516, 0.6996, 0.2933],
        [0.7968, 0.6422, 0.1751],
        [0.8896, 0.8035, 0.5114],
        [0.1203, 0.9271, 0.2717]])
In [30]:
# matrix filled zeros and of dtype long

x = torch.zeros(5, 3, dtype=torch.long)
print(x)
tensor([[0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0],
        [0, 0, 0]])
In [31]:
# Construct a tensor directly from data

x = torch.tensor([5.5, 3])
print(x)
tensor([5.5000, 3.0000])
In [43]:
# Create a tensor based on an existing tensor. 
# Reuse or change properties of the input tensor, e.g. dtype

x = x.new_ones(5, 3, dtype=torch.double)      # new_* methods take in sizes
print(x)

x = torch.randn_like(x, dtype=torch.float)    # override dtype!
print(x)                                      # result has the same size
tensor([[1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.],
        [1., 1., 1.]], dtype=torch.float64)
tensor([[-1.5962,  0.8143,  0.0625],
        [-1.1725, -1.7245,  1.3550],
        [-0.1825,  0.2450, -1.6819],
        [ 1.4698,  0.0891, -1.2921],
        [ 1.6773,  1.3049,  0.3965]])
In [44]:
print(x.size()) # returns a tuple
torch.Size([5, 3])

NumPy-style Slices & Indexing

In [45]:
print(x[:, 1])
tensor([ 0.8143, -1.7245,  0.2450,  0.0891,  1.3049])
In [46]:
print(x[1,:])
tensor([-1.1725, -1.7245,  1.3550])
In [47]:
print(x[1,[True, False, True]])
tensor([-1.1725,  1.3550])

Resize/Reshape Tensors with torch.view

In [54]:
x = torch.randn(4, 4)
y = x.view(16)
print(y)
print(x.size(), y.size())
tensor([-0.1049,  1.2708, -1.4843, -1.8760, -0.4595, -1.0606,  0.3229, -1.1334,
        -0.7912,  0.5949,  0.2558,  0.7625, -0.3833,  0.2607,  0.8765,  1.4194])
torch.Size([4, 4]) torch.Size([16])
In [55]:
z = x.view(-1, 8)  # the size -1 is inferred from other dimensions
print(x.size(), z.size())
torch.Size([4, 4]) torch.Size([2, 8])

.item() for converting one-element tensor to scalar

In [58]:
x[0,0]
Out[58]:
tensor(-0.1049)
In [59]:
x[0,0].item()
Out[59]:
-0.10485438257455826

Converting Torch Tensor to NumPy array

Note that this is a reference to the tensor, not a copy

In [61]:
y = x.numpy()
print(y)
[[-0.10485438  1.2708315  -1.4843239  -1.8760405 ]
 [-0.45950317 -1.0605842   0.32289705 -1.1333894 ]
 [-0.7912421   0.5949042   0.25584257  0.7625179 ]
 [-0.38327378  0.26073715  0.87647516  1.4194298 ]]
In [63]:
x.add_(10) # changing the torch tensor
print(x)
print(y) # y is changed also
tensor([[19.8951, 21.2708, 18.5157, 18.1240],
        [19.5405, 18.9394, 20.3229, 18.8666],
        [19.2088, 20.5949, 20.2558, 20.7625],
        [19.6167, 20.2607, 20.8765, 21.4194]])
[[19.895145 21.270832 18.515676 18.123959]
 [19.540497 18.939415 20.322897 18.866611]
 [19.208757 20.594904 20.255842 20.762518]
 [19.616726 20.260738 20.876476 21.41943 ]]

Converting NumPy array to Torch Tensor

In [66]:
import numpy as np
a = np.ones(5)
b = torch.from_numpy(a)
np.add(a, 1, out=a)
print(a)
print(b)
[2. 2. 2. 2. 2.]
tensor([2., 2., 2., 2., 2.], dtype=torch.float64)

CUDA Tensor

Tensors can be moved onto any device using the .to method.

In [67]:
# let us run this cell only if CUDA is available
# We will use ``torch.device`` objects to move tensors in and out of GPU
if torch.cuda.is_available():
    device = torch.device("cuda")          # a CUDA device object
    y = torch.ones_like(x, device=device)  # directly create a tensor on GPU
    x = x.to(device)                       # or just use strings ``.to("cuda")``
    z = x + y
    print(z)
    print(z.to("cpu", torch.double))       # ``.to`` can also change dtype together!

Operations

There are multiple syntaxes for operations

  • overloaded operator, e.g. "x+y"
  • classic function with return, e.g. z = torch.add(x,y)
  • function without return, torch.add(x,y,out=z)
  • in-place function y.add_(x)

Note: trailing underscore means in-place operation for tensor, .copy(), .t()

In [39]:
x = torch.rand(3, 2)
y = torch.rand(3, 2)
print(x + y)
tensor([[0.9406, 0.9646],
        [1.4418, 1.1993],
        [0.5169, 0.8273]])
In [40]:
print(torch.add(x, y))
tensor([[0.9406, 0.9646],
        [1.4418, 1.1993],
        [0.5169, 0.8273]])
In [41]:
result = torch.empty(5, 3)
torch.add(x, y, out=result)
print(result)
tensor([[0.9406, 0.9646],
        [1.4418, 1.1993],
        [0.5169, 0.8273]])
In [42]:
y.add_(x)
print(y)
tensor([[0.9406, 0.9646],
        [1.4418, 1.1993],
        [0.5169, 0.8273]])

PyTorch AutoGrad Tutorial

https://pytorch.org/tutorials/beginner/blitz/autograd_tutorial.html#sphx-glr-beginner-blitz-autograd-tutorial-py

  • autograd package - automatic differentiation for all operations on Tensors.
  • a define-by-run framework
    • backprop is defined by how your code is run
    • every single iteration can be different.

Tensor $\verb|requires_grad|$ Attribute

  • set requires_grad = True for tracking operations
  • after tracked computation, call .backward() and have all the gradients computed automatically
  • gradient for this tensor will be accumulated into .grad attribute
  • call .detach() to prevent future computation from being tracked
  • can also wrap the code block in "with torch.no_grad():" to evaluate without tracking history
In [76]:
x = torch.ones(2, 2, requires_grad=True)
print(x)
tensor([[1., 1.],
        [1., 1.]], requires_grad=True)
In [77]:
y = x + 2
print(y)
tensor([[3., 3.],
        [3., 3.]], grad_fn=<AddBackward0>)
In [78]:
print(y.grad_fn)
<AddBackward0 object at 0x000001F8F2964308>
In [79]:
z = y * y * 3
out = z.mean()

print(z, out)
tensor([[27., 27.],
        [27., 27.]], grad_fn=<MulBackward0>) tensor(27., grad_fn=<MeanBackward0>)

$X = \begin{pmatrix}1&1\\1&1\end{pmatrix}$ ...requires_grad=true means keep track of $X$ in subsequent operations

$Y = X + 2 = \begin{pmatrix}3&3\\3&3\end{pmatrix}$

$Z = Y \odot Y \times 3 = \begin{pmatrix}27&27\\27&27\end{pmatrix}$ ...notice elementwise multiplication, not matmul

out = mean$(Z) = 27$ ...mean over all elements, unlike matlab

In [80]:
# .requires_grad_( ... ) changes an existing Tensor’s requires_grad flag in-place.
# The input flag defaults to False if not given.

a = torch.randn(2, 2)
a = ((a * 3) / (a - 1))
print(a.requires_grad) # without trailing underscore it is the attribute not a function

a.requires_grad_(True)
print(a.requires_grad)

b = (a * a).sum()
print(b.grad_fn)
False
True
<SumBackward0 object at 0x000001F8F2936CC8>
In [81]:
# Perform backprop now, get gradients d(out)/dx
# Because out contains a single scalar, out.backward() is equivalent to 
#  out.backward(torch.tensor(1.)).

out.backward()
print(x.grad)
tensor([[4.5000, 4.5000],
        [4.5000, 4.5000]])

Scalar Version

  1. out = $\frac{1}{4}(z_1+z_2+z_3+z_4)$
  2. $z_i = y_i^2 \times 3$
  3. $y_i = x_i + 2$

$\frac{\partial out}{\partial x_i}$ = $\frac{\partial out}{\partial z_i}$ $\frac{\partial z_i}{\partial x_i}$

$\frac{\partial out}{\partial x_i}$ = $\frac{\partial out}{\partial z_i}\frac{\partial z_i}{\partial y_i}\frac{\partial y_i}{\partial x_i}$ = $\frac{1}{4} \times$ $6 y_i$ $\times 1$ = $\frac{6}{4} (3)$ = $\frac{18}{4} = 4.5$

Matrix Version

  1. out = $\frac{1}{4}(\mathbf 1^T Z \mathbf 1)$
  2. $Z = Y \odot Y \times 3$
  3. $Y = X + 2$

Matrix derivative for Hadamard products: $ d (A \odot B) = dA \odot B + A \odot dB$, so $d (Y \odot Y) = (\mathbf 1 \mathbf 1^T) \odot Y + Y \odot (\mathbf 1 \mathbf 1^T) = 2 Y$

Matrix derivative for quadratic form $\frac{\partial}{\partial A} (\mathbf x^T A \mathbf y) = \mathbf x \mathbf y^T$ so $\frac{\partial}{\partial Z} (\mathbf 1^T Z \mathbf 1) = \mathbf 1 \mathbf 1^T$

$\frac{\partial out}{\partial X}$ = $\frac{\partial out}{\partial Z}$ $\frac{\partial Z}{\partial X}$

$\frac{\partial out}{\partial X}$ = $\frac{\partial out}{\partial Z}\frac{\partial Z}{\partial Y}\frac{\partial Y}{\partial X}$ = $(\frac{1}{4} \mathbf 1 \mathbf 1^T) \times (6 Y) \times (1)$ = $\frac{6}{4} (3)$ = $\frac{18}{4} = 4.5$

....??

Vector Version

  1. out = $\frac{1}{4}(\mathbf 1^T z)$
  2. $z = y \odot y \times 3$
  3. $y = y + 2$

Vector derivative for Hadamard products must differ...: $ d (a \odot b) = ??$,

$\frac{\partial y}{\partial y} = I$

$\frac{\partial out}{\partial X}$ = $\frac{\partial out}{\partial Z}$ $\frac{\partial Z}{\partial X}$

$\frac{\partial out}{\partial X}$ = $\frac{\partial out}{\partial Z}\frac{\partial Z}{\partial Y}\frac{\partial Y}{\partial X}$ = $(\frac{1}{4} \mathbf 1 \mathbf 1^T) \times (6 Y) \times (1)$ = $\frac{6}{4} (3)$ = $\frac{18}{4} = 4.5$

....??

Gradient of vector-valued function $\mathbf y = f(\mathbf x)$ is Jacobian

$$ \mathbf J = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \end{pmatrix} $$

"torch.autograd is an engine for computing vector-Jacobian products"

$$ \mathbf J^T \mathbf v = \begin{pmatrix} \frac{\partial y_1}{\partial x_1} & \dots & \frac{\partial y_1}{\partial x_n} \\ \vdots & \ddots & \vdots \\ \frac{\partial y_m}{\partial x_1} & \dots & \frac{\partial y_m}{\partial x_n} \end{pmatrix} \begin{pmatrix} v_1 \\ \vdots \\ v_n \end{pmatrix} $$

Suppose there is a scalar function $g(\mathbf y)$,

with gradient $\frac{\partial g}{\partial \mathbf y} = \begin{pmatrix} \frac{\partial g}{\partial y_1} \\ \vdots \\ \frac{\partial g}{\partial y_m} \end{pmatrix} = \mathbf v$

where $\mathbf y = f(\mathbf x)$

Then by the chain rule, $\frac{\partial g}{\partial \mathbf x} = \frac{\partial g}{\partial \mathbf y} \frac{\partial \mathbf y}{\partial \mathbf x} = \mathbf J^T \mathbf v$

....work out the notation carefully...

In [84]:
x = torch.randn(3, requires_grad=True)

y = x * 2
while y.data.norm() < 1000:
    y = y * 2

print(y)
tensor([1037.6091, -651.7774,  984.2659], grad_fn=<MulBackward0>)

if we just want the vector-Jacobian product, pass the vector to backward as argument:

In [85]:
v = torch.tensor([0.1, 1.0, 0.0001], dtype=torch.float)
y.backward(v)

print(x.grad)
tensor([2.0480e+02, 2.0480e+03, 2.0480e-01])
In [86]:
## Prevent tracking for efficiency

print(x.requires_grad)
print((x ** 2).requires_grad)

with torch.no_grad():
    print((x ** 2).requires_grad)
True
True
False

The Torch "function" class

https://pytorch.org/docs/stable/autograd.html#function

  • Think of as a computation graph node
  • Methods for forward and backward compute
In [89]:
dir(torch.autograd.Function)
Out[89]:
['__call__',
 '__class__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__getattribute__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__le__',
 '__lt__',
 '__module__',
 '__ne__',
 '__new__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__',
 '__weakref__',
 '_backward_cls',
 '_do_backward',
 '_do_forward',
 '_is_legacy',
 '_register_hook',
 '_register_hook_dict',
 'apply',
 'backward',
 'dirty_tensors',
 'forward',
 'is_traceable',
 'mark_dirty',
 'mark_non_differentiable',
 'mark_shared_storage',
 'metadata',
 'needs_input_grad',
 'next_functions',
 'non_differentiable',
 'register_hook',
 'requires_grad',
 'save_for_backward',
 'saved_tensors',
 'saved_variables',
 'to_save']
In [88]:
# extend it for your own custom function by changing the forward and backward methods

class Exp(torch.autograd.Function):
    @staticmethod
    def forward(ctx, i):
        result = i.exp()
        ctx.save_for_backward(result)
        return result

    @staticmethod
    def backward(ctx, grad_output):
        result, = ctx.saved_tensors
        return grad_output * result

forward(ctx, *args, **kwargs)

https://pytorch.org/docs/stable/_modules/torch/autograd/function.html#Function.forward

  • Performs the operation.
  • This function is to be overridden by all subclasses.
  • It must accept a context ctx as the first argument, followed by any number of arguments (tensors or other types).
  • The context can be used to store tensors that can be then retrieved during the backward pass.

backward(ctx, *grad_outputs)

https://pytorch.org/docs/stable/_modules/torch/autograd/function.html#Function.backward

  • Defines a formula for differentiating the operation.

  • This function is to be overridden by all subclasses.

  • It must accept a context ctx as the first argument, followed by as many outputs did forward() return,

  • it should return as many tensors, as there were inputs to forward().
  • Each argument is the gradient w.r.t the given output
  • Each returned value should be the gradient w.r.t. the corresponding input.

  • The context can be used to retrieve tensors saved during the forward pass.

    • It also has an attribute $\verb|ctx.needs_input_grad|$ as a tuple of booleans representing whether each input needs gradient.
    • E.g., backward() will have $\verb|ctx.needs_input_grad[0] = True|$ if the first input to forward() needs gradient computated w.r.t. the output.

The staticmethod() built-in function returns a static method for a given function.

Using staticmethod() is considered un-Pythonic way of creating a static function.

So, in newer versions of Python, you can use the Python decorator @staticmethod.

The syntax of @staticmethod is:

In [ ]:
@staticmethod
def func(args, ...)

What is a static method?

Static methods, much like class methods, are methods that are bound to a class rather than its object.

They do not require a class instance creation. So, are not dependent on the state of the object.

The difference between a static method and a class method is:

  1. Static method knows nothing about the class and just deals with the parameters.
  2. Class method works with the class since its parameter is always the class itself.

They can be called both by the class and its object.

Class.staticmethodFunc()

or even

Class().staticmethodFunc()

Metaprogramming with Python decorators

metaprogramming = part of the program tries to modify another part of the program at compile time.

Decorator is a function which takes a function as input and returns a function

A kind of "higher order functions"

Behind the scenes

  • All Python functions, and variables, and everything else are classes
In [91]:
def add1(x):
    return x + 1

dir(add1)
Out[91]:
['__annotations__',
 '__call__',
 '__class__',
 '__closure__',
 '__code__',
 '__defaults__',
 '__delattr__',
 '__dict__',
 '__dir__',
 '__doc__',
 '__eq__',
 '__format__',
 '__ge__',
 '__get__',
 '__getattribute__',
 '__globals__',
 '__gt__',
 '__hash__',
 '__init__',
 '__init_subclass__',
 '__kwdefaults__',
 '__le__',
 '__lt__',
 '__module__',
 '__name__',
 '__ne__',
 '__new__',
 '__qualname__',
 '__reduce__',
 '__reduce_ex__',
 '__repr__',
 '__setattr__',
 '__sizeof__',
 '__str__',
 '__subclasshook__']

The $\verb|__call__()|$ method tells us this class is callable (i.e. a "function")

In [93]:
def divide(a, b):
    return a/b
In [94]:
print(divide(2,5))
print(divide(2,0))
0.4
---------------------------------------------------------------------------
ZeroDivisionError                         Traceback (most recent call last)
<ipython-input-94-dc2bd21ae966> in <module>
      1 print(divide(2,5))
----> 2 print(divide(2,0))

<ipython-input-93-7507bdc665d5> in divide(a, b)
      1 def divide(a, b):
----> 2     return a/b

ZeroDivisionError: division by zero
In [107]:
# this function blocks input functions from getting a zero as second input 
# ...returns none if second inut is zero
# ...otherwise calls input function and returns its output

def zeroblocker(func):
    def inner(a,b):
        print("checking input",b)
        if b == 0:
            print("Whoops! b=0, doing nothing")
            return

        return func(a,b)
    return inner
In [108]:
# functional approach to using

fixeddivide = zeroblocker(divide)

print(fixeddivide(2,0))
checking input 0
Whoops! b=0, doing nothing
None
In [110]:
# "pythonic" way that implements at function definition time to simplify code

@zeroblocker
def divide(a,b):
    return a/b

print(divide(2,0))
checking input 0
Whoops! b=0, doing nothing
None
In [112]:
# example with nested ("chained") decorators and variable number of args

def star(func):
    def inner(*args, **kwargs):
        print("*" * 30)
        func(*args, **kwargs)
        print("*" * 30)
    return inner

def percent(func):
    def inner(*args, **kwargs):
        print("%" * 30)
        func(*args, **kwargs)
        print("%" * 30)
    return inner

@star
@percent
def printer(msg):
    print(msg)
printer("Hello")
******************************
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
Hello
%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
******************************

...other demos...

In [6]:
import torch
from torch.autograd import Variable
a = Variable(torch.rand(1, 4), requires_grad=True)
b = a**2
c = b*2
d = c.mean()
e = c.sum()
In [7]:
a
Out[7]:
tensor([[0.7947, 0.8047, 0.8403, 0.8212]], requires_grad=True)
In [8]:
b
Out[8]:
tensor([[0.6315, 0.6475, 0.7061, 0.6743]], grad_fn=<PowBackward0>)
In [9]:
c
Out[9]:
tensor([[1.2630, 1.2950, 1.4123, 1.3486]], grad_fn=<MulBackward0>)
In [17]:
%load_ext tensorboard
In [19]:
%tensorboard --logdir logs
ERROR: Timed out waiting for TensorBoard to start. It may still be running as pid 2696.

GraphViz

Install GraphViz via anaconda

Also need (or perhaps only need) python-graphviz

https://github.com/szagoruyko/pytorchviz ??

In [26]:
make_dot(e)
Out[26]:
%3 2168684259848 SumBackward0 2168715797960 MulBackward0 2168715797960->2168684259848 2168733280264 PowBackward0 2168733280264->2168715797960 2168733280584 (1, 4) 2168733280584->2168733280264
In [10]:
 
In [21]:
import torch
from torch import nn
In [25]:
model = nn.Sequential()
model.add_module('W0', nn.Linear(8, 16))
model.add_module('tanh', nn.Tanh())
model.add_module('W1', nn.Linear(16, 1))

x = torch.randn(1,8)

make_dot(model(x))
Out[25]:
%3 2168733280008 AddmmBackward 2168733279048 (1) 2168733279048->2168733280008 2168733280520 TanhBackward 2168733280520->2168733280008 2168733281352 AddmmBackward 2168733281352->2168733280520 2168733281160 (16) 2168733281160->2168733281352 2168733281864 TBackward 2168733281864->2168733281352 2168733736392 (16, 8) 2168733736392->2168733281864 2168733282056 TBackward 2168733282056->2168733280008 2168733278792 (1, 16) 2168733278792->2168733282056

TorchViz

Just pasting the code below rather than installing and importing it as a module

https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py

Run this before above calls to make_dot()

In [24]:
# https://github.com/szagoruyko/pytorchviz/blob/master/torchviz/dot.py

from collections import namedtuple
from distutils.version import LooseVersion
from graphviz import Digraph
import torch
from torch.autograd import Variable

Node = namedtuple('Node', ('name', 'inputs', 'attr', 'op'))


def make_dot(var, params=None):
    """ Produces Graphviz representation of PyTorch autograd graph.
    Blue nodes are the Variables that require grad, orange are Tensors
    saved for backward in torch.autograd.Function
    Args:
        var: output Variable
        params: dict of (name, Variable) to add names to node that
            require grad (TODO: make optional)
    """
    if params is not None:
        assert all(isinstance(p, Variable) for p in params.values())
        param_map = {id(v): k for k, v in params.items()}

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')
    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))
    seen = set()

    def size_to_str(size):
        return '(' + (', ').join(['%d' % v for v in size]) + ')'

    output_nodes = (var.grad_fn,) if not isinstance(var, tuple) else tuple(v.grad_fn for v in var)

    def add_nodes(var):
        if var not in seen:
            if torch.is_tensor(var):
                # note: this used to show .saved_tensors in pytorch0.2, but stopped
                # working as it was moved to ATen and Variable-Tensor merged
                dot.node(str(id(var)), size_to_str(var.size()), fillcolor='orange')
            elif hasattr(var, 'variable'):
                u = var.variable
                name = param_map[id(u)] if params is not None else ''
                node_name = '%s\n %s' % (name, size_to_str(u.size()))
                dot.node(str(id(var)), node_name, fillcolor='lightblue')
            elif var in output_nodes:
                dot.node(str(id(var)), str(type(var).__name__), fillcolor='darkolivegreen1')
            else:
                dot.node(str(id(var)), str(type(var).__name__))
            seen.add(var)
            if hasattr(var, 'next_functions'):
                for u in var.next_functions:
                    if u[0] is not None:
                        dot.edge(str(id(u[0])), str(id(var)))
                        add_nodes(u[0])
            if hasattr(var, 'saved_tensors'):
                for t in var.saved_tensors:
                    dot.edge(str(id(t)), str(id(var)))
                    add_nodes(t)

    # handle multiple outputs
    if isinstance(var, tuple):
        for v in var:
            add_nodes(v.grad_fn)
    else:
        add_nodes(var.grad_fn)

    resize_graph(dot)

    return dot


# For traces

def replace(name, scope):
    return '/'.join([scope[name], name])


def parse(graph):
    scope = {}
    for n in graph.nodes():
        inputs = [i.uniqueName() for i in n.inputs()]
        for i in range(1, len(inputs)):
            scope[inputs[i]] = n.scopeName()

        uname = next(n.outputs()).uniqueName()
        assert n.scopeName() != '', '{} has empty scope name'.format(n)
        scope[uname] = n.scopeName()
    scope['0'] = 'input'

    nodes = []
    for n in graph.nodes():
        attrs = {k: n[k] for k in n.attributeNames()}
        attrs = str(attrs).replace("'", ' ')
        inputs = [replace(i.uniqueName(), scope) for i in n.inputs()]
        uname = next(n.outputs()).uniqueName()
        nodes.append(Node(**{'name': replace(uname, scope),
                             'op': n.kind(),
                             'inputs': inputs,
                             'attr': attrs}))

    for n in graph.inputs():
        uname = n.uniqueName()
        if uname not in scope.keys():
            scope[uname] = 'unused'
        nodes.append(Node(**{'name': replace(uname, scope),
                             'op': 'Parameter',
                             'inputs': [],
                             'attr': str(n.type())}))

    return nodes


def make_dot_from_trace(trace):
    """ Produces graphs of torch.jit.trace outputs
    Example:
    >>> trace, = torch.jit.trace(model, args=(x,))
    >>> dot = make_dot_from_trace(trace)
    """
    # from tensorboardX
    if LooseVersion(torch.__version__) >= LooseVersion("0.4.1"):
        torch.onnx._optimize_trace(trace, torch._C._onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK)
    elif LooseVersion(torch.__version__) >= LooseVersion("0.4"):
        torch.onnx._optimize_trace(trace, False)
    else:
        torch.onnx._optimize_trace(trace)
    graph = trace.graph()
    list_of_nodes = parse(graph)

    node_attr = dict(style='filled',
                     shape='box',
                     align='left',
                     fontsize='12',
                     ranksep='0.1',
                     height='0.2')

    dot = Digraph(node_attr=node_attr, graph_attr=dict(size="12,12"))

    for node in list_of_nodes:
        dot.node(node.name, label=node.name.replace('/', '\n'))
        if node.inputs:
            for inp in node.inputs:
                dot.edge(inp, node.name)

    resize_graph(dot)

    return dot


def resize_graph(dot, size_per_element=0.15, min_size=12):
    """Resize the graph according to how much content it contains.
    Modify the graph in place.
    """
    # Get the approximate number of nodes and edges
    num_rows = len(dot.body)
    content_size = num_rows * size_per_element
    size = max(min_size, content_size)
    size_str = str(size) + "," + str(size)
    dot.graph_attr.update(size=size_str)
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]:
 
In [ ]: