Gated Residual Network (GRN) Implementation in PyTorch

The Gated Residual Network (GRN) offers a flexible approach to deep learning by selectively applying non-linear processing only when necessary. This allows the model to adapt to different data complexities and potentially improve performance compared to always applying non-linear transformations. This article provides a comprehensive implementation of GRN in PyTorch, explaining its structure and functionality.

GRN Architecture

The core idea behind GRN is to combine a linear skip connection with a gated non-linear transformation. This allows the model to either pass the input directly through the skip connection or apply a non-linear transformation based on the gate's decision. The GRN equation is defined as follows:

GRN(a, c) = LayerNorm(a + GLU(eta_1)) # Dropout is applied to eta_1
eta_1 = W_1*eta_2 + b_1
eta_2 = ELU(W_2*a + W_3*c + b_2)

Where:

  • a represents the input vector
  • c is an optional context vector
  • W_1, W_2, W_3 are weight matrices
  • b_1, b_2 are bias vectors
  • ELU is the Exponential Linear Unit activation function
  • GLU is the Gated Linear Unit, which controls the flow of information through the non-linear transformation
  • LayerNorm performs layer normalization for improved stability

PyTorch Implementation

class GatedResidualNetwork(nn.Module):
    '''
      The Gated Residual Network gives the model flexibility to apply non-linear
      processing only when needed. It is difficult to know beforehand which
      variables are relevant and in some cases simpler models can be beneficial.

      GRN(a, c) = LayerNorm(a + GLU(eta_1)) # Dropout is applied to eta_1
        eta_1 = W_1*eta_2 + b_1
        eta_2 = ELU(W_2*a + W_3*c + b_2)
      
      Args:
          input_size (int): Size of the input
          hidden_size (int): Size of the hidden layer
          output_size (int): Size of the output layer
          dropout (float): Fraction between 0 and 1 corresponding to the degree of dropout used
          context_size (int): Size of the static context vector
          is_temporal (bool): Flag to decide if TemporalLayer has to be used or not
    '''
    def __init__(self, input_size, hidden_size, output_size, dropout, context_size=None, is_temporal=True):
        super().__init__()

        self.input_size = input_size
        self.output_size = output_size
        self.context_size = context_size
        self.hidden_size = hidden_size
        self.dropout = dropout
        self.is_temporal = is_temporal
        
        if self.is_temporal:
            if self.input_size != self.output_size:
                self.skip_layer = TemporalLayer(nn.Linear(self.input_size, self.output_size))

            # Context vector c
            if self.context_size != None:
                self.c = TemporalLayer(nn.Linear(self.context_size, self.hidden_size, bias=False))

            # Dense & ELU
            self.dense1 = TemporalLayer(nn.Linear(self.input_size, self.hidden_size))
            self.elu = nn.ELU()

            # Dense & Dropout
            self.dense2 = TemporalLayer(nn.Linear(self.hidden_size,  self.output_size))
            self.dropout = nn.Dropout(self.dropout)

            # Gate, Add & Norm
            self.gate = TemporalLayer(GLU(self.output_size))
            self.layer_norm = TemporalLayer(nn.BatchNorm1d(self.output_size))

        else:
            if self.input_size != self.output_size:
                self.skip_layer = nn.Linear(self.input_size, self.output_size)

            # Context vector c
            if self.context_size != None:
                self.c = nn.Linear(self.context_size, self.hidden_size, bias=False)

            # Dense & ELU
            self.dense1 = nn.Linear(self.input_size, self.hidden_size)
            self.elu = nn.ELU()

            # Dense & Dropout
            self.dense2 = nn.Linear(self.hidden_size,  self.output_size)
            self.dropout = nn.Dropout(self.dropout)

            # Gate, Add & Norm
            self.gate = GLU(self.output_size)
            self.layer_norm = nn.BatchNorm1d(self.output_size)


    def forward(self, x, c=None):
        '''
        Args:
            x (torch.tensor): tensor thas passes through the GRN
            c (torch.tensor): Optional static context vector
        '''

        if self.input_size!=self.output_size:
            a = self.skip_layer(x)
        else:
            a = x
        
        x = self.dense1(x)

        if c != None:
            c = self.c(c.unsqueeze(1))
            x += c

        eta_2 = self.elu(x)
        
        eta_1 = self.dense2(eta_2)
        eta_1 = self.dropout(eta_1)

        gate = self.gate(eta_1)
        gate += a
        x = self.layer_norm(gate)
        
        return x
class TemporalLayer(nn.Module):
    def __init__(self, module):
        super().__init__()
        '''
        Collapses input of dim T*N*H to (T*N)*H, and applies to a module.
        Allows handling of variable sequence lengths and minibatch sizes.

        Similar to TimeDistributed in Keras, it is a wrapper that makes it possible
        to apply a layer to every temporal slice of an input.
        '''
        self.module = module


    def forward(self, x):
        '''
        Args:
            x (torch.tensor): tensor with time steps to pass through the same layer.
        '''
        t, n = x.size(0), x.size(1)
        x = x.reshape(t * n, -1)
        x = self.module(x)
        x = x.reshape(t, n, x.size(-1))

        return x

The GatedResidualNetwork class implements the GRN architecture, taking the input size, hidden size, output size, dropout rate, optional context size, and a flag indicating whether the network should be temporal (for handling sequences) as input arguments.

The TemporalLayer class is a wrapper for applying a module (e.g., nn.Linear) to each temporal slice of an input tensor. This enables the handling of variable sequence lengths and minibatch sizes.

Addressing the Error

The error "two matrices cannot be multiplied" typically occurs when the input dimensions are incompatible for matrix multiplication. In the TemporalLayer class, the error points to the line x = self.module(x). This line attempts to apply the module to the reshaped input x. To fix this error, you should check the following:

  1. Module Type: Ensure that the module passed to the TemporalLayer is indeed a module that can perform linear transformation (e.g., nn.Linear). If not, you need to replace it with an appropriate module.
  2. Input Shape: Verify that the input x to the TemporalLayer.forward method has the correct shape of (T, N, H). If not, reshape x accordingly.
  3. Output Shape: After the linear transformation, check if the output from self.module(x) has the correct shape of (T, N, H'). If not, reshape it to match the expected output shape.

By addressing these points, you can ensure that the dimensions are compatible for matrix multiplication and resolve the error.

Conclusion

The Gated Residual Network provides a flexible and powerful architecture for deep learning models. By selectively applying non-linear transformations, GRN can adapt to different data complexities and potentially improve performance. This PyTorch implementation allows you to incorporate GRN into your own deep learning projects, enabling you to benefit from its advantages. Remember to carefully check the dimensions and module types to ensure proper functioning and avoid errors.

Gated Residual Network (GRN) Implementation in PyTorch: A Comprehensive Guide

原文地址: http://www.cveoy.top/t/topic/lRqN 著作权归作者所有。请勿转载和采集!

免费AI点我,无需注册和登录