importtorch.nnasnnimporttorch.nn.functionalasFclassGEGLU(nn.Module):"""Gated Gaussian Error Linear Unit (GEGLU) activation function, as introduced in the paper "GLU Variants Improve Transformer" (https://arxiv.org/abs/2002.05202). """defforward(self,x):x,gates=x.chunk(2,dim=-1)returnx*F.gelu(gates)
[docs]classFeedForward(nn.Module):"""A feed-forward network with GEGLU activation. Args: dim (int): Input and output dimension mult (int, optional): Multiplier for hidden dimension. Defaults to 4 dropout (float, optional): Dropout probability. Defaults to 0.2 """def__init__(self,dim,mult=4,dropout=0.2):super().__init__()self.net=nn.Sequential(nn.Linear(dim,dim*mult*2),GEGLU(),nn.Dropout(p=dropout),nn.Linear(dim*mult,dim),)