Virtual Batch Normalization in PyTorch: Implementation and Analysis
class VirtualBatchNorm1d(Module):
'Module for Virtual Batch Normalization.
Implementation borrowed and modified from Rafael_Valle's code + help of SimonW from this discussion thread:
https://discuss.pytorch.org/t/parameter-grad-of-conv-weight-is-none-after-virtual-batch-normalization/9036
'
def __init__(self, num_features, eps=1e-5):
super().__init__()
# batch statistics
self.num_features = num_features
self.eps = eps # epsilon
# define gamma and beta parameters
self.gamma = Parameter(torch.normal(mean=1.0, std=0.02, size=(1, num_features, 1)))
self.beta = Parameter(torch.zeros(1, num_features, 1))
def get_stats(self, x):
'Calculates mean and mean square for given batch x.
Args:
x: tensor containing batch of activations
Returns:
mean: mean tensor over features
mean_sq: squared mean tensor over features
'
mean = x.mean(2, keepdim=True).mean(0, keepdim=True)
mean_sq = (x ** 2).mean(2, keepdim=True).mean(0, keepdim=True)
return mean, mean_sq
def forward(self, x, ref_mean, ref_mean_sq):
'Forward pass of virtual batch normalization.
Virtual batch normalization require two forward passes
for reference batch and train batch, respectively.
Args:
x: input tensor
ref_mean: reference mean tensor over features
ref_mean_sq: reference squared mean tensor over features
Result:
x: normalized batch tensor
ref_mean: reference mean tensor over features
ref_mean_sq: reference squared mean tensor over features
'
mean, mean_sq = self.get_stats(x)
if ref_mean is None or ref_mean_sq is None:
# reference mode - works just like batch norm
mean = mean.clone().detach()
mean_sq = mean_sq.clone().detach()
out = self.normalize(x, mean, mean_sq)
else:
# calculate new mean and mean_sq
batch_size = x.size(0)
new_coeff = 1. / (batch_size + 1.)
old_coeff = 1. - new_coeff
mean = new_coeff * mean + old_coeff * ref_mean
mean_sq = new_coeff * mean_sq + old_coeff * ref_mean_sq
out = self.normalize(x, mean, mean_sq)
return out, mean, mean_sq
def normalize(self, x, mean, mean_sq):
'Normalize tensor x given the statistics.
Args:
x: input tensor
mean: mean over features
mean_sq: squared means over features
Result:
x: normalized batch tensor
'
assert mean_sq is not None
assert mean is not None
assert len(x.size()) == 3 # specific for 1d VBN
if mean.size(1) != self.num_features:
raise Exception('Mean tensor size not equal to number of features : given {}, expected {}'
.format(mean.size(1), self.num_features))
if mean_sq.size(1) != self.num_features:
raise Exception('Squared mean tensor size not equal to number of features : given {}, expected {}'
.format(mean_sq.size(1), self.num_features))
std = torch.sqrt(self.eps + mean_sq - mean ** 2)
x = x - mean
x = x / std
x = x * self.gamma
x = x + self.beta
return x
def __repr__(self):
return ('{name}(num_features={num_features}, eps={eps}'
.format(name=self.__class__.__name__, **self.__dict__))
# Code Analysis
This PyTorch module implements Virtual Batch Normalization. Virtual Batch Normalization is a variation of Batch Normalization that can normalize using a single example during training instead of a mini-batch. It requires two forward passes for the reference batch and the train batch, respectively. During training, it uses the statistics of the reference batch to normalize each training batch.
Here is a detailed analysis of the module:
- **__init__:** Initializes the module and defines the gamma and beta parameters, where gamma is the scaling factor for normalization and beta is the offset.
- **get_stats:** Calculates the mean and mean squared values for a given batch x, returning mean and mean_sq.
- **forward:** The forward pass of virtual batch normalization. If given reference mean and reference mean squared values, it calculates a new mean and mean squared values, uses them to normalize x, and returns it. If no reference mean and mean squared values are given, it only uses the given x, calculates mean and mean squared values, uses them to normalize x, and returns them for later use.
- **normalize:** Normalizes x using given mean and mean_sq. It first calculates the standard deviation and then uses it to normalize x, followed by scaling and offsetting x using the gamma and beta parameters. It returns the normalized x.
- **__repr__:** Returns a string representation of the module, including num_features and eps.
This code demonstrates the implementation and analysis of Virtual Batch Normalization in PyTorch. Understanding these aspects can be beneficial for optimizing the performance of deep learning models in scenarios where traditional batch normalization might not be suitable.
原文地址: https://www.cveoy.top/t/topic/n6BP 著作权归作者所有。请勿转载和采集!