From ce53e8ea98e25317b418b9976bc7475756c9bedd Mon Sep 17 00:00:00 2001 From: alisterburt Date: Wed, 6 Nov 2024 15:52:58 -0800 Subject: [PATCH] fix dtype issue in complex data insertion (#14) --- src/torch_image_lerp/linear_interpolation_3d.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/torch_image_lerp/linear_interpolation_3d.py b/src/torch_image_lerp/linear_interpolation_3d.py index 4d089d6..eaef114 100644 --- a/src/torch_image_lerp/linear_interpolation_3d.py +++ b/src/torch_image_lerp/linear_interpolation_3d.py @@ -119,12 +119,12 @@ def insert_into_image_3d( data, coordinates = data[inside], coordinates[inside] # calculate and cache floor and ceil of coordinates for each data point being inserted - _c = torch.empty(size=(data.shape[0], 2, 3), dtype=torch.long, device=image.device) + _c = torch.empty(size=(data.shape[0], 2, 3), dtype=torch.int64, device=image.device) _c[:, 0] = torch.floor(coordinates) # for lower corners _c[:, 1] = torch.ceil(coordinates) # for upper corners # calculate linear interpolation weights for each data point being inserted - _w = torch.empty(size=(data.shape[0], 2, 3), dtype=image.dtype, device=image.device) # (b, 2, zyx) + _w = torch.empty(size=(data.shape[0], 2, 3), dtype=torch.float64, device=image.device) # (b, 2, zyx) _w[:, 1] = coordinates - _c[:, 0] # upper corner weights _w[:, 0] = 1 - _w[:, 1] # lower corner weights