[docs]classLoss(torch.nn.Module,ABC):r"""Base class for losses. All losses should support an optional weights argument."""def__init__(self):super().__init__()
[docs]@abstractmethoddefforward(self,input:torch.Tensor,target:torch.Tensor,weights:Optional[torch.Tensor]=None,)->torch.Tensor:r"""Abstract method for computing the loss."""pass
[docs]defforward(self,input:TensorType["batch_size","dim"],target:TensorType["batch_size","dim"],weights:Optional[TensorType["batch_size"]]=None,)->torch.Tensor:r"""Compute mean squared error loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor, optional): The weights tensor. """ifinput.ndim!=2:raiseValueError("Input must have 2 dimensions")iftarget.ndim!=2:raiseValueError("Target must have 2 dimensions")ifweightsisnotNoneandweights.ndim!=1:raiseValueError("Weights must have 1 dimension")ifweightsisnotNoneandinput.shape[0]!=weights.shape[0]:raiseValueError("Input and weights must have the same batch size")ifweightsisnotNone:loss_noreduce=F.mse_loss(input,target,reduction="none").mean(dim=1)return(weights*loss_noreduce).sum()/weights.sum()else:returnF.mse_loss(input,target)
[docs]defforward(self,input:TensorType["batch_size","dim"],target:TensorType["batch_size"],weights:Optional[TensorType["batch_size"]]=None,)->torch.Tensor:r"""Compute cross-entropy loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor, optional): The weights tensor. """ifinput.ndim!=2:raiseValueError("Input must have 2 dimensions")iftarget.ndim!=1:raiseValueError("Target must have 1 dimensions")ifweightsisnotNoneandweights.ndim!=1:raiseValueError("Weights must have 1 dimension")ifweightsisnotNoneandinput.shape[0]!=weights.shape[0]:raiseValueError("Input and weights must have the same batch size")ifweightsisnotNone:loss_noreduce=F.cross_entropy(input,target,reduction="none")return(weights*loss_noreduce).sum()/weights.sum()else:returnF.cross_entropy(input,target)
[docs]defforward(self,input:TensorType["batch_size","dim"],target:TensorType["batch_size"],weights:TensorType["batch_size"],)->torch.Tensor:r"""Compute Mallow distance loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor): The weights tensor. """num_classes=input.size(-1)input=torch.softmax(input,dim=-1).view(-1,num_classes)target=target.view(-1,1)weights=weights.view(-1)# Mallow distancetarget=torch.zeros_like(input).scatter_(1,target,1.0)# we compute the mallow distance as the sum of the squared differencesloss=torch.mean(torch.square(torch.cumsum(target,dim=-1)-torch.cumsum(input,dim=-1)),dim=-1,)loss=(weights*loss).sum()/weights.sum()returnloss