Skip to content

Commit

Permalink
fix: add device for tensor creation that prevent gpu usage (#13)
Browse files Browse the repository at this point in the history
  • Loading branch information
McHaillet authored Nov 6, 2024
1 parent ce53e8e commit 1979ef2
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 2 deletions.
4 changes: 3 additions & 1 deletion src/torch_image_lerp/linear_interpolation_2d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def insert_into_image_2d(
coordinates = coordinates.float()

# only keep data and coordinates inside the image
in_image_idx = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1)
in_image_idx = (coordinates >= 0) & (
coordinates <= torch.tensor(image.shape, device=image.device) - 1
)
in_image_idx = torch.all(in_image_idx, dim=-1)
data, coordinates = data[in_image_idx], coordinates[in_image_idx]

Expand Down
4 changes: 3 additions & 1 deletion src/torch_image_lerp/linear_interpolation_3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,9 @@ def insert_into_image_3d(
coordinates = coordinates.float()

# only keep data and coordinates inside the volume
inside = (coordinates >= 0) & (coordinates <= torch.tensor(image.shape) - 1)
inside = (coordinates >= 0) & (
coordinates <= torch.tensor(image.shape, device=image.device) - 1
)
inside = torch.all(inside, dim=-1)
data, coordinates = data[inside], coordinates[inside]

Expand Down

0 comments on commit 1979ef2

Please sign in to comment.