I have been using PyTorch for a while. There are a few things to note down when using them.
- Functionals
- Data Structure
- GPU versus CPU
- Memory
- Pytorch-tensorboard
- TLS import issue
Functionals
For me, functional is a new concept PyTorch brings in. Each network is-a functional. In other words, the value returned by the network construction process is a function, which acts like a network.
Data Structure
The difference between
Variable
andTensor
need to be paid attention to. You can take derivative on aVariable
, but not on aTensor
. There are differentTensor
data types as well. Note thatTensor
in PyTorch is closely related tonumpy
. Callingnumpy()
gets the numpy object, and passing in an numpy object initializes aTensor
.
Updates since0.4.0
: there is no longer explicitVariable
classes. In other words,Tensor
classes contain histories. You need to calldetach()
to achieve same thing as.data
.RNN structures in
pytorch
requires the input sequence be winded having the shape(seq_length, batch_length, data_dimension)
. I think it would be easier to use if the first two parameters switch.
GPU versus CPU
PyTorch has its built-in CUDA, but you need to install NVidia driver on host machine anyways. There are several combos that look workable, depending on which environment you are using.
1
2
3Choice A: virtualenv + bash scripts
Choice B: docker + nvidia-docker tool + CUDA. Can modify Dockerfile from pytorch github to suit for your computer.
Choice C: condaRunning a model on GPU can be much faster than doing on CPU. To run a model on GPU:
- Shift all
Variable
s to GPU withv = v.gpu()
immediately after creating them. - Shift all
nn.Module
s to GPU withmodel.gpu()
after it is initialized. - Consider using
DataParallel
if you have multiple GPUs. - Yank out the values from output to CPU during evaluation. You don’t need to shift to CPU during training.
Memory
PyTorch can take up increasing memory if you run RNN. As for November 1, the solution is to disable its cudnn backend:
1
torch.backends.cudnn.enabled = False
There may still be memory leaks. Consider explicitly handling garbage collection mechanism in Python:
1
2import gc
gc.collect() # Invoke this periodicallyConsider using Truncated Back Prop Through Time (TBPTT). To cut off connections of a variable from its history, yank out its data:
1
2from torch.autograd import Variable
next_var = Variable(prev_var.data)
Tensorboard
TensorBoard is too good not to use. Although there is no official integrations yet, we can use open source plugins. I am using tensorboard-pytorch
plugin, available from github.
TLS import issue
As of December 2017: When running a large project in docker involving the import of PyTorch and other python libraries (matplotlib
in my case), there came the TLS error problem. The solution was to reverse the sequence of import of pytorch
and matplotlib
.