PyTorch Notes

I have been using PyTorch for a while. There are a few things to note down when using them.

  • Functionals
  • Data Structure
  • GPU versus CPU
  • Memory
  • Pytorch-tensorboard
  • TLS import issue

Functionals

For me, functional is a new concept PyTorch brings in. Each network is-a functional. In other words, the value returned by the network construction process is a function, which acts like a network.

Data Structure

  1. The difference between Variable and Tensor need to be paid attention to. You can take derivative on a Variable, but not on a Tensor. There are different Tensor data types as well. Note that Tensor in PyTorch is closely related to numpy. Calling numpy() gets the numpy object, and passing in an numpy object initializes a Tensor.
    Updates since 0.4.0: there is no longer explicit Variable classes. In other words, Tensor classes contain histories. You need to call detach() to achieve same thing as .data.

  2. RNN structures in pytorch requires the input sequence be winded having the shape (seq_length, batch_length, data_dimension). I think it would be easier to use if the first two parameters switch.

GPU versus CPU

  1. PyTorch has its built-in CUDA, but you need to install NVidia driver on host machine anyways. There are several combos that look workable, depending on which environment you are using.

    1
    2
    3
    Choice A: virtualenv + bash scripts  
    Choice B: docker + nvidia-docker tool + CUDA. Can modify Dockerfile from pytorch github to suit for your computer.
    Choice C: conda
  2. Running a model on GPU can be much faster than doing on CPU. To run a model on GPU:

  • Shift all Variables to GPU with v = v.gpu() immediately after creating them.
  • Shift all nn.Modules to GPU with model.gpu() after it is initialized.
  • Consider using DataParallel if you have multiple GPUs.
  • Yank out the values from output to CPU during evaluation. You don’t need to shift to CPU during training.

Memory

  1. PyTorch can take up increasing memory if you run RNN. As for November 1, the solution is to disable its cudnn backend:

    1
    torch.backends.cudnn.enabled = False
  2. There may still be memory leaks. Consider explicitly handling garbage collection mechanism in Python:

    1
    2
    import gc
    gc.collect() # Invoke this periodically
  3. Consider using Truncated Back Prop Through Time (TBPTT). To cut off connections of a variable from its history, yank out its data:

    1
    2
    from torch.autograd import Variable
    next_var = Variable(prev_var.data)

Tensorboard

TensorBoard is too good not to use. Although there is no official integrations yet, we can use open source plugins. I am using tensorboard-pytorch plugin, available from github.

TLS import issue

As of December 2017: When running a large project in docker involving the import of PyTorch and other python libraries (matplotlib in my case), there came the TLS error problem. The solution was to reverse the sequence of import of pytorch and matplotlib.