From 1979ef23fb001778b91ab13488568807e8783fe6 Mon Sep 17 00:00:00 2001 From: Marten Chaillet <58044494+McHaillet@users.noreply.github.com> Date: Thu, 7 Nov 2024 00:55:47 +0100 Subject: [PATCH] fix: add device for tensor creation that prevent gpu usage (#13) --- src/torch_image_lerp/linear_interpolation_2d.py | 4 +++- src/torch_image_lerp/linear_interpolation_3d.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/src/torch_image_lerp/linear_interpolation_2d.py b/src/torch_image_lerp/linear_interpolation_2d.py index 12960b2..d755de5 100644 --- a/src/torch_image_lerp/linear_interpolation_2d.py +++ b/src/torch_image_lerp/linear_interpolation_2d.py @@ -114,7 +114,9 @@ def insert_into_image_2d( coordinates = coordinates.float() # only keep data and coordinates inside the image - in_image_idx = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1) + in_image_idx = (coordinates >= 0) & ( + coordinates <= torch.tensor(image.shape, device=image.device) - 1 + ) in_image_idx = torch.all(in_image_idx, dim=-1) data, coordinates = data[in_image_idx], coordinates[in_image_idx] diff --git a/src/torch_image_lerp/linear_interpolation_3d.py b/src/torch_image_lerp/linear_interpolation_3d.py index eaef114..ef2dda4 100644 --- a/src/torch_image_lerp/linear_interpolation_3d.py +++ b/src/torch_image_lerp/linear_interpolation_3d.py @@ -114,7 +114,9 @@ def insert_into_image_3d( coordinates = coordinates.float() # only keep data and coordinates inside the volume - inside = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1) + inside = (coordinates >= 0) & ( + coordinates <= torch.tensor(image.shape, device=image.device) - 1 + ) inside = torch.all(inside, dim=-1) data, coordinates = data[inside], coordinates[inside]