Pytorch Questions

This notebook compiles experiments to three questions wrt pytorch usage.

  1. Memory overflow: what if you don’t GC the graph?
  2. Gradient zeroing: when to do it, and why necessary?
  3. Backward second time: when will this be triggered?
1
2
3
4
5
6
import gc
import numpy as np
import os
import psutil
import torch
import torch.nn as nn

Memory Overflow

1
2
3
4
5
def print_memory_usage(label):
process = psutil.Process(os.getpid())
#print(process.memory_info())
mem = process.memory_info().rss / 1024 / 1024
print ("{} using {:.2f} MB memory!".format(label, mem))
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def memory_overflow():

network = nn.Sequential(
nn.Linear(4, 10000), nn.Linear(10000, 10000), nn.Linear(10000, 2))
losses = []
for epoch in range(1, 101):
x = torch.tensor([0,1,2,3]).float().unsqueeze(0)
y = network(x)
loss = y * y
losses.append(loss.mean()) # A reference to the previous graphs are not freed.
if epoch % 10 == 0:
print_memory_usage("Epoch {}".format(epoch))

gc.collect()
memory_overflow()
Epoch 10 using 513.77 MB memory!
Epoch 20 using 514.53 MB memory!
Epoch 30 using 515.36 MB memory!
Epoch 40 using 516.20 MB memory!
Epoch 50 using 517.04 MB memory!
Epoch 60 using 517.87 MB memory!
Epoch 70 using 518.71 MB memory!
Epoch 80 using 519.55 MB memory!
Epoch 90 using 520.39 MB memory!
Epoch 100 using 521.23 MB memory!
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
def no_memory_overflow():

network = nn.Sequential(
nn.Linear(4, 10000), nn.Linear(10000, 10000), nn.Linear(10000, 2))
losses = []
for epoch in range(1, 101):
x = torch.tensor([0,1,2,3]).float().unsqueeze(0)
y = network(x)
loss = y * y
losses.append(loss.mean().item()) # The previous graph could be discarded
if epoch % 10 == 0:
print_memory_usage("Epoch {}".format(epoch))

gc.collect()
no_memory_overflow()
Epoch 10 using 513.01 MB memory!
Epoch 20 using 513.03 MB memory!
Epoch 30 using 513.03 MB memory!
Epoch 40 using 513.03 MB memory!
Epoch 50 using 513.03 MB memory!
Epoch 60 using 513.03 MB memory!
Epoch 70 using 513.04 MB memory!
Epoch 80 using 513.04 MB memory!
Epoch 90 using 513.04 MB memory!
Epoch 100 using 513.04 MB memory!

Gradient Zeroing

This discussion could be relevant.
If you don’t zero out the gradient, you effectively increased the batch size. This could be useful when you want to train very large batch of data. An example is in Open NMT (in its —accum_count option).

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def test_gradient_accumulation():
x = torch.ones((1), requires_grad=False) # requires_grad is default to False
w = torch.ones((1), requires_grad=True)
y = w * x
print ("Before backward(): w.grad_={}".format(w.grad))
y.backward()
print ("After backward(): w.grad_={} (should be 1)".format(w.grad))

y = w * x
y.backward()
print ("Second pass: w.grad={} (accumulated to 2)".format(w.grad))

w.grad.data.zero_()
print ("===I zero'ed the grad===")
y = w * x
y.backward()
print ("Now: w.grad={} (back to 1, the correct value)".format(w.grad))

test_gradient_accumulation()
Before backward(): w.grad_=None
After backward():  w.grad_=tensor([1.]) (should be 1)
Second pass:       w.grad=tensor([2.]) (accumulated to 2)
===I zero'ed the grad===
Now:               w.grad=tensor([1.]) (back to 1, the correct value)

Backward second time

This stackoverflow question is relevant.
See this discussion for when a node is freed and when it is not.

Basically, this depends on your node types:

  • The root node (loss) does not have gradient (dL/dL is not useful).
  • The leaves nodes (weights) have their gradients accumulated if you do the second backprop.
  • Your intermediate nodes (all other nodes) are free’d during backward() unless you backward(retain_graph=True). This causes the “Backward Second Time” error.
1
2
3
4
5
6
7
8
9
10
11
12
13
def test_backprop_second_time():
x = torch.ones((1), requires_grad=False)
w = torch.ones((1), requires_grad=True)
y = w * x # y is not a leaf node
loss = y * y
loss.backward()
print ("After backward(): x={}, x.grad={}".format(x, x.grad))
print ("After backward(): w={}, w.grad={}".format(w, w.grad))
print ("After backward(): y={}, y.grad={}".format(y, y.grad))
print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad))
loss.backward()

test_backprop_second_time()
After backward():  x=tensor([1.]), x.grad=None
After backward():  w=tensor([1.], requires_grad=True), w.grad=tensor([2.])
After backward():  y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward():  loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None



---------------------------------------------------------------------------

RuntimeError                              Traceback (most recent call last)

<ipython-input-47-6b7e5cd304a0> in <module>
     11     loss.backward()
     12 
---> 13 test_backprop_second_time()


<ipython-input-47-6b7e5cd304a0> in test_backprop_second_time()
      9     print ("After backward():  y={}, y.grad={}".format(y, y.grad))
     10     print ("After backward():  loss={}, loss.grad={}".format(loss, loss.grad))
---> 11     loss.backward()
     12 
     13 test_backprop_second_time()


~/anaconda3/envs/pytorch12/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
    193                 products. Defaults to ``False``.
    194         """
--> 195         torch.autograd.backward(self, gradient, retain_graph, create_graph)
    196 
    197     def register_hook(self, hook):


~/anaconda3/envs/pytorch12/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
     95         retain_graph = create_graph
     96 
---> 97     Variable._execution_engine.run_backward(
     98         tensors, grad_tensors, retain_graph, create_graph,
     99         allow_unreachable=True)  # allow_unreachable flag


RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
def test_backprop_second_time():
x = torch.ones((1), requires_grad=False)
w = torch.ones((1), requires_grad=True)
y = w * x # y is not a leaf node
loss = y * y
loss.backward(retain_graph=True)
print ("=== backward() with retain_graph ===")
print ("After backward(): x={}, x.grad={}".format(x, x.grad))
print ("After backward(): w={}, w.grad={}".format(w, w.grad))
print ("After backward(): y={}, y.grad={}".format(y, y.grad))
print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad))
loss.backward()
print ("=== backward() again ===")
print ("After backward(): x={}, x.grad={}".format(x, x.grad))
print ("After backward(): w={}, w.grad={} (you can see gradient accumulation too)".format(w, w.grad))
print ("After backward(): y={}, y.grad={}".format(y, y.grad))
print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad))

test_backprop_second_time()
=== backward() with retain_graph ===
After backward():  x=tensor([1.]), x.grad=None
After backward():  w=tensor([1.], requires_grad=True), w.grad=tensor([2.])
After backward():  y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward():  loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None
=== backward() again ===
After backward():  x=tensor([1.]), x.grad=None
After backward():  w=tensor([1.], requires_grad=True), w.grad=tensor([4.]) (you can see gradient accumulation too)
After backward():  y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward():  loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None