class VirtualBatchNorm1d(Module):
    'Module for Virtual Batch Normalization.
    Implementation borrowed and modified from Rafael_Valle's code + help of SimonW from this discussion thread:
    https://discuss.pytorch.org/t/parameter-grad-of-conv-weight-is-none-after-virtual-batch-normalization/9036
    '

    def __init__(self, num_features, eps=1e-5):
        super().__init__()
        # batch statistics
        self.num_features = num_features
        self.eps = eps  # epsilon
        # define gamma and beta parameters
        self.gamma = Parameter(torch.normal(mean=1.0, std=0.02, size=(1, num_features, 1)))
        self.beta = Parameter(torch.zeros(1, num_features, 1))

    def get_stats(self, x):
        'Calculates mean and mean square for given batch x.
        Args:
            x: tensor containing batch of activations
        Returns:
            mean: mean tensor over features
            mean_sq: squared mean tensor over features
        '
        mean = x.mean(2, keepdim=True).mean(0, keepdim=True)
        mean_sq = (x ** 2).mean(2, keepdim=True).mean(0, keepdim=True)
        return mean, mean_sq

    def forward(self, x, ref_mean, ref_mean_sq):
        'Forward pass of virtual batch normalization.
        Virtual batch normalization require two forward passes
        for reference batch and train batch, respectively.
        Args:
            x: input tensor
            ref_mean: reference mean tensor over features
            ref_mean_sq: reference squared mean tensor over features
        Result:
            x: normalized batch tensor
            ref_mean: reference mean tensor over features
            ref_mean_sq: reference squared mean tensor over features
        '
        mean, mean_sq = self.get_stats(x)
        if ref_mean is None or ref_mean_sq is None:
            # reference mode - works just like batch norm
            mean = mean.clone().detach()
            mean_sq = mean_sq.clone().detach()
            out = self.normalize(x, mean, mean_sq)
        else:
            # calculate new mean and mean_sq
            batch_size = x.size(0)
            new_coeff = 1. / (batch_size + 1.)
            old_coeff = 1. - new_coeff
            mean = new_coeff * mean + old_coeff * ref_mean
            mean_sq = new_coeff * mean_sq + old_coeff * ref_mean_sq
            out = self.normalize(x, mean, mean_sq)
        return out, mean, mean_sq

    def normalize(self, x, mean, mean_sq):
        'Normalize tensor x given the statistics.
        Args:
            x: input tensor
            mean: mean over features
            mean_sq: squared means over features
        Result:
            x: normalized batch tensor
        '
        assert mean_sq is not None
        assert mean is not None
        assert len(x.size()) == 3  # specific for 1d VBN
        if mean.size(1) != self.num_features:
            raise Exception('Mean tensor size not equal to number of features : given {}, expected {}'
                            .format(mean.size(1), self.num_features))
        if mean_sq.size(1) != self.num_features:
            raise Exception('Squared mean tensor size not equal to number of features : given {}, expected {}'
                            .format(mean_sq.size(1), self.num_features))

        std = torch.sqrt(self.eps + mean_sq - mean ** 2)
        x = x - mean
        x = x / std
        x = x * self.gamma
        x = x + self.beta
        return x

    def __repr__(self):
        return ('{name}(num_features={num_features}, eps={eps}'
                .format(name=self.__class__.__name__, **self.__dict__))

# Code Analysis

This PyTorch module implements Virtual Batch Normalization. Virtual Batch Normalization is a variation of Batch Normalization that can normalize using a single example during training instead of a mini-batch. It requires two forward passes for the reference batch and the train batch, respectively. During training, it uses the statistics of the reference batch to normalize each training batch.

Here is a detailed analysis of the module:

- **__init__:** Initializes the module and defines the gamma and beta parameters, where gamma is the scaling factor for normalization and beta is the offset.
- **get_stats:** Calculates the mean and mean squared values for a given batch x, returning mean and mean_sq.
- **forward:** The forward pass of virtual batch normalization. If given reference mean and reference mean squared values, it calculates a new mean and mean squared values, uses them to normalize x, and returns it. If no reference mean and mean squared values are given, it only uses the given x, calculates mean and mean squared values, uses them to normalize x, and returns them for later use.
- **normalize:** Normalizes x using given mean and mean_sq. It first calculates the standard deviation and then uses it to normalize x, followed by scaling and offsetting x using the gamma and beta parameters. It returns the normalized x.
- **__repr__:** Returns a string representation of the module, including num_features and eps.

This code demonstrates the implementation and analysis of Virtual Batch Normalization in PyTorch. Understanding these aspects can be beneficial for optimizing the performance of deep learning models in scenarios where traditional batch normalization might not be suitable.


原文地址: https://www.cveoy.top/t/topic/n6BP 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录