diff --git a/docs/index.html b/docs/index.html index da759e2..6867bad 100644 --- a/docs/index.html +++ b/docs/index.html @@ -2059,7 +2059,7 @@
Return
else: win = np.stack([win] * shape[i], axis=i) - w *= win + w *= win.astype(self.weights_dtype) return w diff --git a/tests/test_merger.py b/tests/test_merger.py index c602bc7..49ada9d 100644 --- a/tests/test_merger.py +++ b/tests/test_merger.py @@ -32,14 +32,14 @@ def test_init(self): merger4 = Merger( tiler=tiler, data_dtype=np.float32, - weights_dtype=np.float32, + weights_dtype=np.int16, window="boxcar", ) self.assertEqual(merger4.data.dtype, np.float32) self.assertEqual(merger4.data_dtype, np.float32) - self.assertEqual(merger4.weights_sum.dtype, np.float32) - self.assertEqual(merger4.weights_dtype, np.float32) - self.assertEqual(merger4.window.dtype, np.float32) + self.assertEqual(merger4.weights_sum.dtype, np.int16) + self.assertEqual(merger4.weights_dtype, np.int16) + self.assertEqual(merger4.window.dtype, np.int16) def test_add(self): tiler = Tiler(data_shape=self.data.shape, tile_shape=(10,)) diff --git a/tiler/merger.py b/tiler/merger.py index 5672142..675eb34 100644 --- a/tiler/merger.py +++ b/tiler/merger.py @@ -153,7 +153,7 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray else: win = np.stack([win] * shape[i], axis=i) - w *= win + w *= win.astype(self.weights_dtype) return w