Readout Layers¶
A multi-task readout module. |
|
Tokenizer function for |
- class MultitaskReadout(dim, readout_specs)[source]¶
Bases:
Module
A module that performs multi-task linear readouts from output embeddings.
- forward(output_embs, output_readout_index, unpack_output=False)[source]¶
Forward pass of the multi-task readout module.
- Parameters:
output_embs (
Tensor
) – Transformer output embeddings of shape (batch, n_out, dim)output_readout_index (
Tensor
) – Integer indices indicating which readout head to use for each output token. Shape (batch, n_out)unpack_output (
bool
) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples in the batch.
- Returns:
- Single dictionary containing outputs from all samples concatenated together,
organized by task name. Shape per task: (total_queries, n_channels)
- If unpack_output=True:
List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)
- Return type:
If unpack_output=False (default)
- forward_varlen(output_embs, output_readout_index, output_batch_index, unpack_output=False)[source]¶
Forward pass of the multi-task readout module for variable length sequences.
This version handles sequences that are chained together in a single batch dimension rather than padded. This can be more memory efficient since it avoids padding.
- Parameters:
output_embs (
Tensor
) – Transformer output embeddings of shape (total_ntokens, dim) where total_ntokens is the sum of sequence lengths across the batchoutput_readout_index (
Tensor
) – Integer indices indicating which readout head to use for each output token. Shape (total_ntokens,)output_batch_index (
Tensor
) – Tensor containing batch indices for each token. Shape (total_ntokens,)unpack_output (
bool
) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples using the batch indices.
- Returns:
- Single dictionary containing outputs from all samples concatenated together,
organized by task name. Shape per task: (total_queries, n_channels)
- If unpack_output=True:
List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)
- Return type:
If unpack_output=False (default)