From 3e6b45f827d8a6e05fd06e31236d811b69c1475d Mon Sep 17 00:00:00 2001 From: Eike Middell Date: Sun, 16 Feb 2025 15:05:57 +0100 Subject: [PATCH] add quality.stimulus_mask that flags stimulus events that overlap with masked periods --- src/cedalion/sigproc/quality.py | 38 +++++++++++++++++++++++++++ tests/test_sigproc_quality.py | 46 ++++++++++++++++++++++++++++++++- 2 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/cedalion/sigproc/quality.py b/src/cedalion/sigproc/quality.py index 624abff..9c23a3a 100644 --- a/src/cedalion/sigproc/quality.py +++ b/src/cedalion/sigproc/quality.py @@ -1133,3 +1133,41 @@ def detect_baselineshift(ts: cdt.NDTimeSeries, outlier_mask: cdt.NDTimeSeries): shift_mask = shift_mask.isel(time=slice(pad_samples,-pad_samples)) return shift_mask + + + +def stimulus_mask(df_stim : pd.DataFrame, mask : xr.DataArray) -> xr.DataArray: + """Create a mask which events overlap with periods flagged as tainted in mask. + + Args: + df_stim: stimulus data frame + mask: signal quality mask. Must contain dimensions 'channel' and 'time' + + Returns: + A boolean mask with dimensions "stim", "channel". + The stim dimension matches the stimulus dataframe. Stimuli are marked as + TAINTED when there is any TAINTED flag in the mask between onset and onset+ + duration. + """ + assert mask.ndim == 2 + assert "channel" in mask.dims + assert "time" in mask.dims + + result = np.zeros((len(df_stim), mask.sizes["channel"]), dtype=bool) + + for i, r in df_stim.iterrows(): + tmp = mask.sel( + time=(r["onset"] <= mask.time) & (mask.time < (r["onset"] + r["duration"])) + ) + result[i,:] = (tmp == CLEAN).all("time") + + return xr.DataArray( + result, + dims=["stim", "channel"], + coords=xrutils.coords_from_other( + mask, + dims=["channel"], + stim=("stim", df_stim.index), + trial_type=("stim", df_stim.trial_type), + ), + ) diff --git a/tests/test_sigproc_quality.py b/tests/test_sigproc_quality.py index 6df659a..d62846a 100644 --- a/tests/test_sigproc_quality.py +++ b/tests/test_sigproc_quality.py @@ -1,7 +1,11 @@ +import numpy as np +import pandas as pd import pytest from numpy.testing import assert_allclose -import cedalion.sigproc.quality as quality + +import cedalion.dataclasses as cdc import cedalion.datasets +import cedalion.sigproc.quality as quality from cedalion import units @@ -94,3 +98,43 @@ def test_detect_outliers(rec): def test_detect_baselineshift(rec): outlier_mask = quality.detect_outliers(rec["amp"], t_window_std=2 * units.s) _ = quality.detect_baselineshift(rec["amp"], outlier_mask) + + + +def test_stimulus_mask(): + t = np.arange(10) + channel = ["S1D1", "S1D2", "S1D3"] + source = ["S1", "S1", "S1"] + detector = ["D1", "D2", "D3"] + + df_stim = pd.DataFrame( + { + "onset": [1.0, 5.0], + "duration": [3.0, 3.0], + "value": [1.0, 1.0], + "trial_type": ["X", "X"], + } + ) + + mask = cdc.build_timeseries( + np.array([ + [1, 1, 1, 1, 1, 1, 1, 1, 1, 1], + [1, 0, 1, 1, 1, 1, 1, 1, 1, 1], # stim 0 in channel 1 tainted + [1, 1, 1, 1, 1, 1, 0, 1, 1, 1], # stim 1 in channel 2 tainted + ]), + dims=["channel", "time"], + time=t, + channel=channel, + value_units="1", + time_units="s", + other_coords={"source": ("channel", source), "detector": ("channel", detector)}, + ).astype(bool) + + stim_mask = quality.stimulus_mask(df_stim, mask) + + assert stim_mask.dims == ("stim", "channel") + assert stim_mask.sizes["stim"] == 2 + assert stim_mask.sizes["channel"] == 3 + + assert all(stim_mask[0,:] == [True, False, True]) + assert all(stim_mask[1, :] == [True, True, False])