From ae456f17927555e4f2f3a148da36e2601675df55 Mon Sep 17 00:00:00 2001 From: neel Date: Tue, 17 Dec 2024 14:42:35 -0700 Subject: [PATCH] removed new multicandidate metrics and return or or multiple answers and substring match instead --- .../data_utils/spatial_utils.py | 2 +- eureka_ml_insights/metrics/__init__.py | 2 - eureka_ml_insights/metrics/metrics_base.py | 38 ------------------- .../user_configs/vision_language/maze.py | 8 ++-- .../vision_language/spatial_map.py | 8 ++-- 5 files changed, 9 insertions(+), 49 deletions(-) diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 68dc332..f2907d2 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -260,7 +260,7 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): if answers_match: model_output_parsed = answers_match.group(1) - return [model_output_parsed, model_output_parsed_letter] + return model_output_parsed + " or " + model_output_parsed_letter def extract_answer_from_text_maze(text, question_type): diff --git a/eureka_ml_insights/metrics/__init__.py b/eureka_ml_insights/metrics/__init__.py index 3e6521e..2b19ec1 100644 --- a/eureka_ml_insights/metrics/__init__.py +++ b/eureka_ml_insights/metrics/__init__.py @@ -7,8 +7,6 @@ ExactMatch, IdentityMetric, Metric, - MultiCandidateAnyExactMatch, - MultiCandidateAnyCaseInsensitiveMatch, SubstringExistsMatch, ) from .mmmu_metrics import MMMUMetric diff --git a/eureka_ml_insights/metrics/metrics_base.py b/eureka_ml_insights/metrics/metrics_base.py index a96146d..45d4249 100644 --- a/eureka_ml_insights/metrics/metrics_base.py +++ b/eureka_ml_insights/metrics/metrics_base.py @@ -144,31 +144,6 @@ def __evaluate__(self, answer_text, target_text, is_valid): else: return "incorrect" -class MultiCandidateAnyExactMatch(ExactMatch): - """ - This class checks for a case-sensitive match for a list of answers from the model output, - and returns the or of the list of metric results. - - This is required for answers to multiple-choice questions. As many models sometimes give the letter answer - and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. - """ - - def __evaluate__(self, answer_texts, target_text, is_valid): - - if not is_valid: - return "none" - - results = [] - for answer_text in answer_texts: - res = super().__evaluate__(str(answer_text), str(target_text), is_valid) - results.append(res) - - corrects = [x=="correct" for x in results] - - if (any(corrects)): - return "correct" - else: - return "incorrect" class CaseInsensitiveMatch(ExactMatch): """This class checks for a case-insensitive, but otherwise exact match.""" @@ -176,19 +151,6 @@ class CaseInsensitiveMatch(ExactMatch): def __evaluate__(self, answer_text, target_text, is_valid): return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid) -class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch): - """ - This class checks for a case-insensitive match for a list of answers from the model output, - and returns the or of the list of metric results. - - This is required for answers to multiple-choice questions. As many models sometimes give the letter answer - and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. - """ - - def __evaluate__(self, answer_texts, target_text, is_valid): - answer_texts = [str(answer_text).lower() for answer_text in answer_texts] - return super().__evaluate__(answer_texts, str(target_text).lower(), is_valid) - class IdentityMetric(Metric): diff --git a/eureka_ml_insights/user_configs/vision_language/maze.py b/eureka_ml_insights/user_configs/vision_language/maze.py index a4d7eb1..b461a93 100644 --- a/eureka_ml_insights/user_configs/vision_language/maze.py +++ b/eureka_ml_insights/user_configs/vision_language/maze.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -96,13 +96,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveOrMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/eureka_ml_insights/user_configs/vision_language/spatial_map.py b/eureka_ml_insights/user_configs/vision_language/spatial_map.py index 81f9f97..4bb2134 100644 --- a/eureka_ml_insights/user_configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_map.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -97,13 +97,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveOrMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, },