From 702899ea0aeb57b5761b85c43d1416ca81d4da00 Mon Sep 17 00:00:00 2001 From: Jonathan de Bruin Date: Thu, 14 Mar 2019 20:59:51 +0100 Subject: [PATCH] resolve conflict with threshold and missing value (#85) Closes #70 --- recordlinkage/compare.py | 9 +++++---- tests/test_compare.py | 42 ++++++++++++++++++++++++++++++++++++++++ 2 files changed, 47 insertions(+), 4 deletions(-) diff --git a/recordlinkage/compare.py b/recordlinkage/compare.py index b5ea9479..1a3153fc 100644 --- a/recordlinkage/compare.py +++ b/recordlinkage/compare.py @@ -151,13 +151,14 @@ def _compute_vectorized(self, s_left, s_right): self.method)) c = str_sim_alg(s_left, s_right) - c = _fillna(c, self.missing_value) if self.threshold is not None: - return (c >= self.threshold).astype(numpy.float64) - else: - return c + c = c.where((c < self.threshold) | (pandas.isnull(c)), other=1.0) + c = c.where((c >= self.threshold) | (pandas.isnull(c)), other=0.0) + c = _fillna(c, self.missing_value) + + return c class Numeric(BaseCompareFeature): """Compute the (partial) similarity between numeric values. diff --git a/tests/test_compare.py b/tests/test_compare.py index a483a11a..c2eb10c7 100644 --- a/tests/test_compare.py +++ b/tests/test_compare.py @@ -1183,6 +1183,48 @@ def test_fuzzy(self): assert (result[result.notnull()] >= 0).all(1).all(0) assert (result[result.notnull()] <= 1).all(1).all(0) + def test_threshold(self): + + A = DataFrame({'col': [u"gretzky", u"gretzky99", u"gretzky", u"gretzky"]}) + B = DataFrame({'col': [u"gretzky", u"gretzky", nan, u"wayne"]}) + ix = MultiIndex.from_arrays([A.index.values, B.index.values]) + + comp = recordlinkage.Compare() + comp.string( + 'col', + 'col', + method="levenshtein", + threshold=0.5, + missing_value=2.0, + label="x_col1" + ) + comp.string( + 'col', + 'col', + method="levenshtein", + threshold=1.0, + missing_value=0.5, + label="x_col2" + ) + comp.string( + 'col', + 'col', + method="levenshtein", + threshold=0.0, + missing_value=nan, + label="x_col3" + ) + result = comp.compute(ix, A, B) + + expected = Series([1.0, 1.0, 2.0, 0.0], index=ix, name="x_col1") + pdt.assert_series_equal(result["x_col1"], expected) + + expected = Series([1.0, 0.0, 0.5, 0.0], index=ix, name="x_col2") + pdt.assert_series_equal(result["x_col2"], expected) + + expected = Series([1.0, 1.0, nan, 1.0], index=ix, name="x_col3") + pdt.assert_series_equal(result["x_col3"], expected) + @pytest.mark.parametrize("alg", STRING_SIM_ALGORITHMS) def test_incorrect_input(self, alg):