Task
Task
dataclass
¶
Task(
name: str,
session: dict[str, float] = lambda: {
"v": 0.5,
"a": 0.5,
}(),
stim_intensities: list[float] = lambda: [0.8, 0.9, 1](),
stim_time: int = 1000,
catch_prob: float = 0.5,
shuffle_trials: bool = True,
max_sequential: int | None = None,
*,
fix_intensity: float = 0,
fix_time: int | tuple[int, int] = 100,
iti: int | tuple[int, int] = 0,
dt: int = 20,
tau: int = 100,
n_outputs: int = 2,
output_behavior: list[float] = lambda: [0, 1](),
noise_std: float = 0.01,
scaling: bool = True
)
Bases: TaskSettingsMixin
General data class for defining a task.
A task is defined by a set of trials, each of which is characterized by a sequence of inputs and expected outputs.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
name |
str
|
Name of the task. |
required |
session |
dict[str, float]
|
Configuration of the trials that can appear during a session.
It is given by a dictionary representing the ratio (values) of the different trials (keys) within the task.
Trials with a single modality (e.g., a visual trial) must be represented by single characters, while trials
with multiple modalities (e.g., an audiovisual trial) are represented by the character combination of those
trials. The capital letter X may not be used to signify a modality, as it is reserved for catch trials.
Note that values are read relative to each other, such that e.g. |
lambda: {'v': 0.5, 'a': 0.5}()
|
stim_intensities |
list[float]
|
List of possible intensity values of each stimulus, when the stimulus is present. Note that when the stimulus is not present, the intensity is set to 0. Defaults to [0.8, 0.9, 1]. |
lambda: [0.8, 0.9, 1]()
|
stim_time |
int
|
Duration of each stimulus in ms. Defaults to 1000. |
1000
|
catch_prob |
float
|
probability of catch trials (denoted by X) in the session. Must be between 0 and 1 (inclusive). Defaults to 0.5. |
0.5
|
shuffle_trials |
bool
|
If True (default), trial order will be randomized. If False, all trials corresponding to one
modality (e.g. visual) are run before any trial of the next modality (e.g. auditory) starts, in the order
defined in |
True
|
max_sequential |
int | None
|
If |
None
|
generate_trials
¶
generate_trials(
ntrials: int | tuple[int, int] = 20,
random_seed: int | None = None,
) -> dict[str, Any]
Method for generating trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
ntrials |
int | tuple[int, int]
|
Number of trials to generate. If a tuple is given, it is interpreted as an interval of possible values, and a value will be randomly picked from it. Defaults to 20. |
20
|
random_seed |
int | None
|
Seed for numpy's random number generator (rng). If an int is given, it will be used as the seed
for |
None
|
Returns:
Type | Description |
---|---|
dict[str, Any]
|
dict containing all input parameters of |
dict[str, Any]
|
|
dict[str, Any]
|
"time", "phases", "inputs", "outputs"). |
Source code in annubes/task.py
def generate_trials(
self,
ntrials: int | tuple[int, int] = 20,
random_seed: int | None = None,
) -> dict[str, Any]:
"""Method for generating trials.
Args:
ntrials: Number of trials to generate. If a tuple is given, it is interpreted as an interval of
possible values, and a value will be randomly picked from it.
Defaults to 20.
random_seed: Seed for numpy's random number generator (rng). If an int is given, it will be used as the seed
for `np.random.default_rng()`.
Defaults to None (i.e. the initial state itself is random).
Returns:
dict containing all input parameters of `Task` ("task_settings"), the input parameters for the current
`generate_trials()` method's call ("ntrials", "random_state"), and the generated data ("modality_seq",
"time", "phases", "inputs", "outputs").
"""
# Check input parameters
self._check_range("ntrials", ntrials, strict=True)
if random_seed is not None:
self._check_int_positive("random_seed", random_seed, strict=False)
# Set random state
if random_seed is None:
rng = np.random.default_rng(random_seed)
random_seed = rng.integers(2**32)
self._rng = np.random.default_rng(random_seed)
self._random_seed = random_seed
self._ntrials = self._rng.integers(min(ntrials), max(ntrials)) if isinstance(ntrials, tuple) else ntrials
# Generate sequence of modalities
self._modality_seq = self._build_trials_seq()
# Setup phases of trial
self._fix_time, self._iti, self._time, self._phases = self._setup_trial_phases()
# Generate inputs and outputs
self._inputs = self._build_trials_inputs()
self._outputs = self._build_trials_outputs()
# Scaling
if self.scaling:
flattened = np.concatenate(
(np.concatenate(self._inputs).reshape(-1), np.concatenate(self._outputs).reshape(-1)),
)
abs_min = np.min(flattened)
abs_max = np.max(flattened)
for n in range(self._ntrials):
self._inputs[n] = self._minmaxscaler(self._inputs[n], abs_min, abs_max)
self._outputs[n] = self._minmaxscaler(self._outputs[n], abs_min, abs_max)
# Store trials settings and data
return {
"task_settings": self._task_settings,
"ntrials": self._ntrials,
"random_seed": self._random_seed,
"modality_seq": self._modality_seq,
"time": self._time,
"phases": self._phases,
"inputs": self._inputs,
"outputs": self._outputs,
}
plot_trials
¶
plot_trials(n_plots: int = 1) -> Figure
Method for plotting generated trials.
Parameters:
Name | Type | Description | Default |
---|---|---|---|
n_plots |
int
|
number of trials to plot (capped by number of trials generated). Defaults to 1. |
1
|
Returns:
Type | Description |
---|---|
Figure
|
go.Figure: Plotly figure of trial results. |
Source code in annubes/task.py
def plot_trials(self, n_plots: int = 1) -> go.Figure:
"""Method for plotting generated trials.
Args:
n_plots: number of trials to plot (capped by number of trials generated). Defaults to 1.
Returns:
go.Figure: Plotly figure of trial results.
"""
# Check input parameters
self._check_int_positive("n_plots", n_plots, strict=True)
if (p := n_plots) > (t := self._ntrials):
msg = f"Number of plots requested ({p}) exceeds number of trials ({t}). Will plot all trials."
warnings.warn(msg, stacklevel=2)
n_plots = self._ntrials
fig = make_subplots(
rows=n_plots,
cols=1,
shared_xaxes=True,
vertical_spacing=0.5 / n_plots,
subplot_titles=[f"Trial {i + 1} - modality {self._modality_seq[i]}" for i in range(n_plots)],
)
showlegend = True
colors = [
"#{:02x}{:02x}{:02x}".format(
*tuple(int(c * 255) for c in colorsys.hsv_to_rgb(i / self._n_inputs, 1.0, 1.0)),
)
for i in range(self._n_inputs)
]
for i in range(n_plots):
for idx, m in enumerate(self._modalities):
fig.add_trace(
go.Scatter(
name=m,
mode="markers+lines",
x=self._time[i],
y=self._inputs[i][:, idx],
marker_symbol="star",
legendgroup=m,
showlegend=showlegend,
line_color=colors[idx],
),
row=i + 1,
col=1,
)
fig.add_trace(
go.Scatter(
name="START",
mode="markers+lines",
x=self._time[i],
y=self._inputs[i][:, self._n_inputs - 1],
marker_symbol="star",
legendgroup="START",
showlegend=showlegend,
line_color="green",
),
row=i + 1,
col=1,
)
fig.add_trace(
go.Scatter(
name="Choice 1: NO STIMULUS",
mode="lines",
x=self._time[i],
y=self._outputs[i][:, 0],
legendgroup="Choice 1",
showlegend=showlegend,
line_color="orange",
),
row=i + 1,
col=1,
)
fig.add_trace(
go.Scatter(
name="Choice 2: STIMULUS",
mode="lines",
x=self._time[i],
y=self._outputs[i][:, 1],
legendgroup="Choice 2",
showlegend=showlegend,
line_color="purple",
),
row=i + 1,
col=1,
)
fig.add_vline(
x=self._fix_time[i] + self.dt,
line_width=3,
line_dash="dash",
line_color="red",
row=i + 1,
col=1,
)
showlegend = False
fig.update_layout(height=1300, width=900, title_text="Trials")
return fig