Skip to content

Commit

Permalink
add quality.stimulus_mask that flags stimulus events that overlap wit…
Browse files Browse the repository at this point in the history
…h masked periods
  • Loading branch information
emiddell committed Feb 16, 2025
1 parent 73ba87b commit 3e6b45f
Show file tree
Hide file tree
Showing 2 changed files with 83 additions and 1 deletion.
38 changes: 38 additions & 0 deletions src/cedalion/sigproc/quality.py
Original file line number Diff line number Diff line change
Expand Up @@ -1133,3 +1133,41 @@ def detect_baselineshift(ts: cdt.NDTimeSeries, outlier_mask: cdt.NDTimeSeries):
shift_mask = shift_mask.isel(time=slice(pad_samples,-pad_samples))

return shift_mask



def stimulus_mask(df_stim : pd.DataFrame, mask : xr.DataArray) -> xr.DataArray:
"""Create a mask which events overlap with periods flagged as tainted in mask.
Args:
df_stim: stimulus data frame
mask: signal quality mask. Must contain dimensions 'channel' and 'time'
Returns:
A boolean mask with dimensions "stim", "channel".
The stim dimension matches the stimulus dataframe. Stimuli are marked as
TAINTED when there is any TAINTED flag in the mask between onset and onset+
duration.
"""
assert mask.ndim == 2
assert "channel" in mask.dims
assert "time" in mask.dims

result = np.zeros((len(df_stim), mask.sizes["channel"]), dtype=bool)

for i, r in df_stim.iterrows():
tmp = mask.sel(
time=(r["onset"] <= mask.time) & (mask.time < (r["onset"] + r["duration"]))
)
result[i,:] = (tmp == CLEAN).all("time")

return xr.DataArray(
result,
dims=["stim", "channel"],
coords=xrutils.coords_from_other(
mask,
dims=["channel"],
stim=("stim", df_stim.index),
trial_type=("stim", df_stim.trial_type),
),
)
46 changes: 45 additions & 1 deletion tests/test_sigproc_quality.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,11 @@
import numpy as np
import pandas as pd
import pytest
from numpy.testing import assert_allclose
import cedalion.sigproc.quality as quality

import cedalion.dataclasses as cdc
import cedalion.datasets
import cedalion.sigproc.quality as quality
from cedalion import units


Expand Down Expand Up @@ -94,3 +98,43 @@ def test_detect_outliers(rec):
def test_detect_baselineshift(rec):
outlier_mask = quality.detect_outliers(rec["amp"], t_window_std=2 * units.s)
_ = quality.detect_baselineshift(rec["amp"], outlier_mask)



def test_stimulus_mask():
t = np.arange(10)
channel = ["S1D1", "S1D2", "S1D3"]
source = ["S1", "S1", "S1"]
detector = ["D1", "D2", "D3"]

df_stim = pd.DataFrame(
{
"onset": [1.0, 5.0],
"duration": [3.0, 3.0],
"value": [1.0, 1.0],
"trial_type": ["X", "X"],
}
)

mask = cdc.build_timeseries(
np.array([
[1, 1, 1, 1, 1, 1, 1, 1, 1, 1],
[1, 0, 1, 1, 1, 1, 1, 1, 1, 1], # stim 0 in channel 1 tainted
[1, 1, 1, 1, 1, 1, 0, 1, 1, 1], # stim 1 in channel 2 tainted
]),
dims=["channel", "time"],
time=t,
channel=channel,
value_units="1",
time_units="s",
other_coords={"source": ("channel", source), "detector": ("channel", detector)},
).astype(bool)

stim_mask = quality.stimulus_mask(df_stim, mask)

assert stim_mask.dims == ("stim", "channel")
assert stim_mask.sizes["stim"] == 2
assert stim_mask.sizes["channel"] == 3

assert all(stim_mask[0,:] == [True, False, True])
assert all(stim_mask[1, :] == [True, True, False])

0 comments on commit 3e6b45f

Please sign in to comment.