Source code for torch_brain.nn.feedforward
import torch.nn as nn
import torch.nn.functional as F
class GEGLU(nn.Module):
"""Gated Gaussian Error Linear Unit (GEGLU) activation function, as introduced in
the paper "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202).
"""
def forward(self, x):
x, gates = x.chunk(2, dim=-1)
return x * F.gelu(gates)
[docs]
class FeedForward(nn.Module):
"""A feed-forward network with GEGLU activation.
Args:
dim (int): Input and output dimension
mult (int, optional): Multiplier for hidden dimension. Defaults to 4
dropout (float, optional): Dropout probability. Defaults to 0.2
"""
def __init__(self, dim, mult=4, dropout=0.2):
super().__init__()
self.net = nn.Sequential(
nn.Linear(dim, dim * mult * 2),
GEGLU(),
nn.Dropout(p=dropout),
nn.Linear(dim * mult, dim),
)
[docs]
def forward(self, x):
return self.net(x)