Skip to content

Task

Task dataclass

Task(
    name: str,
    session: dict[str, float] = lambda: {
        "v": 0.5,
        "a": 0.5,
    }(),
    stim_intensities: list[float] = lambda: [0.8, 0.9, 1](),
    stim_time: int = 1000,
    catch_prob: float = 0.5,
    shuffle_trials: bool = True,
    max_sequential: int | None = None,
    *,
    fix_intensity: float = 0,
    fix_time: int | tuple[int, int] = 100,
    iti: int | tuple[int, int] = 0,
    dt: int = 20,
    tau: int = 100,
    n_outputs: int = 2,
    output_behavior: list[float] = lambda: [0, 1](),
    noise_std: float = 0.01,
    scaling: bool = True
)

Bases: TaskSettingsMixin

General data class for defining a task.

A task is defined by a set of trials, each of which is characterized by a sequence of inputs and expected outputs.

Parameters:

Name Type Description Default
name str

Name of the task.

required
session dict[str, float]

Configuration of the trials that can appear during a session. It is given by a dictionary representing the ratio (values) of the different trials (keys) within the task. Trials with a single modality (e.g., a visual trial) must be represented by single characters, while trials with multiple modalities (e.g., an audiovisual trial) are represented by the character combination of those trials. The capital letter X may not be used to signify a modality, as it is reserved for catch trials. Note that values are read relative to each other, such that e.g. {"v": 0.25, "a": 0.75} is equivalent to {"v": 1, "a": 3}. Defaults to {"v": 0.5, "a": 0.5}.

lambda: {'v': 0.5, 'a': 0.5}()
stim_intensities list[float]

List of possible intensity values of each stimulus, when the stimulus is present. Note that when the stimulus is not present, the intensity is set to 0. Defaults to [0.8, 0.9, 1].

lambda: [0.8, 0.9, 1]()
stim_time int

Duration of each stimulus in ms. Defaults to 1000.

1000
catch_prob float

probability of catch trials (denoted by X) in the session. Must be between 0 and 1 (inclusive). Defaults to 0.5.

0.5
shuffle_trials bool

If True (default), trial order will be randomized. If False, all trials corresponding to one modality (e.g. visual) are run before any trial of the next modality (e.g. auditory) starts, in the order defined in session (catch trials will still be randomly interspersed).

True
max_sequential int | None

If shuffle_trials is True, sets the maximum number of sequential trials of the same modality. Defaults to None (no maximum).

None

generate_trials

generate_trials(
    ntrials: int | tuple[int, int] = 20,
    random_seed: int | None = None,
) -> dict[str, Any]

Method for generating trials.

Parameters:

Name Type Description Default
ntrials int | tuple[int, int]

Number of trials to generate. If a tuple is given, it is interpreted as an interval of possible values, and a value will be randomly picked from it. Defaults to 20.

20
random_seed int | None

Seed for numpy's random number generator (rng). If an int is given, it will be used as the seed for np.random.default_rng(). Defaults to None (i.e. the initial state itself is random).

None

Returns:

Type Description
dict[str, Any]

dict containing all input parameters of Task ("task_settings"), the input parameters for the current

dict[str, Any]

generate_trials() method's call ("ntrials", "random_state"), and the generated data ("modality_seq",

dict[str, Any]

"time", "phases", "inputs", "outputs").

Source code in annubes/task.py
def generate_trials(
    self,
    ntrials: int | tuple[int, int] = 20,
    random_seed: int | None = None,
) -> dict[str, Any]:
    """Method for generating trials.

    Args:
        ntrials: Number of trials to generate. If a tuple is given, it is interpreted as an interval of
            possible values, and a value will be randomly picked from it.
            Defaults to 20.
        random_seed: Seed for numpy's random number generator (rng). If an int is given, it will be used as the seed
            for `np.random.default_rng()`.
            Defaults to None (i.e. the initial state itself is random).

    Returns:
        dict containing all input parameters of `Task` ("task_settings"), the input parameters for the current
        `generate_trials()` method's call ("ntrials", "random_state"), and the generated data ("modality_seq",
        "time", "phases", "inputs", "outputs").
    """
    # Check input parameters
    self._check_range("ntrials", ntrials, strict=True)
    if random_seed is not None:
        self._check_int_positive("random_seed", random_seed, strict=False)

    # Set random state
    if random_seed is None:
        rng = np.random.default_rng(random_seed)
        random_seed = rng.integers(2**32)
    self._rng = np.random.default_rng(random_seed)
    self._random_seed = random_seed

    self._ntrials = self._rng.integers(min(ntrials), max(ntrials)) if isinstance(ntrials, tuple) else ntrials

    # Generate sequence of modalities
    self._modality_seq = self._build_trials_seq()

    # Setup phases of trial
    self._fix_time, self._iti, self._time, self._phases = self._setup_trial_phases()

    # Generate inputs and outputs
    self._inputs = self._build_trials_inputs()
    self._outputs = self._build_trials_outputs()

    # Scaling
    if self.scaling:
        flattened = np.concatenate(
            (np.concatenate(self._inputs).reshape(-1), np.concatenate(self._outputs).reshape(-1)),
        )
        abs_min = np.min(flattened)
        abs_max = np.max(flattened)

        for n in range(self._ntrials):
            self._inputs[n] = self._minmaxscaler(self._inputs[n], abs_min, abs_max)
            self._outputs[n] = self._minmaxscaler(self._outputs[n], abs_min, abs_max)

    # Store trials settings and data
    return {
        "task_settings": self._task_settings,
        "ntrials": self._ntrials,
        "random_seed": self._random_seed,
        "modality_seq": self._modality_seq,
        "time": self._time,
        "phases": self._phases,
        "inputs": self._inputs,
        "outputs": self._outputs,
    }

plot_trials

plot_trials(n_plots: int = 1) -> Figure

Method for plotting generated trials.

Parameters:

Name Type Description Default
n_plots int

number of trials to plot (capped by number of trials generated). Defaults to 1.

1

Returns:

Type Description
Figure

go.Figure: Plotly figure of trial results.

Source code in annubes/task.py
def plot_trials(self, n_plots: int = 1) -> go.Figure:
    """Method for plotting generated trials.

    Args:
        n_plots: number of trials to plot (capped by number of trials generated). Defaults to 1.

    Returns:
        go.Figure: Plotly figure of trial results.
    """
    # Check input parameters
    self._check_int_positive("n_plots", n_plots, strict=True)

    if (p := n_plots) > (t := self._ntrials):
        msg = f"Number of plots requested ({p}) exceeds number of trials ({t}). Will plot all trials."
        warnings.warn(msg, stacklevel=2)
        n_plots = self._ntrials

    fig = make_subplots(
        rows=n_plots,
        cols=1,
        shared_xaxes=True,
        vertical_spacing=0.5 / n_plots,
        subplot_titles=[f"Trial {i + 1}  - modality {self._modality_seq[i]}" for i in range(n_plots)],
    )
    showlegend = True
    colors = [
        "#{:02x}{:02x}{:02x}".format(
            *tuple(int(c * 255) for c in colorsys.hsv_to_rgb(i / self._n_inputs, 1.0, 1.0)),
        )
        for i in range(self._n_inputs)
    ]
    for i in range(n_plots):
        for idx, m in enumerate(self._modalities):
            fig.add_trace(
                go.Scatter(
                    name=m,
                    mode="markers+lines",
                    x=self._time[i],
                    y=self._inputs[i][:, idx],
                    marker_symbol="star",
                    legendgroup=m,
                    showlegend=showlegend,
                    line_color=colors[idx],
                ),
                row=i + 1,
                col=1,
            )
        fig.add_trace(
            go.Scatter(
                name="START",
                mode="markers+lines",
                x=self._time[i],
                y=self._inputs[i][:, self._n_inputs - 1],
                marker_symbol="star",
                legendgroup="START",
                showlegend=showlegend,
                line_color="green",
            ),
            row=i + 1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                name="Choice 1: NO STIMULUS",
                mode="lines",
                x=self._time[i],
                y=self._outputs[i][:, 0],
                legendgroup="Choice 1",
                showlegend=showlegend,
                line_color="orange",
            ),
            row=i + 1,
            col=1,
        )
        fig.add_trace(
            go.Scatter(
                name="Choice 2: STIMULUS",
                mode="lines",
                x=self._time[i],
                y=self._outputs[i][:, 1],
                legendgroup="Choice 2",
                showlegend=showlegend,
                line_color="purple",
            ),
            row=i + 1,
            col=1,
        )
        fig.add_vline(
            x=self._fix_time[i] + self.dt,
            line_width=3,
            line_dash="dash",
            line_color="red",
            row=i + 1,
            col=1,
        )
        showlegend = False
    fig.update_layout(height=1300, width=900, title_text="Trials")
    return fig