Skip to content

Commit

Permalink
Refactor and add comments to the code.
Browse files Browse the repository at this point in the history
  • Loading branch information
shunping committed Feb 7, 2025
1 parent 24c77e9 commit 371d465
Show file tree
Hide file tree
Showing 4 changed files with 382 additions and 158 deletions.
103 changes: 100 additions & 3 deletions sdks/python/apache_beam/ml/anomaly/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,9 @@
# limitations under the License.
#

"""Base classes for anomaly detection"""
"""
Base classes for anomaly detection
"""
from __future__ import annotations

import abc
Expand All @@ -26,51 +28,111 @@

import apache_beam as beam

__all__ = [
"AnomalyPrediction",
"AnomalyResult",
"ThresholdFn",
"AggregationFn",
"AnomalyDetector",
"EnsembleAnomalyDetector"
]


@dataclass(frozen=True)
class AnomalyPrediction():
"""A dataclass for anomaly detection predictions."""
#: The ID of detector (model) that generates the prediction.
model_id: Optional[str] = None
#: The outlier score resulting from applying the detector to the input data.
score: Optional[float] = None
#: The outlier label (normal or outlier) derived from the outlier score.
label: Optional[int] = None
#: The threshold used to determine the label.
threshold: Optional[float] = None
#: Additional information about the prediction.
info: str = ""
#: If enabled, a list of `AnomalyPrediction` objects used to derive the
#: aggregated prediction.
agg_history: Optional[Iterable[AnomalyPrediction]] = None


@dataclass(frozen=True)
class AnomalyResult():
"""A dataclass for the anomaly detection results"""
#: The original input data.
example: beam.Row
#: The `AnomalyPrediction` object containing the prediction.
prediction: AnomalyPrediction


class ThresholdFn(abc.ABC):
"""An abstract base class for threshold functions.
Args:
normal_label: The integer label used to identify normal data. Defaults to 0.
outlier_label: The integer label used to identify outlier data. Defaults to
1.
"""
def __init__(self, normal_label: int = 0, outlier_label: int = 1):
self._normal_label = normal_label
self._outlier_label = outlier_label

@property
@abc.abstractmethod
def is_stateful(self) -> bool:
"""Indicates whether the threshold function is stateful or not."""
raise NotImplementedError

@property
@abc.abstractmethod
def threshold(self) -> Optional[float]:
"""Retrieves the current threshold value, or None if not set."""
raise NotImplementedError

@abc.abstractmethod
def apply(self, score: Optional[float]) -> int:
"""Applies the threshold function to a given score to classify it as
normal or outlier.
Args:
score: The outlier score generated from the detector (model).
Returns:
The label assigned to the score, either `self._normal_label`
or `self._outlier_label`
"""
raise NotImplementedError


class AggregationFn(abc.ABC):
"""An abstract base class for aggregation functions."""
@abc.abstractmethod
def apply(
self, predictions: Iterable[AnomalyPrediction]) -> AnomalyPrediction:
"""Applies the aggregation function to an iterable of predictions, either on
their outlier scores or labels.
Args:
predictions: An Iterable of `AnomalyPrediction` objects to aggregate.
Returns:
An `AnomalyPrediction` object containing the aggregated result.
"""
raise NotImplementedError


class AnomalyDetector(abc.ABC):
"""An abstract base class for anomaly detectors.
Args:
model_id: The ID of detector (model). Defaults to the value of the
`spec_type` attribute, or 'unknown' if not set.
features: An Iterable of strings representing the names of the input
features in the `beam.Row`
target: The name of the target field in the `beam.Row`.
threshold_criterion: An optional `ThresholdFn` to apply to the outlier score
and yield a label.
"""
def __init__(
self,
model_id: Optional[str] = None,
Expand All @@ -79,36 +141,71 @@ def __init__(
threshold_criterion: Optional[ThresholdFn] = None,
**kwargs):
self._model_id = model_id if model_id is not None else getattr(
self, '_key', 'unknown')
self, 'spec_type', 'unknown')
self._features = features
self._target = target
self._threshold_criterion = threshold_criterion

@abc.abstractmethod
def learn_one(self, x: beam.Row) -> None:
"""Trains the detector on a single data instance.
Args:
x: A `beam.Row` representing the data instance.
"""
raise NotImplementedError

@abc.abstractmethod
def score_one(self, x: beam.Row) -> float:
"""Scores a single data instance for anomalies.
Args:
x: A `beam.Row` representing the data instance.
Returns:
The outlier score as a float.
"""
raise NotImplementedError


class EnsembleAnomalyDetector(AnomalyDetector):
"""An abstract base class for an ensemble of anomaly (sub-)detectors.
Args:
sub_detectors: A List of `AnomalyDetector` used in this ensemble model.
aggregation_strategy: An optional `AggregationFn` to apply to the
predictions from all sub-detectors and yield an aggregated result.
model_id: Inherited from `AnomalyDetector`.
features: Inherited from `AnomalyDetector`.
target: Inherited from `AnomalyDetector`.
threshold_criterion: Inherited from `AnomalyDetector`.
"""
def __init__(
self,
sub_detectors: Optional[List[AnomalyDetector]] = None,
aggregation_strategy: Optional[AggregationFn] = None,
**kwargs):
if "model_id" not in kwargs or kwargs["model_id"] is None:
kwargs["model_id"] = getattr(self, '_key', 'custom')
kwargs["model_id"] = getattr(self, 'spec_type', 'custom')

super().__init__(**kwargs)

self._aggregation_strategy = aggregation_strategy
self._sub_detectors = sub_detectors

def learn_one(self, x: beam.Row) -> None:
"""Inherited from `AnomalyDetector.learn_one`.
This method is never called during ensemble detector training. The training
process is done on each sub-detector independently and in parallel.
"""
raise NotImplementedError

def score_one(self, x: beam.Row) -> float:
"""Inherited from `AnomalyDetector.score_one`.
This method is never called during ensemble detector scoring. The scoring
process is done on sub-detector independently and in parallel, and then
the results are aggregated in the pipeline.
"""
raise NotImplementedError
7 changes: 2 additions & 5 deletions sdks/python/apache_beam/ml/anomaly/base_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,6 @@ def __eq__(self, value) -> bool:
return isinstance(value, TestAnomalyDetector.Dummy) and \
self._my_arg == value._my_arg

def test_unknown_detector(self):
self.assertRaises(ValueError, Specifiable.from_spec, Spec(type="unknown"))

def test_model_id_on_known_detector(self):
a = self.Dummy(
my_arg="abc",
Expand All @@ -75,7 +72,7 @@ def test_model_id_on_known_detector(self):

assert isinstance(a, Specifiable)
self.assertEqual(
a._init_params, {
a.init_kwargs, {
"my_arg": "abc",
"target": "ABC",
"threshold_criterion": t1,
Expand All @@ -92,7 +89,7 @@ def test_model_id_on_known_detector(self):

assert isinstance(b, Specifiable)
self.assertEqual(
b._init_params,
b.init_kwargs,
{
"model_id": "my_dummy",
"my_arg": "efg",
Expand Down
Loading

0 comments on commit 371d465

Please sign in to comment.