Skip to content

Commit

Permalink
Merge pull request #529 from prodrigues-tdx/master
Browse files Browse the repository at this point in the history
Add new argument for limiting the maximum epsilon
  • Loading branch information
lmcinnes authored Oct 27, 2024
2 parents aef934c + c101732 commit 5dab8e3
Show file tree
Hide file tree
Showing 3 changed files with 92 additions and 8 deletions.
24 changes: 18 additions & 6 deletions hdbscan/_hdbscan_tree.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -515,9 +515,11 @@ cdef np.ndarray[np.intp_t, ndim=1] do_labelling(
if cluster < root_cluster:
result[n] = -1
elif cluster == root_cluster:
if len(clusters) == 1 and allow_single_cluster:
if len(clusters) == 1 and allow_single_cluster and cluster in cluster_label_map:
# check if `cluster` still exists in `cluster_label_map` and that it was not pruned
# by `max_cluster_size` or `cluster_selection_epsilon_max` before executing this
if cluster_selection_epsilon != 0.0:
if tree['lambda_val'][tree['child'] == n] >= 1 / cluster_selection_epsilon :
if tree['lambda_val'][tree['child'] == n] >= 1 / cluster_selection_epsilon:
result[n] = cluster_label_map[cluster]
else:
result[n] = -1
Expand Down Expand Up @@ -792,7 +794,8 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
allow_single_cluster=False,
match_reference_implementation=False,
cluster_selection_epsilon=0.0,
max_cluster_size=0):
max_cluster_size=0,
cluster_selection_epsilon_max=float('inf')):
"""Given a tree and stability dict, produce the cluster labels
(and probabilities) for a flat clustering based on the chosen
cluster selection method.
Expand All @@ -819,13 +822,18 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
certain edge cases.
cluster_selection_epsilon: float, optional (default 0.0)
A distance threshold for cluster splits.
A distance threshold for cluster splits. This is the minimum
epsilon allowed.
max_cluster_size: int, optional (default 0)
The maximum size for clusters located by the EOM clusterer. Can
be overridden by the cluster_selection_epsilon parameter in
rare cases.
cluster_selection_epsilon_max: float, optional (default inf)
A distance threshold for cluster splits. This is the maximum
epsilon allowed.
Returns
-------
labels : ndarray (n_samples,)
Expand All @@ -842,6 +850,7 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
cdef np.ndarray child_selection
cdef dict is_cluster
cdef dict cluster_sizes
cdef dict node_eps
cdef float subtree_stability
cdef np.intp_t node
cdef np.intp_t sub_node
Expand Down Expand Up @@ -872,18 +881,21 @@ cpdef tuple get_clusters(np.ndarray tree, dict stability,
max_cluster_size = num_points + 1 # Set to a value that will never be triggered
cluster_sizes = {child: child_size for child, child_size
in zip(cluster_tree['child'], cluster_tree['child_size'])}
node_eps = {child: 1/l for child, l
in zip(cluster_tree['child'], cluster_tree['lambda_val'])}
if allow_single_cluster:
# Compute cluster size for the root node
cluster_sizes[node_list[-1]] = np.sum(
cluster_tree[cluster_tree['parent'] == node_list[-1]]['child_size'])
node_eps[node_list[-1]] = np.max(1.0 / tree['lambda_val'])

if cluster_selection_method == 'eom':
for node in node_list:
child_selection = (cluster_tree['parent'] == node)
subtree_stability = np.sum([
stability[child] for
child in cluster_tree['child'][child_selection]])
if subtree_stability > stability[node] or cluster_sizes[node] > max_cluster_size:
if subtree_stability > stability[node] or cluster_sizes[node] > max_cluster_size or node_eps[node] > cluster_selection_epsilon_max:
is_cluster[node] = False
stability[node] = subtree_stability
else:
Expand Down
34 changes: 32 additions & 2 deletions hdbscan/hdbscan_.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def _tree_to_labels(
match_reference_implementation=False,
cluster_selection_epsilon=0.0,
max_cluster_size=0,
cluster_selection_epsilon_max=float('inf'),
):
"""Converts a pretrained tree and cluster size into a
set of labels and probabilities.
Expand All @@ -86,6 +87,7 @@ def _tree_to_labels(
match_reference_implementation,
cluster_selection_epsilon,
max_cluster_size,
cluster_selection_epsilon_max,
)

return (labels, probabilities, stabilities, condensed_tree, single_linkage_tree)
Expand Down Expand Up @@ -529,6 +531,7 @@ def hdbscan(
cluster_selection_method="eom",
allow_single_cluster=False,
match_reference_implementation=False,
cluster_selection_epsilon_max=float('inf'),
**kwargs
):
"""Perform HDBSCAN clustering from a vector array or distance matrix.
Expand All @@ -555,7 +558,7 @@ def hdbscan(
See [3]_ for more information. Note that this should not be used
if we want to predict the cluster labels for new points in future
(e.g. using approximate_predict), as the approximate_predict function
is not aware of this argument.
is not aware of this argument. This is the minimum epsilon allowed.
alpha : float, optional (default=1.0)
A distance scaling parameter as used in robust single linkage.
Expand Down Expand Up @@ -641,6 +644,16 @@ def hdbscan(
performance cost, ensure that the clustering results match the
reference implementation.
cluster_selection_epsilon_max: float, optional (default=inf)
A distance threshold. Clusters above this value will be split.
Has no effect when using leaf clustering (where clusters are
usually small regardless) and can also be overridden in rare
cases by a high value for cluster_selection_epsilon. Note that
this should not be used if we want to predict the cluster labels
for new points in future (e.g. using approximate_predict), as
the approximate_predict function is not aware of this argument.
This is the maximum epsilon allowed.
**kwargs : optional
Arguments passed to the distance metric
Expand Down Expand Up @@ -722,6 +735,9 @@ def hdbscan(
"Minkowski metric with negative p value is not" " defined!"
)

if cluster_selection_epsilon_max < cluster_selection_epsilon:
raise ValueError("Cluster selection epsilon max must be greater than epsilon!")

if match_reference_implementation:
min_samples = min_samples - 1
min_cluster_size = min_cluster_size + 1
Expand Down Expand Up @@ -891,6 +907,7 @@ def hdbscan(
match_reference_implementation,
cluster_selection_epsilon,
max_cluster_size,
cluster_selection_epsilon_max,
)
+ (result_min_span_tree,)
)
Expand Down Expand Up @@ -934,6 +951,7 @@ class HDBSCAN(BaseEstimator, ClusterMixin):
cluster_selection_epsilon: float, optional (default=0.0)
A distance threshold. Clusters below this value will be merged.
This is the minimum epsilon allowed.
See [5]_ for more information.
algorithm : string, optional (default='best')
Expand Down Expand Up @@ -1010,6 +1028,16 @@ class HDBSCAN(BaseEstimator, ClusterMixin):
performance cost, ensure that the clustering results match the
reference implementation.
cluster_selection_epsilon_max: float, optional (default=inf)
A distance threshold. Clusters above this value will be split.
Has no effect when using leaf clustering (where clusters are
usually small regardless) and can also be overridden in rare
cases by a high value for cluster_selection_epsilon. Note that
this should not be used if we want to predict the cluster labels
for new points in future (e.g. using approximate_predict), as
the approximate_predict function is not aware of this argument.
This is the maximum epsilon allowed.
**kwargs : optional
Arguments passed to the distance metric
Expand Down Expand Up @@ -1127,6 +1155,7 @@ def __init__(
prediction_data=False,
branch_detection_data=False,
match_reference_implementation=False,
cluster_selection_epsilon_max=float('inf'),
**kwargs
):
self.min_cluster_size = min_cluster_size
Expand All @@ -1147,6 +1176,7 @@ def __init__(
self.match_reference_implementation = match_reference_implementation
self.prediction_data = prediction_data
self.branch_detection_data = branch_detection_data
self.cluster_selection_epsilon_max = cluster_selection_epsilon_max

self._metric_kwargs = kwargs

Expand Down Expand Up @@ -1296,7 +1326,7 @@ def generate_prediction_data(self):
def generate_branch_detection_data(self):
"""
Create data that caches intermediate results used for detecting
branches within clusters. This data is only useful if you are
branches within clusters. This data is only useful if you are
intending to use functions from ``hdbscan.branches``.
"""
if self.metric in FAST_METRICS:
Expand Down
42 changes: 42 additions & 0 deletions hdbscan/tests/test_hdbscan.py
Original file line number Diff line number Diff line change
Expand Up @@ -581,6 +581,7 @@ def test_hdbscan_badargs():
assert_raises(Exception, hdbscan, X, algorithm="something_else")
assert_raises(TypeError, hdbscan, X, metric="minkowski", p=None)
assert_raises(ValueError, hdbscan, X, leaf_size=0)
assert_raises(ValueError, hdbscan, X, cluster_selection_epsilon_max=-1)


def test_hdbscan_sparse():
Expand Down Expand Up @@ -648,6 +649,47 @@ def test_hdbscan_allow_single_cluster_with_epsilon():
assert counts[unique_labels == -1] == 2


def test_hdbscan_cluster_selection_epsilon_max():
"""Test that reducing the cluster_selection_epsilon_max parameter
results in more clusters with smaller sizes being found."""
blobs, _ = make_blobs(n_samples=50,
centers=[(1, 0), (-1, 0), (-1, 1), (1, 1)],
cluster_std=0.2,
random_state=42)

clusterer = HDBSCAN(cluster_selection_epsilon_max=2.0,
allow_single_cluster=True)
clusterer.fit(blobs)

assert_array_equal(np.unique(clusterer.labels_), np.array([0, 1]))

clusterer = HDBSCAN(cluster_selection_epsilon_max=1.0,
allow_single_cluster=True)
clusterer.fit(blobs)

assert_array_equal(np.unique(clusterer.labels_), np.array([-1, 0, 1, 2, 3]))


def test_hdbscan_parameters_do_not_trigger_errors():
blobs, _ = make_blobs(n_samples=50,
centers=[(1, 0), (-1, 0), (-1, 1), (1, 1)],
cluster_std=0.2,
random_state=42)
clusterer = HDBSCAN(max_cluster_size=10,
allow_single_cluster=True)

# If the following line does not raise an error, the test passes
clusterer.fit(blobs)
assert True

clusterer = HDBSCAN(cluster_selection_epsilon_max=0.41,
cluster_selection_epsilon=0.4,
allow_single_cluster=True)

# If the following line does not raise an error, the test passes
clusterer.fit(blobs)
assert True

# Disable for now -- need to refactor to meet newer standards
@pytest.mark.skip(reason="need to refactor to meet newer standards")
def test_hdbscan_is_sklearn_estimator():
Expand Down

0 comments on commit 5dab8e3

Please sign in to comment.