Skip to content

Commit

Permalink
Fixed float/int mismatch when generating windows
Browse files Browse the repository at this point in the history
  • Loading branch information
the-lay committed Jan 6, 2022
1 parent 7d75102 commit 555698f
Show file tree
Hide file tree
Showing 3 changed files with 6 additions and 6 deletions.
2 changes: 1 addition & 1 deletion docs/index.html
Original file line number Diff line number Diff line change
Expand Up @@ -2059,7 +2059,7 @@ <h6 id="return">Return</h6>
<span class="k">else</span><span class="p">:</span>
<span class="n">win</span> <span class="o">=</span> <span class="n">np</span><span class="o">.</span><span class="n">stack</span><span class="p">([</span><span class="n">win</span><span class="p">]</span> <span class="o">*</span> <span class="n">shape</span><span class="p">[</span><span class="n">i</span><span class="p">],</span> <span class="n">axis</span><span class="o">=</span><span class="n">i</span><span class="p">)</span>

<span class="n">w</span> <span class="o">*=</span> <span class="n">win</span>
<span class="n">w</span> <span class="o">*=</span> <span class="n">win</span><span class="o">.</span><span class="n">astype</span><span class="p">(</span><span class="bp">self</span><span class="o">.</span><span class="n">weights_dtype</span><span class="p">)</span>

<span class="k">return</span> <span class="n">w</span>

Expand Down
8 changes: 4 additions & 4 deletions tests/test_merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,14 +32,14 @@ def test_init(self):
merger4 = Merger(
tiler=tiler,
data_dtype=np.float32,
weights_dtype=np.float32,
weights_dtype=np.int16,
window="boxcar",
)
self.assertEqual(merger4.data.dtype, np.float32)
self.assertEqual(merger4.data_dtype, np.float32)
self.assertEqual(merger4.weights_sum.dtype, np.float32)
self.assertEqual(merger4.weights_dtype, np.float32)
self.assertEqual(merger4.window.dtype, np.float32)
self.assertEqual(merger4.weights_sum.dtype, np.int16)
self.assertEqual(merger4.weights_dtype, np.int16)
self.assertEqual(merger4.window.dtype, np.int16)

def test_add(self):
tiler = Tiler(data_shape=self.data.shape, tile_shape=(10,))
Expand Down
2 changes: 1 addition & 1 deletion tiler/merger.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def _generate_window(self, window: str, shape: Union[Tuple, List]) -> np.ndarray
else:
win = np.stack([win] * shape[i], axis=i)

w *= win
w *= win.astype(self.weights_dtype)

return w

Expand Down

0 comments on commit 555698f

Please sign in to comment.