When building neural networks in PyTorch, we often need to store tensors that aren't meant to be trained.
The most common examples being std and mean for the batch norm (check the pytorch codebase
here
and look for register_buffer). When I first saw this being used frequently in multiple codebases,
I found it a bit confusing. After some small google search, I came across really good discussion in the pytorch
forum(see references).
Basics
In PyTorch, there are two main ways to store tensors in our model:
Parameters are learnable tensors that get updated during training. They require gradients
and are returned by model.parameters(), which means the optimizer will update them.
Buffers are fixed tensors that don't require gradients. They're not returned by
model.parameters(), so the optimizer ignores them. Think of them as constants or
non-learnable state that the model needs to remember.
Why Not Just Use self.my_tensor?
The first confusion I had was if I don't want something trained, can I just assign it directly like
self.my_tensor = torch.randn(1) inside my nn.Module? Technically yes,
but you'll run into problems.
Let's see what happens:
class MyModel(nn.Module):
def __init__(self):
super(MyModel, self).__init__()
self.my_tensor = torch.randn(1)
self.register_buffer('my_buffer', torch.randn(1))
self.my_param = nn.Parameter(torch.randn(1))
def forward(self, x):
return x
model = MyModel()
print(model.my_tensor)
# tensor([0.9329])
print(model.state_dict())
# OrderedDict([('my_param', tensor([-0.2471])),
# ('my_buffer', tensor([1.2112]))])
Notice that my_tensor doesn't appear in the state_dict(). This means if we
save and load the model, my_tensor won't be restored. we will lose that state. Now see
what happens when we move the model to GPU:
model.cuda()
print(model.my_tensor)
# tensor([0.9329]) # Still on CPU!
print(model.state_dict())
# OrderedDict([('my_param', tensor([-0.2471], device='cuda:0')),
# ('my_buffer', tensor([1.2112], device='cuda:0'))])
The buffer and parameter moved to CUDA, but my_tensor stayed on CPU. This will cause
errors when you try to use it in forward passes.
major benefits of register_buffer:
- Automatically included in
state_dict()for saving/loading - Moved to the correct device when you call
model.cuda()ormodel.to(device) - Makes your code's intent clear to other developers
Why Not Use nn.Parameter with requires_grad=False?
Another confusion that I had was: why not just use nn.Parameter(tensor, requires_grad=False)
for buffers?
This technically works for training, but it's confusing and inefficient:
- Misleading code: Other developers expect parameters to be learnable. Seeing non-learnable "parameters" is confusing.
- Optimizer overhead: If you pass
model.parameters()to your optimizer, it includes these fake parameters. The optimizer has to check and skip them on every step, which wastes computation. - Code clarity: Separating buffers and parameters makes your intent obvious at a glance.
Examples
Here's a simple example of batch normalization where we track running statistics:
class SimpleBatchNorm(nn.Module):
def __init__(self, num_features):
super(SimpleBatchNorm, self).__init__()
self.weight = nn.Parameter(torch.ones(num_features))
self.bias = nn.Parameter(torch.zeros(num_features))
# These are not learned, but need to be saved and moved to device
self.register_buffer('running_mean', torch.zeros(num_features))
self.register_buffer('running_var', torch.ones(num_features))
self.register_buffer('num_batches_tracked', torch.tensor(0))
def forward(self, x):
# Use running_mean and running_var for normalization
# Update them during training
return normalized_output
The weights and biases are parameters (learned). The running statistics are buffers (tracked but not learned).
When to Use register_parameter
We also see register_parameter() in PyTorch code. It works similarly to directly assigning
an nn.Parameter, but we pass the name as a string:
self.my_param = nn.Parameter(torch.randn(10))
# vs
self.register_parameter('my_param', nn.Parameter(torch.randn(10)))
Both do the same thing. The register_parameter approach is handy when we create parameters
in a loop or need dynamic naming:
for i in range(5):
self.register_parameter(f'weight_{i}', nn.Parameter(torch.randn(10)))
Otherwise, it's just a style choice.
Summary
- Use nn.Parameter for learnable tensors that require gradients
- Use register_buffer for non-learnable tensors that need to be saved and moved with the model
- Don't use plain attributes like
self.my_tensorfor anything you want preserved in the model state - Don't use
nn.Parameter(requires_grad=False)as a hack for buffers