Source code for torch_brain.nn.loss

from abc import ABC, abstractmethod
from typing import Optional

import torch
import torch.nn.functional as F
from torchtyping import TensorType


[docs] class Loss(torch.nn.Module, ABC): r"""Base class for losses. All losses should support an optional weights argument.""" def __init__(self): super().__init__()
[docs] @abstractmethod def forward( self, input: torch.Tensor, target: torch.Tensor, weights: Optional[torch.Tensor] = None, ) -> torch.Tensor: r"""Abstract method for computing the loss.""" pass
[docs] class MSELoss(Loss):
[docs] def forward( self, input: TensorType["batch_size", "dim"], target: TensorType["batch_size", "dim"], weights: Optional[TensorType["batch_size"]] = None, ) -> torch.Tensor: r"""Compute mean squared error loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor, optional): The weights tensor. """ if input.ndim != 2: raise ValueError("Input must have 2 dimensions") if target.ndim != 2: raise ValueError("Target must have 2 dimensions") if weights is not None and weights.ndim != 1: raise ValueError("Weights must have 1 dimension") if weights is not None and input.shape[0] != weights.shape[0]: raise ValueError("Input and weights must have the same batch size") if weights is not None: loss_noreduce = F.mse_loss(input, target, reduction="none").mean(dim=1) return (weights * loss_noreduce).sum() / weights.sum() else: return F.mse_loss(input, target)
[docs] class CrossEntropyLoss(Loss):
[docs] def forward( self, input: TensorType["batch_size", "dim"], target: TensorType["batch_size"], weights: Optional[TensorType["batch_size"]] = None, ) -> torch.Tensor: r"""Compute cross-entropy loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor, optional): The weights tensor. """ if input.ndim != 2: raise ValueError("Input must have 2 dimensions") if target.ndim != 1: raise ValueError("Target must have 1 dimensions") if weights is not None and weights.ndim != 1: raise ValueError("Weights must have 1 dimension") if weights is not None and input.shape[0] != weights.shape[0]: raise ValueError("Input and weights must have the same batch size") if weights is not None: loss_noreduce = F.cross_entropy(input, target, reduction="none") return (weights * loss_noreduce).sum() / weights.sum() else: return F.cross_entropy(input, target)
[docs] class MallowDistanceLoss(Loss):
[docs] def forward( self, input: TensorType["batch_size", "dim"], target: TensorType["batch_size"], weights: TensorType["batch_size"], ) -> torch.Tensor: r"""Compute Mallow distance loss. Args: input (Tensor): The input tensor. target (Tensor): The target tensor. weights (Tensor): The weights tensor. """ num_classes = input.size(-1) input = torch.softmax(input, dim=-1).view(-1, num_classes) target = target.view(-1, 1) weights = weights.view(-1) # Mallow distance target = torch.zeros_like(input).scatter_(1, target, 1.0) # we compute the mallow distance as the sum of the squared differences loss = torch.mean( torch.square(torch.cumsum(target, dim=-1) - torch.cumsum(input, dim=-1)), dim=-1, ) loss = (weights * loss).sum() / weights.sum() return loss