Skip to content

Commit

Permalink
removed new multicandidate metrics and return or or multiple answers …
Browse files Browse the repository at this point in the history
…and substring match instead
  • Loading branch information
neel committed Dec 17, 2024
1 parent c7b848e commit ae456f1
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 49 deletions.
2 changes: 1 addition & 1 deletion eureka_ml_insights/data_utils/spatial_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options):
if answers_match:
model_output_parsed = answers_match.group(1)

return [model_output_parsed, model_output_parsed_letter]
return model_output_parsed + " or " + model_output_parsed_letter


def extract_answer_from_text_maze(text, question_type):
Expand Down
2 changes: 0 additions & 2 deletions eureka_ml_insights/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,8 +7,6 @@
ExactMatch,
IdentityMetric,
Metric,
MultiCandidateAnyExactMatch,
MultiCandidateAnyCaseInsensitiveMatch,
SubstringExistsMatch,
)
from .mmmu_metrics import MMMUMetric
Expand Down
38 changes: 0 additions & 38 deletions eureka_ml_insights/metrics/metrics_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,51 +144,13 @@ def __evaluate__(self, answer_text, target_text, is_valid):
else:
return "incorrect"

class MultiCandidateAnyExactMatch(ExactMatch):
"""
This class checks for a case-sensitive match for a list of answers from the model output,
and returns the or of the list of metric results.
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
"""

def __evaluate__(self, answer_texts, target_text, is_valid):

if not is_valid:
return "none"

results = []
for answer_text in answer_texts:
res = super().__evaluate__(str(answer_text), str(target_text), is_valid)
results.append(res)

corrects = [x=="correct" for x in results]

if (any(corrects)):
return "correct"
else:
return "incorrect"

class CaseInsensitiveMatch(ExactMatch):
"""This class checks for a case-insensitive, but otherwise exact match."""

def __evaluate__(self, answer_text, target_text, is_valid):
return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid)

class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch):
"""
This class checks for a case-insensitive match for a list of answers from the model output,
and returns the or of the list of metric results.
This is required for answers to multiple-choice questions. As many models sometimes give the letter answer
and sometimes the full word answer. This allows one to consider the answer correct if either one was correct.
"""

def __evaluate__(self, answer_texts, target_text, is_valid):
answer_texts = [str(answer_text).lower() for answer_text in answer_texts]
return super().__evaluate__(answer_texts, str(target_text).lower(), is_valid)


class IdentityMetric(Metric):

Expand Down
8 changes: 4 additions & 4 deletions eureka_ml_insights/user_configs/vision_language/maze.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator
from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator

from eureka_ml_insights.configs import (
AggregatorConfig,
Expand Down Expand Up @@ -96,13 +96,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
),
},
),
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
metric_config=MetricConfig(SubstringExistsMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}),
AggregatorConfig(
CountAggregator,
{
"column_names": ["CaseInsensitiveOrMatch_result"],
"column_names": ["SubstringExistsMatch_result"],
"group_by": "task",
"normalize": True,
},
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
PrependStringTransform,
SequenceTransform,
)
from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator
from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator

from eureka_ml_insights.configs import (
AggregatorConfig,
Expand Down Expand Up @@ -97,13 +97,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None)
),
},
),
metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch),
metric_config=MetricConfig(SubstringExistsMatch),
aggregator_configs=[
AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}),
AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}),
AggregatorConfig(
CountAggregator,
{
"column_names": ["CaseInsensitiveOrMatch_result"],
"column_names": ["SubstringExistsMatch_result"],
"group_by": "task",
"normalize": True,
},
Expand Down

0 comments on commit ae456f1

Please sign in to comment.