Skip to main content
Ctrl+K

TorchBrain

  • User Guide
  • API Reference
  • GitHub
  • Discord
  • PyPI
  • User Guide
  • API Reference
  • GitHub
  • Discord
  • PyPI

Section Navigation

  • torch_brain.dataset
    • Dataset
    • DatasetIndex
    • NestedDataset
    • SpikingDatasetMixin
    • CalciumImagingDatasetMixin
    • MultiChannelDatasetMixin
  • torch_brain.data
    • chain
    • collate
    • pad
    • pad8
    • pad2d
    • pad2d8
    • track_batch
    • track_mask
    • track_mask8
    • track_mask2d
    • track_mask2d8
    • Dataset
  • torch_brain.data.sampler
    • RandomFixedWindowSampler
    • SequentialFixedWindowSampler
    • TrialSampler
    • DistributedEvaluationSamplerWrapper
    • DistributedStitchingFixedWindowSampler
  • torch_brain.transforms
    • Compose
    • RandomChoice
    • ConditionalChoice
    • UnitDropout
    • TriangleDistribution
    • RandomTimeScaling
    • RandomOutputSampler
    • RandomCrop
    • BinSpikes
    • UnitFilter
    • UnitFilterById
  • torch_brain.nn
    • Embedding
    • InfiniteVocabEmbedding
    • RotaryTimeEmbedding
    • SinusoidalTimeEmbedding
    • RotaryCrossAttention
    • RotarySelfAttention
    • MultitaskReadout
    • prepare_for_multitask_readout
    • Loss
    • MSELoss
    • CrossEntropyLoss
    • MallowDistanceLoss
  • torch_brain.models
    • POYO
    • POYOPlus
    • CalciumPOYOPlus
  • torch_brain.registry
    • DataType
    • ModalitySpec
    • register_modality
    • get_modality_by_id
    • MODALITY_REGISTRY
  • torch_brain.utils
    • stitch
    • seed_everything
    • create_linspace_latent_tokens
    • create_start_end_unit_tokens
    • resolve_weights_based_on_interval_membership
    • isin_interval
    • prepare_for_readout
    • np_string_prefix
    • bin_spikes
  • API Reference
  • torch_brain.nn
  • Loss

Loss#

class torch_brain.nn.Loss[source]#

Bases: torch.nn.modules.module.Module, abc.ABC

Base class for losses. All losses should support an optional weights argument.

abstract forward(input, target, weights=None)[source]#

Abstract method for computing the loss.

Return type:

Tensor

previous

prepare_for_multitask_readout

next

MSELoss

On this page
  • Loss
    • Loss.forward()

© Copyright 2026, neuro-galaxy Team.

Created using Sphinx 8.1.3.

Built with the PyData Sphinx Theme 0.17.1.