Skip to content

Commit

Permalink
added to comments, renamed metric classes
Browse files Browse the repository at this point in the history
  • Loading branch information
neel committed Dec 11, 2024
1 parent 74107ad commit c7b848e
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 9 deletions.
1 change: 1 addition & 0 deletions eureka_ml_insights/data_utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
PrependStringTransform,
ExtractAnswerGrid,
ExtractAnswerSpatialMapAndMaze,
ExtractQuestionOptions,
ShuffleColumnsTransform,
ColumnMatchMapTransform,
TokenCounterTransform,
Expand Down
3 changes: 2 additions & 1 deletion eureka_ml_insights/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,12 +2,13 @@
from .geomtric_reasoning_metrics import GeoMCQMetric
from .metrics_base import (
CaseInsensitiveMatch,
CaseInsensitiveOrMatch,
ClassicMetric,
CompositeMetric,
ExactMatch,
IdentityMetric,
Metric,
MultiCandidateAnyExactMatch,
MultiCandidateAnyCaseInsensitiveMatch,
SubstringExistsMatch,
)
from .mmmu_metrics import MMMUMetric
Expand Down
20 changes: 16 additions & 4 deletions eureka_ml_insights/metrics/metrics_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,8 +144,14 @@ def __evaluate__(self, answer_text, target_text, is_valid):
else:
return "incorrect"

class ExactOrMatch(ExactMatch):
"""This class checks for a case-sensitive, but otherwise exact match, and returns the or of them."""
class MultiCandidateAnyExactMatch(ExactMatch):
"""
This class checks for a case-sensitive match for a list of answers from the model output,
and returns the or of the list of metric results.
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
"""

def __evaluate__(self, answer_texts, target_text, is_valid):

Expand All @@ -170,8 +176,14 @@ class CaseInsensitiveMatch(ExactMatch):
def __evaluate__(self, answer_text, target_text, is_valid):
return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid)

class CaseInsensitiveOrMatch(ExactOrMatch):
"""This class checks for a case-insensitive, but otherwise exact or match."""
class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch):
"""
This class checks for a case-insensitive match for a list of answers from the model output,
and returns the or of the list of metric results.
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
"""

def __evaluate__(self, answer_texts, target_text, is_valid):
answer_texts = [str(answer_text).lower() for answer_text in answer_texts]
Expand Down
4 changes: 2 additions & 2 deletions eureka_ml_insights/user_configs/vision_language/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator

from eureka_ml_insights.configs import (
AggregatorConfig,
Expand Down Expand Up @@ -96,7 +96,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
),
},
),
metric_config=MetricConfig(CaseInsensitiveOrMatch),
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
AggregatorConfig(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator

from eureka_ml_insights.configs import (
AggregatorConfig,
Expand Down Expand Up @@ -97,7 +97,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
),
},
),
metric_config=MetricConfig(CaseInsensitiveOrMatch),
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
AggregatorConfig(
Expand Down

0 comments on commit c7b848e

Please sign in to comment.