MultitaskReadout#
- class torch_brain.nn.MultitaskReadout(dim, readout_specs)[source]#
Bases:
torch.nn.modules.module.ModuleA module that performs multi-task linear readouts from output embeddings.
- forward(output_embs, output_readout_index, unpack_output=False)[source]#
Forward pass of the multi-task readout module.
- Parameters:
output_embs (
Tensor) – Transformer output embeddings of shape (batch, n_out, dim)output_readout_index (
Tensor) – Integer indices indicating which readout head to use for each output token. Shape (batch, n_out)unpack_output (
bool) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples in the batch.
- Returns:
- Single dictionary containing outputs from all samples concatenated together,
organized by task name. Shape per task: (total_queries, n_channels)
- If unpack_output=True:
List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)
- Return type:
If unpack_output=False (default)
- forward_varlen(output_embs, output_readout_index, output_batch_index, unpack_output=False)[source]#
Forward pass of the multi-task readout module for variable length sequences.
This version handles sequences that are chained together in a single batch dimension rather than padded. This can be more memory efficient since it avoids padding.
- Parameters:
output_embs (
Tensor) – Transformer output embeddings of shape (total_ntokens, dim) where total_ntokens is the sum of sequence lengths across the batchoutput_readout_index (
Tensor) – Integer indices indicating which readout head to use for each output token. Shape (total_ntokens,)output_batch_index (
Tensor) – Tensor containing batch indices for each token. Shape (total_ntokens,)unpack_output (
bool) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples using the batch indices.
- Returns:
- Single dictionary containing outputs from all samples concatenated together,
organized by task name. Shape per task: (total_queries, n_channels)
- If unpack_output=True:
List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)
- Return type:
If unpack_output=False (default)