This notebook compiles experiments to three questions wrt pytorch usage.
Memory overflow: what if you don’t GC the graph?
Gradient zeroing: when to do it, and why necessary?
Backward second time: when will this be triggered?
1 2 3 4 5 6
import gc import numpy as np import os import psutil import torch import torch.nn as nn
Memory Overflow
1 2 3 4 5
defprint_memory_usage(label): process = psutil.Process(os.getpid()) #print(process.memory_info()) mem = process.memory_info().rss / 1024 / 1024 print ("{} using {:.2f} MB memory!".format(label, mem))
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
defmemory_overflow(): network = nn.Sequential( nn.Linear(4, 10000), nn.Linear(10000, 10000), nn.Linear(10000, 2)) losses = [] for epoch inrange(1, 101): x = torch.tensor([0,1,2,3]).float().unsqueeze(0) y = network(x) loss = y * y losses.append(loss.mean()) # A reference to the previous graphs are not freed. if epoch % 10 == 0: print_memory_usage("Epoch {}".format(epoch)) gc.collect() memory_overflow()
Epoch 10 using 513.77 MB memory!
Epoch 20 using 514.53 MB memory!
Epoch 30 using 515.36 MB memory!
Epoch 40 using 516.20 MB memory!
Epoch 50 using 517.04 MB memory!
Epoch 60 using 517.87 MB memory!
Epoch 70 using 518.71 MB memory!
Epoch 80 using 519.55 MB memory!
Epoch 90 using 520.39 MB memory!
Epoch 100 using 521.23 MB memory!
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15
defno_memory_overflow(): network = nn.Sequential( nn.Linear(4, 10000), nn.Linear(10000, 10000), nn.Linear(10000, 2)) losses = [] for epoch inrange(1, 101): x = torch.tensor([0,1,2,3]).float().unsqueeze(0) y = network(x) loss = y * y losses.append(loss.mean().item()) # The previous graph could be discarded if epoch % 10 == 0: print_memory_usage("Epoch {}".format(epoch)) gc.collect() no_memory_overflow()
Epoch 10 using 513.01 MB memory!
Epoch 20 using 513.03 MB memory!
Epoch 30 using 513.03 MB memory!
Epoch 40 using 513.03 MB memory!
Epoch 50 using 513.03 MB memory!
Epoch 60 using 513.03 MB memory!
Epoch 70 using 513.04 MB memory!
Epoch 80 using 513.04 MB memory!
Epoch 90 using 513.04 MB memory!
Epoch 100 using 513.04 MB memory!
Gradient Zeroing
This discussion could be relevant. If you don’t zero out the gradient, you effectively increased the batch size. This could be useful when you want to train very large batch of data. An example is in Open NMT (in its —accum_count option).
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
deftest_gradient_accumulation(): x = torch.ones((1), requires_grad=False) # requires_grad is default to False w = torch.ones((1), requires_grad=True) y = w * x print ("Before backward(): w.grad_={}".format(w.grad)) y.backward() print ("After backward(): w.grad_={} (should be 1)".format(w.grad)) y = w * x y.backward() print ("Second pass: w.grad={} (accumulated to 2)".format(w.grad)) w.grad.data.zero_() print ("===I zero'ed the grad===") y = w * x y.backward() print ("Now: w.grad={} (back to 1, the correct value)".format(w.grad)) test_gradient_accumulation()
Before backward(): w.grad_=None
After backward(): w.grad_=tensor([1.]) (should be 1)
Second pass: w.grad=tensor([2.]) (accumulated to 2)
===I zero'ed the grad===
Now: w.grad=tensor([1.]) (back to 1, the correct value)
The root node (loss) does not have gradient (dL/dL is not useful).
The leaves nodes (weights) have their gradients accumulated if you do the second backprop.
Your intermediate nodes (all other nodes) are free’d during backward() unless you backward(retain_graph=True). This causes the “Backward Second Time” error.
1 2 3 4 5 6 7 8 9 10 11 12 13
deftest_backprop_second_time(): x = torch.ones((1), requires_grad=False) w = torch.ones((1), requires_grad=True) y = w * x # y is not a leaf node loss = y * y loss.backward() print ("After backward(): x={}, x.grad={}".format(x, x.grad)) print ("After backward(): w={}, w.grad={}".format(w, w.grad)) print ("After backward(): y={}, y.grad={}".format(y, y.grad)) print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad)) loss.backward() test_backprop_second_time()
After backward(): x=tensor([1.]), x.grad=None
After backward(): w=tensor([1.], requires_grad=True), w.grad=tensor([2.])
After backward(): y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward(): loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None
---------------------------------------------------------------------------
RuntimeError Traceback (most recent call last)
<ipython-input-47-6b7e5cd304a0> in <module>
11 loss.backward()
12
---> 13 test_backprop_second_time()
<ipython-input-47-6b7e5cd304a0> in test_backprop_second_time()
9 print ("After backward(): y={}, y.grad={}".format(y, y.grad))
10 print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad))
---> 11 loss.backward()
12
13 test_backprop_second_time()
~/anaconda3/envs/pytorch12/lib/python3.8/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
193 products. Defaults to ``False``.
194 """
--> 195 torch.autograd.backward(self, gradient, retain_graph, create_graph)
196
197 def register_hook(self, hook):
~/anaconda3/envs/pytorch12/lib/python3.8/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
95 retain_graph = create_graph
96
---> 97 Variable._execution_engine.run_backward(
98 tensors, grad_tensors, retain_graph, create_graph,
99 allow_unreachable=True) # allow_unreachable flag
RuntimeError: Trying to backward through the graph a second time, but the buffers have already been freed. Specify retain_graph=True when calling backward the first time.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19
deftest_backprop_second_time(): x = torch.ones((1), requires_grad=False) w = torch.ones((1), requires_grad=True) y = w * x # y is not a leaf node loss = y * y loss.backward(retain_graph=True) print ("=== backward() with retain_graph ===") print ("After backward(): x={}, x.grad={}".format(x, x.grad)) print ("After backward(): w={}, w.grad={}".format(w, w.grad)) print ("After backward(): y={}, y.grad={}".format(y, y.grad)) print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad)) loss.backward() print ("=== backward() again ===") print ("After backward(): x={}, x.grad={}".format(x, x.grad)) print ("After backward(): w={}, w.grad={} (you can see gradient accumulation too)".format(w, w.grad)) print ("After backward(): y={}, y.grad={}".format(y, y.grad)) print ("After backward(): loss={}, loss.grad={}".format(loss, loss.grad)) test_backprop_second_time()
=== backward() with retain_graph ===
After backward(): x=tensor([1.]), x.grad=None
After backward(): w=tensor([1.], requires_grad=True), w.grad=tensor([2.])
After backward(): y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward(): loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None
=== backward() again ===
After backward(): x=tensor([1.]), x.grad=None
After backward(): w=tensor([1.], requires_grad=True), w.grad=tensor([4.]) (you can see gradient accumulation too)
After backward(): y=tensor([1.], grad_fn=<MulBackward0>), y.grad=None
After backward(): loss=tensor([1.], grad_fn=<MulBackward0>), loss.grad=None