Skip to content

Commit

Permalink
fix optode_dirs if not perfectly unitary (#71)
Browse files Browse the repository at this point in the history
* fix optode_dirs if not perfectly unitary
* move normal normalization to TrimeshSurface.get_vertex_normals

---------

Co-authored-by: Eike Middell <[email protected]>
  • Loading branch information
harmening and emiddell authored Jan 10, 2025
1 parent ae5fccf commit 023b09a
Show file tree
Hide file tree
Showing 2 changed files with 41 additions and 21 deletions.
22 changes: 17 additions & 5 deletions src/cedalion/dataclasses/geometry.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import cedalion
import cedalion.typing as cdt
from cedalion.vtktutils import trimesh_to_vtk_polydata, pyvista_polydata_to_trimesh

import cedalion.xrutils as xrutils

@total_ordering
class PointType(Enum):
Expand Down Expand Up @@ -228,7 +228,7 @@ def smooth(self, lamb: float) -> "TrimeshSurface":
smoothed = trimesh.smoothing.filter_taubin(self.mesh, lamb=lamb)
return TrimeshSurface(smoothed, self.crs, self.units)

def get_vertex_normals(self, points: cdt.LabeledPointCloud):
def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True):
"""Get normals of vertices closest to the provided points."""

assert points.points.crs == self.crs
Expand All @@ -237,12 +237,22 @@ def get_vertex_normals(self, points: cdt.LabeledPointCloud):

_, vertex_indices = self.kdtree.query(points.values, workers=-1)

return xr.DataArray(
normals = xr.DataArray(
self.mesh.vertex_normals[vertex_indices],
dims=["label", self.crs],
coords={"label": points.label},
)

if normalized:
norms = xrutils.norm(normals, dim=normals.points.crs)

if not (norms > 0).all():
raise ValueError("Cannot normalize normals with zero length.")

normals /= norms

return normals

def fix_vertex_normals(self):
mesh = self.mesh
# again make sure, that normals face outside
Expand Down Expand Up @@ -384,7 +394,7 @@ def apply_transform(self, transform: cdt.AffineTransform) -> "PycortexSurface":
def decimate(self, face_count: int) -> "PycortexSurface":
raise NotImplementedError("Decimation not implemented for PycortexSurface")

def get_vertex_normals(self, points: cdt.LabeledPointCloud):
def get_vertex_normals(self, points: cdt.LabeledPointCloud, normalized=True):
assert points.points.crs == self.crs
assert points.pint.units == self.units
points = points.pint.dequantify()
Expand All @@ -401,7 +411,9 @@ def get_vertex_normals(self, points: cdt.LabeledPointCloud):
for i, poly in enumerate(self.mesh.polys):
for j in poly:
vertex_normals[j] += face_normals[i]
vertex_normals /= np.linalg.norm(vertex_normals, axis=1)[:, np.newaxis]

if normalized:
vertex_normals /= np.linalg.norm(vertex_normals, axis=1)[:, np.newaxis]

return xr.DataArray(
vertex_normals[vertex_indices],
Expand Down
40 changes: 24 additions & 16 deletions src/cedalion/imagereco/forward_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
from cedalion.geometry.registration import register_trans_rot_isoscale
import cedalion.typing as cdt
import cedalion.xrutils as xrutils
from cedalion.geometry.segmentation import surface_from_segmentation, voxels_from_segmentation
from cedalion.geometry.segmentation import (
surface_from_segmentation,
voxels_from_segmentation,
)
from cedalion.imagereco.utils import map_segmentation_mask_to_surface

from .tissue_properties import get_tissue_properties
Expand Down Expand Up @@ -220,7 +223,7 @@ def from_surfaces(
scalp_face_count: Optional[int] = 60000,
fill_holes: bool = False,
) -> "TwoSurfaceHeadModel":
"""Constructor from binary masks, brain and head surfaces as gained from MRI scans.
"""Constructor from seg.masks, brain and head surfaces as gained from MRI scans.
Args:
segmentation_dir (str): Folder containing the segmentation masks in NIFTI
Expand Down Expand Up @@ -536,9 +539,9 @@ def snap_to_scalp_voxels(

if len(voxel_idx) > 0:
# Get voxel coordinates from voxel indices
try:
try:
shape = self.segmentation_masks.shape[-3:]
except:
except AttributeError: # FIXME should not be handled here
shape = self.segmentation_masks.to_dataarray().shape[-3:]
voxels = np.array(np.unravel_index(voxel_idx, shape)).T

Expand All @@ -547,21 +550,24 @@ def snap_to_scalp_voxels(
voxel_idx = np.argmin(dist)

else:
# If no voxel maps to that scalp surface vertex,
# If no voxel maps to that scalp surface vertex,
# simply choose the closest of all scalp voxels
voxels = voxels_from_segmentation(self.segmentation_masks, ["scalp"]).voxels

sm = self.segmentation_masks

voxels = voxels_from_segmentation(sm, ["scalp"]).voxels
if len(voxels) == 0:
try:
scalp_mask = self.segmentation_masks.sel(segmentation_type="scalp").to_dataarray()
except:
scalp_mask = self.segmentation_masks.sel(segmentation_type="scalp")
scalp_mask = sm.sel(segmentation_type="scalp").to_dataarray()
except AttributeError: # FIXME same as above
scalp_mask = sm.sel(segmentation_type="scalp")
voxels = np.argwhere(np.array(scalp_mask)[0] > 0.99)

kdtree = KDTree(voxels)
dist, voxel_idx = kdtree.query(self.scalp.mesh.vertices[idx[0,0]],
workers=-1)

# Snap to closest scalp voxel
# Snap to closest scalp voxel
snapped[i] = voxels[voxel_idx]

points.values = snapped
Expand Down Expand Up @@ -617,7 +623,12 @@ def __init__(
]

# Comppute the direction of the light beam from the surface normals
self.optode_dir = -head_model.scalp.get_vertex_normals(self.optode_pos)
# pmcx fails if directions are not normalized
self.optode_dir = -head_model.scalp.get_vertex_normals(
self.optode_pos,
normalized=True,
)

# Slightly realign the optode positions to the closest scalp voxel
self.optode_pos = head_model.snap_to_scalp_voxels(self.optode_pos)

Expand Down Expand Up @@ -804,14 +815,11 @@ def compute_fluence_mcx(self, nphoton: int = 1e8):

return fluence_all, fluence_at_optodes


def compute_fluence_nirfaster(
self, meshingparam = None
):
def compute_fluence_nirfaster(self, meshingparam=None):
"""Compute fluence for each channel and wavelength using NIRFASTer package.
Args:
meshingparam (ff.utils.MeshingParam) Parameters to be used by the CGAL
meshingparam (ff.utils.MeshingParam): Parameters to be used by the CGAL
mesher. Note: they should all be double
Returns:
Expand Down

0 comments on commit 023b09a

Please sign in to comment.