Skip to content

Commit

Permalink
make GLM basis GaussianKernels behave like in Homer
Browse files Browse the repository at this point in the history
The meaning of t_pre in GaussianKernels differed from the implemenation
in Homer2. In cedalion the Gaussians were centered in [-t_pre, t_post]
with enough padding so that the Gaussian kernels could decay back to the
baseline. In Homer the first Gaussian is centered at -t_pre. In order
to avoid confusion for switching users the basis function class
GaussianKernels behaves now as in Homer. The implementation with
padding is available as GaussianKernelsWithTails.
  • Loading branch information
emiddell committed Dec 16, 2024
1 parent b192089 commit 228b6b0
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 9 deletions.
27 changes: 26 additions & 1 deletion examples/modeling/31_glm_basis_functions.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,31 @@
"p.figure()\n",
"for i_comp, comp in enumerate(hrf.component.values):\n",
" p.plot(hrf.time, hrf[:, i_comp], label=comp)\n",
"\n",
"p.axvline(-5, c=\"r\", ls=\":\")\n",
"p.axvline(30, c=\"r\", ls=\":\")\n",
"p.legend(ncols=3)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"basis = bf.GaussianKernelsWithTails(\n",
" t_pre=5 * units.s,\n",
" t_post=30 * units.s,\n",
" t_delta=3 * units.s,\n",
" t_std=3 * units.s,\n",
")\n",
"hrf = basis(ts)\n",
"\n",
"p.figure()\n",
"for i_comp, comp in enumerate(hrf.component.values):\n",
" p.plot(hrf.time, hrf[:, i_comp], label=comp)\n",
"p.axvline(-5, c=\"r\", ls=\":\")\n",
"p.axvline(30, c=\"r\", ls=\":\")\n",
"p.legend(ncols=3)"
]
},
Expand Down Expand Up @@ -162,7 +187,7 @@
],
"metadata": {
"kernelspec": {
"display_name": "cedalion_240720",
"display_name": "cedalion_241112",
"language": "python",
"name": "python3"
},
Expand Down
92 changes: 84 additions & 8 deletions src/cedalion/models/glm/basis_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from cedalion.sigproc.frequency import sampling_rate
import cedalion.xrutils as xrutils


class TemporalBasisFunction(ABC):
def __init__(self, convolve_over_duration: bool):
self.convolve_over_duration = convolve_over_duration
Expand All @@ -32,7 +33,7 @@ def __call__(
raise NotImplementedError()


class GaussianKernels(TemporalBasisFunction):
class GaussianKernelsWithTails(TemporalBasisFunction):
r"""A consecutive sequence of gaussian functions.
The basis functions have the form:
Expand Down Expand Up @@ -111,6 +112,84 @@ def __call__(
)


class GaussianKernels(TemporalBasisFunction):
r"""A consecutive sequence of gaussian functions.
The basis functions have the form:
.. math::
f(t) = \exp( -(t-\mu)^2/t_{std}^2)
The user specifies a time interval around the stimuls onset via the parameters
t_pre and t_post. Over this time interval a series of gaussian basis functions is
distributed:
- between the gaussian centers there is time gap of t_delta
- the width of the each gaussian is specified by t_std
- the first gaussian is centered at trial onset - t_pre.
- the model function extends strictly from -t_pre to t_post with a hard cutoff.
The number of gaussians is derived automatically from these constraints.
Args:
t_pre (:class:`Quantity`, [time]): time before trial onset
t_post (:class:`Quantity`, [time]): time after trial onset
t_delta (:class:`Quantity`, [time]): the temporal spacing between consecutive
gaussians
t_std (:class:`Quantity`, [time]): time width of the gaussians
"""

def __init__(
self,
t_pre: Annotated[Quantity, "[time]"],
t_post: Annotated[Quantity, "[time]"],
t_delta: Annotated[Quantity, "[time]"],
t_std: Annotated[Quantity, "[time]"],
):
super().__init__(convolve_over_duration=False)
self.t_pre = _to_unit(t_pre, units.s)
self.t_post = _to_unit(t_post, units.s)
self.t_delta = _to_unit(t_delta, units.s)
self.t_std = _to_unit(t_std, units.s)

def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:
fs = sampling_rate(ts).to(units.Hz)

# create time-axis
smpl_pre = int(np.ceil(self.t_pre * fs)) + 1
smpl_post = int(np.ceil(self.t_post * fs)) + 1
t_hrf = np.arange(-smpl_pre, smpl_post) / fs
t_hrf = t_hrf.to("s")

duration = t_hrf[-1] - t_hrf[0]

# determine number of gaussians
n_components = int(np.floor(duration / self.t_delta))

# place gaussians spaced by t_delta and starting at -t_pre.
mu = t_hrf[0] + np.arange(n_components) * self.t_delta

# build regressors. shape: (n_times, n_regressors)
regressors = np.exp(
-((t_hrf[:, None] - mu[None, :]) ** 2) / (2 * self.t_std) ** 2
)
regressors /= regressors.max(axis=0) # normalize gaussian peaks to 1
regressors = regressors.to_base_units().magnitude

component_names = _generate_component_names(n_components)

return xr.DataArray(
regressors,
dims=["time", "component"],
coords={
"time": xr.DataArray(t_hrf, dims=["time"]).pint.dequantify(),
"component": component_names,
},
)


class Gamma(TemporalBasisFunction):
r"""Modified gamma function, optionally convolved with a square-wave.
Expand Down Expand Up @@ -209,7 +288,6 @@ def __call__(
self,
ts: cdt.NDTimeSeries,
) -> xr.DataArray:

other_dim = xrutils.other_dim(ts, "time", "channel")
other_dim_values = ts[other_dim].values

Expand All @@ -229,7 +307,7 @@ def __call__(

for i_other, other in enumerate(other_dim_values):
x = (t_hrf - tau[other]) / sigma[other]
x2 = x.magnitude ** 2
x2 = x.magnitude**2
r = x2 * np.exp(-x2)
dr = (2 * x * (1 - x2)) * np.exp(-x2)
dr = dr.magnitude
Expand Down Expand Up @@ -289,7 +367,7 @@ def __call__(
T = _to_dict(self.T, other_dim_values) # noqa: N806

fs = sampling_rate(ts).to(units.Hz)
duration = 4.1 * max(q.values()) + max(T.values())
duration = 4.1 * max(q.values()) + max(T.values())
duration = (duration.to(units.s).magnitude + max(p.values())) * units.s
# add p to duration. but p has not unit and duration is in seconds.
# so we need to convert p to seconds.
Expand All @@ -303,9 +381,7 @@ def __call__(

for i_other, other in enumerate(other_dim_values):
bas = t_hrf / (p[other] * q[other])
r = np.power(bas.magnitude, p[other]) * np.exp(
p[other] - t_hrf / q[other]
)
r = np.power(bas.magnitude, p[other]) * np.exp(p[other] - t_hrf / q[other])
r[t_hrf < 0] = 0.0
r = r.magnitude

Expand All @@ -330,7 +406,7 @@ def __call__(
# FIXME: instead of defining IndividualBasis we may want to make make_hrf_regressor
# accept xr.DataArrays directly?

#class IndividualBasis(TemporalBasisFunction):
# class IndividualBasis(TemporalBasisFunction):
# """Uses individual basis functions for each channel.
#
# Args:
Expand Down

0 comments on commit 228b6b0

Please sign in to comment.