Readout Layers

MultitaskReadout

A multi-task readout module.

prepare_for_multitask_readout()

Tokenizer function for MultitaskReadout.

class MultitaskReadout(dim, readout_specs)[source]

Bases: Module

A module that performs multi-task linear readouts from output embeddings.

forward(output_embs, output_readout_index, unpack_output=False)[source]

Forward pass of the multi-task readout module.

Parameters:
  • output_embs (Tensor) – Transformer output embeddings of shape (batch, n_out, dim)

  • output_readout_index (Tensor) – Integer indices indicating which readout head to use for each output token. Shape (batch, n_out)

  • unpack_output (bool) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples in the batch.

Returns:

Single dictionary containing outputs from all samples concatenated together,

organized by task name. Shape per task: (total_queries, n_channels)

If unpack_output=True:

List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)

Return type:

If unpack_output=False (default)

forward_varlen(output_embs, output_readout_index, output_batch_index, unpack_output=False)[source]

Forward pass of the multi-task readout module for variable length sequences.

This version handles sequences that are chained together in a single batch dimension rather than padded. This can be more memory efficient since it avoids padding.

Parameters:
  • output_embs (Tensor) – Transformer output embeddings of shape (total_ntokens, dim) where total_ntokens is the sum of sequence lengths across the batch

  • output_readout_index (Tensor) – Integer indices indicating which readout head to use for each output token. Shape (total_ntokens,)

  • output_batch_index (Tensor) – Tensor containing batch indices for each token. Shape (total_ntokens,)

  • unpack_output (bool) – By default False, which concatenates all outputs into a single dictionary organized by task. Set to True to break down outputs by individual samples using the batch indices.

Returns:

Single dictionary containing outputs from all samples concatenated together,

organized by task name. Shape per task: (total_queries, n_channels)

If unpack_output=True:

List of dictionaries, where each dictionary contains the outputs for a single batch sample organized by task name. Shape per task: (n_queries, n_channels)

Return type:

If unpack_output=False (default)

prepare_for_multitask_readout(data, readout_registry)[source]