-
Notifications
You must be signed in to change notification settings - Fork 9
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Synthetic artifacts first commit * Scale spike by duration * Move artifact generation into dedicated functions * Added brief gallery thumbnail info * Synthetic artifact progress Added auto scaling Added passing of artifact functions Improved example notebook * Small bug fix * Added manual scale option * Synthetic artifact improvements and TDDR Improved scaling options for synthetic artifacts, among other fixes. Added TDDR motion correction algorithm. * Added thumbnail to notebook * Added PCA to MA notebook Also added reference to TDDR github * Resolving alex's comments moved synth artifacts example expanded docstrings made plots in motion notebook consistent * Added selection by optode to example notebook * address linter complaints and code formatting in synthetic_artifact.py * further changes - specified a protocol for artifact functions - avoid mutable default parameters - avoid frequent calls to dequantify * remove hardcoded wavelengths in conc2od. add tests * rename TDDR to tddr (PEP8) * minor improvements in 22_motion_artefacts_and_corrections - indicate true artefact positions in plots - consolidate plotting functions - fix linter complaints * address linter complaints in 61_synthetic_artifacts_example * simple test of tddr that only checks execution * Addressing Eike's comments Removed obsolete functions Raise more errors Added argument for sliding_window * Removed notebook outputs --------- Co-authored-by: Eike Middell <[email protected]>
- Loading branch information
Showing
11 changed files
with
1,076 additions
and
73 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
287 changes: 287 additions & 0 deletions
287
examples/augmentation/61_synthetic_artifacts_example.ipynb
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,287 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"# Synthetic Artifacts" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import matplotlib.pyplot as p\n", | ||
"import xarray as xr\n", | ||
"\n", | ||
"import cedalion\n", | ||
"import cedalion.datasets as datasets\n", | ||
"import cedalion.nirs\n", | ||
"import cedalion.sim.synthetic_artifact as sa" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"First, we'll load some example data." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"rec = datasets.get_fingertapping()\n", | ||
"rec[\"od\"] = cedalion.nirs.int2od(rec[\"amp\"])\n", | ||
"\n", | ||
"f, ax = p.subplots(1, 1, figsize=(12, 4))\n", | ||
"ax.plot(\n", | ||
" rec[\"amp\"].time,\n", | ||
" rec[\"amp\"].sel(channel=\"S3D3\", wavelength=\"850\"),\n", | ||
" \"g-\",\n", | ||
" label=\"850nm\",\n", | ||
")\n", | ||
"ax.plot(\n", | ||
" rec[\"amp\"].time,\n", | ||
" rec[\"amp\"].sel(channel=\"S3D3\", wavelength=\"760\"),\n", | ||
" \"r-\",\n", | ||
" label=\"760nm\",\n", | ||
")\n", | ||
"p.legend()\n", | ||
"ax.set_xlabel(\"time / s\")\n", | ||
"ax.set_ylabel(\"intensity / v\")\n", | ||
"\n", | ||
"display(rec[\"od\"])" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Artifact Generation\n", | ||
"\n", | ||
"Artifacts are generated by functions taking as arguments: \n", | ||
"- time axis of timeseries \n", | ||
"- onset time \n", | ||
"- duration\n", | ||
"\n", | ||
"To enable proper scaling, the amplitude of the generic artifact generated by these functions should be 1." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"time = rec[\"amp\"].time\n", | ||
"\n", | ||
"sample_bl_shift = sa.gen_bl_shift(time, 1000)\n", | ||
"sample_spike = sa.gen_spike(time, 2000, 3)\n", | ||
"\n", | ||
"display(sample_bl_shift)\n", | ||
"\n", | ||
"fig, ax = p.subplots(1, 1, figsize=(12,2))\n", | ||
"ax.plot(time, sample_bl_shift, \"r-\", label=\"bl_shift\")\n", | ||
"ax.plot(time, sample_spike, \"g-\", label=\"spike\")\n", | ||
"ax.set_xlabel('Time / s')\n", | ||
"ax.set_ylabel('Amp')\n", | ||
"ax.legend()\n", | ||
"\n", | ||
"p.tight_layout()\n", | ||
"p.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Controlling Artifact Timing\n", | ||
"\n", | ||
"Artifacts can be placed using a timing dataframe with columns onset_time, duration, trial_type, value, and channel (extends stim dataframe).\n", | ||
"\n", | ||
"We can use the function add_event_timing to create and modify timing dataframes. The function allows precise control over each event.\n", | ||
"\n", | ||
"The function sel_chans_by_opt allows us to select a list of channels by way of a list of optodes. This reflects the fact that motion artifacts usually stem from the motion of a specific optode or set of optodes, which in turn affects all related channels.\n", | ||
"\n", | ||
"We can also use the functions random_events_num and random_events_perc to add random events to the dataframe—specifying either the number of events or the percentage of the timeseries duration, respectively." | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Create a list of events in the format (onset, duration)\n", | ||
"events = [(1000, 1), (2000, 1)]\n", | ||
"\n", | ||
"# Creates a new timing dataframe with the specified events.\n", | ||
"# Setting channel to None indicates that the artifact applies to all channels.\n", | ||
"timing_amp = sa.add_event_timing(events, 'bl_shift', None)\n", | ||
"\n", | ||
"# Select channels by optode\n", | ||
"chans = sa.sel_chans_by_opt([\"S1\"], rec[\"od\"])\n", | ||
"\n", | ||
"# Add random events to the timing dataframe\n", | ||
"timing_od = sa.random_events_perc(time, 0.01, [\"spike\"], chans)\n", | ||
"\n", | ||
"display(timing_amp)\n", | ||
"display(timing_od)" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Adding Artifacts to Data\n", | ||
"\n", | ||
"The function add_artifacts automatically scales artifacts and adds them to timeseries data. The function takes arguments\n", | ||
"- ts: cdt.NDTimeSeries\n", | ||
"- timing: pd.DataFrame\n", | ||
"- artifacts: Dict\n", | ||
"- (mode): 'auto' (default) or 'manual'\n", | ||
"- (scale): float = 1\n", | ||
"- (window_size): float = 120s\n", | ||
"\n", | ||
"The artifact functions (see above) are passed as a dictionary. Keys correspond to entries in the column trial_type of the timing dataframe, i.e. each event specified in the timing dataframe is generated using the function artifacts[trial_type]. If mode is 'manual', artifacts are scaled directly by the scale parameter, otherwise artifacts are automatically scaled by a parameter alpha which is calculated using a sliding window approach.\n", | ||
"\n", | ||
"If we want to auto scale based on concentration amplitudes but to add the artifacts to OD data, we can use the function add_chromo_artifacts_2_od. The function requires slightly different arguments because of the conversion between OD and conc:\n", | ||
"- ts: cdt.NDTimeSeries\n", | ||
"- timing: pd.DataFrame\n", | ||
"- artifacts: Dict\n", | ||
"- dpf: differential pathlength factor\n", | ||
"- geo3d: geometry of optodes (see recording object description)\n", | ||
"- (scale)\n", | ||
"- (window_size)\n" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"tags": [ | ||
"nbsphinx-thumbnail" | ||
] | ||
}, | ||
"outputs": [], | ||
"source": [ | ||
"artifacts = {\"spike\": sa.gen_spike, \"bl_shift\": sa.gen_bl_shift}\n", | ||
"\n", | ||
"# Add baseline shifts to the amp data\n", | ||
"rec[\"amp2\"] = sa.add_artifacts(rec[\"amp\"], timing_amp, artifacts)\n", | ||
"\n", | ||
"# Convert the amp data to optical density\n", | ||
"rec[\"od2\"] = cedalion.nirs.int2od(rec[\"amp2\"])\n", | ||
"\n", | ||
"dpf = xr.DataArray(\n", | ||
" [6, 6],\n", | ||
" dims=\"wavelength\",\n", | ||
" coords={\"wavelength\": rec[\"amp\"].wavelength},\n", | ||
")\n", | ||
"\n", | ||
"# add spikes to od based on conc amplitudes\n", | ||
"rec[\"od2\"] = sa.add_chromo_artifacts_2_od(\n", | ||
" rec[\"od2\"], timing_od, artifacts, rec.geo3d, dpf, 1.5\n", | ||
")\n", | ||
"\n", | ||
"# Plot the OD data\n", | ||
"channels = rec[\"od\"].channel.values[0:6]\n", | ||
"fig, axes = p.subplots(len(channels), 1, figsize=(12, len(channels) * 2))\n", | ||
"if len(channels) == 1:\n", | ||
" axes = [axes]\n", | ||
"for i, channel in enumerate(channels):\n", | ||
" ax = axes[i]\n", | ||
" ax.plot(\n", | ||
" rec[\"od2\"].time,\n", | ||
" rec[\"od2\"].sel(channel=channel, wavelength=\"850\"),\n", | ||
" \"g-\",\n", | ||
" label=\"850nm + artifacts\",\n", | ||
" )\n", | ||
" ax.plot(\n", | ||
" rec[\"od\"].time,\n", | ||
" rec[\"od\"].sel(channel=channel, wavelength=\"850\"),\n", | ||
" \"r-\",\n", | ||
" label=\"850nm - od\",\n", | ||
" )\n", | ||
" ax.set_title(f\"Channel: {channel}\")\n", | ||
" ax.set_xlabel(\"Time / s\")\n", | ||
" ax.set_ylabel(\"OD\")\n", | ||
" ax.legend()\n", | ||
"p.tight_layout()\n", | ||
"p.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"# Plot the data in conc\n", | ||
"\n", | ||
"rec[\"conc\"] = cedalion.nirs.od2conc(rec[\"od\"], rec.geo3d, dpf)\n", | ||
"rec[\"conc2\"] = cedalion.nirs.od2conc(rec[\"od2\"], rec.geo3d, dpf)\n", | ||
"channels = rec[\"od\"].channel.values[0:6]\n", | ||
"fig, axes = p.subplots(len(channels), 1, figsize=(12, len(channels) * 2))\n", | ||
"if len(channels) == 1:\n", | ||
" axes = [axes]\n", | ||
"for i, channel in enumerate(channels):\n", | ||
" ax = axes[i]\n", | ||
" ax.plot(\n", | ||
" rec[\"conc2\"].time,\n", | ||
" rec[\"conc2\"].sel(channel=channel, chromo=\"HbR\"),\n", | ||
" \"g-\",\n", | ||
" label=\"HbR + artifacts\",\n", | ||
" )\n", | ||
" ax.plot(\n", | ||
" rec[\"conc\"].time,\n", | ||
" rec[\"conc\"].sel(channel=channel, chromo=\"HbR\"),\n", | ||
" \"b-\",\n", | ||
" label=\"HbR\",\n", | ||
" )\n", | ||
" ax.set_title(f\"Channel: {channel}\")\n", | ||
" ax.set_xlabel(\"Time / s\")\n", | ||
" ax.set_ylabel(\"conc\")\n", | ||
" ax.legend()\n", | ||
"p.tight_layout()\n", | ||
"p.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "markdown", | ||
"metadata": {}, | ||
"source": [ | ||
"## Problems, improvements\n", | ||
"\n", | ||
"- One-function wrapper/interface?\n", | ||
"- More sophisticated artifacts (e.g. smooth baseline shift)" | ||
] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "cedalion", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.9" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 2 | ||
} |
Oops, something went wrong.