Source code for torch_brain.utils.readout

from typing import TYPE_CHECKING
import numpy as np
from temporaldata import Data

import torch_brain
from torch_brain.utils import (
    resolve_weights_based_on_interval_membership,
    isin_interval,
)

if TYPE_CHECKING:
    from torch_brain.registry import ModalitySpec


[docs] def prepare_for_readout( data: Data, readout_spec: "ModalitySpec", ): required_keys = ["readout_id"] optional_keys = [ "weights", "normalize_mean", "normalize_std", "timestamp_key", "value_key", "metrics", "eval_interval", ] readout_config = data.config["readout"] # check that the readout config contains all required keys for key in required_keys: if key not in readout_config: raise ValueError(f"readout config is missing required key: {key}") # check that the readout config contains only valid keys if not all(key in required_keys + optional_keys for key in readout_config.keys()): raise ValueError( f"Readout {readout_config} contains invalid keys, please use only {required_keys + optional_keys}" ) key = readout_config["readout_id"] if key not in torch_brain.MODALITY_REGISTRY: raise ValueError( f"Readout {key} not found in modality registry, please register it " "using torch_brain.register_modality()" ) value_key = readout_config.get("value_key", readout_spec.value_key) timestamp_key = readout_config.get("timestamp_key", readout_spec.timestamp_key) timestamps = data.get_nested_attribute(timestamp_key) values = data.get_nested_attribute(value_key) # z-scale the values if mean/std are specified in the config file if "normalize_mean" in readout_config: # if mean is a list, its a per-channel mean (usually for x,y coordinates) if isinstance(readout_config["normalize_mean"], list): mean = np.array(readout_config["normalize_mean"]) else: mean = readout_config["normalize_mean"] values = values - mean if "normalize_std" in readout_config: # if std is a list, its a per-channel std (usually for x,y coordinates) if isinstance(readout_config["normalize_std"], list): std = np.array(readout_config["normalize_std"]) else: std = readout_config["normalize_std"] values = values / std # here we assume that we won't be running a model at float64 precision if values.dtype == np.float64: values = values.astype(np.float32) # resolve weights weights = resolve_weights_based_on_interval_membership( timestamps, data, config=readout_config.get("weights", None) ) # resolve eval mask eval_mask = np.ones(len(timestamps), dtype=np.bool_) eval_interval_key = readout_config.get("eval_interval", None) if eval_interval_key is not None: eval_interval = data.get_nested_attribute(eval_interval_key) eval_mask = isin_interval(timestamps, eval_interval) return timestamps, values, weights, eval_mask