From 51cc1db6e84219a08310a7a94edc88d9309342e9 Mon Sep 17 00:00:00 2001 From: neel Date: Wed, 4 Dec 2024 16:17:48 -0700 Subject: [PATCH 1/8] new non model dependent spatial map answer parsing --- .../configs/vision_language/spatial_map.py | 16 +- eureka_ml_insights/data_utils/__init__.py | 1 + .../data_utils/spatial_utils.py | 221 +++++++----------- eureka_ml_insights/metrics/__init__.py | 1 + eureka_ml_insights/metrics/metrics_base.py | 26 +++ 5 files changed, 128 insertions(+), 137 deletions(-) diff --git a/eureka_ml_insights/configs/vision_language/spatial_map.py b/eureka_ml_insights/configs/vision_language/spatial_map.py index 7a4cfdf..22b5bc4 100644 --- a/eureka_ml_insights/configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/configs/vision_language/spatial_map.py @@ -9,10 +9,11 @@ DataLoader, DataReader, ExtractAnswerSpatialMap, + ExtractQuestionOptions, PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator from ..config import ( AggregatorConfig, @@ -82,24 +83,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) "format": ".jsonl", "transform": SequenceTransform( [ + ExtractQuestionOptions( + prompt_column_name="prompt", + extracted_options_column_name="target_options_answers", + ), ColumnRename(name_mapping={"model_output": "model_output_raw"}), ExtractAnswerSpatialMap( answer_column_name="model_output_raw", extracted_answer_column_name="model_output", - question_type_column_name="question_type", - model_name=model_config.init_args['model_name'], # passing the model name for model-specific answer extraction + extracted_options_column_name="target_options_answers", ), ], ), }, ), - metric_config=MetricConfig(CaseInsensitiveMatch), + metric_config=MetricConfig(CaseInsensitiveOrMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveMatch_result"], + "column_names": ["CaseInsensitiveOrMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/eureka_ml_insights/data_utils/__init__.py b/eureka_ml_insights/data_utils/__init__.py index 88b5921..70c3db7 100644 --- a/eureka_ml_insights/data_utils/__init__.py +++ b/eureka_ml_insights/data_utils/__init__.py @@ -17,6 +17,7 @@ ExtractAnswerGrid, ExtractAnswerMaze, ExtractAnswerSpatialMap, + ExtractQuestionOptions, ) from .transform import ( AddColumn, diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 8a579bf..c40ea7d 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -157,22 +157,22 @@ def extract_answer_from_text_grid(text, question_type): return None # Return None if no numbers are found -def extract_answer_from_text_map(text, question_type, model_name): +def extract_answer_from_text_map(model_output_raw, options): """ - Extracts the answer from the text based on specific patterns, - and as a fallback, extracts the first number if no patterns match. - The code is from: https://github.com/alvinmingwisc/spatial_reason_vlm/tree/main/eval, - and included with minimal modifications. + Extracts the answer from the text based on known model output patterns. + Searches for botha letter and whole word answer and returns both as they are not + always consistent. Args: - text (str): The text containing the model's answer. - - question_type (str): The text containing the question type. - - model_name (str): The model name. Returns: - - str or None: The extracted answer, or None if no answer could be extracted. + - str or None: The extracted answers, or empty strings if no answer could be extracted. """ - # Mapping of textual numbers to their numeric equivalents + + # replace common subsitutions for numbers in model outputs + model_output_raw = model_output_raw.replace("no objects", "0 objects") + number_mapping = { "zero": 0, "no": 0, @@ -187,127 +187,57 @@ def extract_answer_from_text_map(text, question_type, model_name): "nine": 9, } - dirs = ["southeast", "northeast", "northwest", "southwest"] - dir_pattern = rf"\b({'|'.join(dirs)})\b" - - if text is None: - return None - - question_id = int(re.search("[0-9]", re.search("Q[0-9]", question_type).group()).group()) - - if question_id == 0: - direction_match = re.search(r"\b[A-D]\.\s*(" + "|".join(dirs) + r")\b", text, re.IGNORECASE) - if direction_match: - return direction_match.group(1).lower() - - match = re.search(dir_pattern, text, re.IGNORECASE) - if match: - return match.group(1) - return None - - elif question_id == 1: - match = re.search( - rf"^([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|located\s+to\s+the\s+)({dir_pattern})", - text, - re.IGNORECASE, - ) - - if match: - string = match.group(1) - return string - - match = re.search(r"\b[A-D]\.\s*(.*)", text) # problem with extracting . - - if match: - string = match.group(1) - string = remove_redundancy(string) - string = extract_before_is(string) - return string - - match = re.search(r"\b([ABCD][.,]|[(][abcdABCD][)])\s*(.*?)(?=\sis\b|\.|,|<|$)", text) - if match: - answer = match.group(1).strip() - # Remove trailing punctuation if any - answer = re.sub(r"[\.,\?!<]+$", "", answer) - return answer - - match = re.search( - rf"Therefore, the object in the {dir_pattern} of [\w\s\'\']+ is ([\w\s\'\']+)", text, re.IGNORECASE - ) - if match: - string = match.group(2) - return string - - if "claude" in model_name.lower(): - match = re.search(rf"^([\w\s\'\']+?)\s+is\s+(to\s+the\s+)({dir_pattern})", text, re.IGNORECASE) - if match: - string = match.group(1) - return string - - if "gemini" in model_name.lower(): - patterns = [ - rf"\*\*Concise Answer:\*\*\n([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", - rf"\*\*Answer:\*\*\s+([\w\s\'\']+?)\s+is\s+in\s+the\s+({dir_pattern})\s+of\s+([\w\s\'\']+)", - r"\*\*Answer:\*\*\n([\w\s\'\']+)", - r"\*\*Answer\*\*:\s+([\w\s\'\']+)", - r"\*\*Answer:\*\*\s+([\w\s\'\']+)", - ] - - for pattern in patterns: - match = re.search(pattern, text, re.IGNORECASE) - if match: - return match.group(1) - - if "gpt-4o" in model_name.lower() or "gpt4o" in model_name.lower(): - match = re.search( - rf"Concise Answer:\s+([\w\s\'\']+?)\s+is\s+(?:located\s+|in\s+the\s+|in\s+|located\s+to\s+the\s+)({dir_pattern})", - text, - re.IGNORECASE, - ) - if match: - string = match.group(1) - return string - - # If no match, check for an answer following "is", with specific end markers defined - match = re.search(r"\bis\b\s+(.*?)(?=\.|,|<|$)", text) - if match: - answer = match.group(1).strip() - # Remove trailing punctuation if any - answer = re.sub(r"[\.,\?!<]+$", "", answer) - return answer - - return None # Return None if no match is found - - elif question_id == 2: - match = re.search(r"\b[A-D]\.\s*(\d+)", text) # match number only - if match: - return match.group(1) - # Create a list to store all found numbers along with their positions - found_numbers = [] - - # Check for textual numbers and their positions - for text_num, num in number_mapping.items(): - for match in re.finditer(rf"\b{text_num}\b", text, re.IGNORECASE): - found_numbers.append((match.start(), num)) + for k, v in number_mapping.items(): + model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE) + + # get dict of options from options string + options_dict = {x.split(".")[0].strip():x.split(".")[1].strip() for x in options} + + model_output_parsed_letter = "" + model_output_parsed = "" + + # "Concise Asnwer" is a common GPT-4o pattern + if "Concise Answer" in model_output_raw: + pattern_letter = r"^\**Concise Answer:\**\s+(\w)\. (\w+)" + matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) + if matches: + match_option = matches.group(1) + model_output_parsed_letter = options_dict[match_option] + + pattern_phrase = r"\**Concise Answer:\**\s+([^\n]+)" + model_output_answer_line = re.search(pattern_phrase, model_output_raw, re.IGNORECASE).group(1) + + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + + if answers_match: + model_output_parsed = answers_match.group(1) - # Check for digit sequences and their positions, specifically ignoring list markers at the start - # Exclude numbers following "\n\n" and directly followed by ". " - text = re.sub(r"^\n\n\d+\.\s", "", text) # Remove the leading list marker if it exists + else: + pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*' + + # first look for a single letter answer + matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) + if matches: + match_option = matches.group(1) + if match_option in options_dict: + model_output_parsed_letter = options_dict[match_option] + else: + model_output_parsed_letter = match_option - for match in re.finditer(r"\d+", text): - found_numbers.append((match.start(), int(match.group(0)))) + # next look if any of the options names are present in the first sentance - # Sort found numbers by their positions (smallest position first) - if found_numbers: - found_numbers.sort(key=lambda x: x[0]) - # Return the number associated with the earliest position - return str(found_numbers[0][1]) - return None + model_output_answer_line = model_output_raw.splitlines()[0] - else: - raise ValueError(f"Question ID {question_id} is not supported.") + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + + if answers_match: + model_output_parsed = answers_match.group(1) - return None # Return None if no numbers are found + return [model_output_parsed, model_output_parsed_letter] def extract_answer_from_text_maze(text, question_type): @@ -443,6 +373,34 @@ def transform(self, df: pd.DataFrame) -> pd.DataFrame: ) return df +@dataclass +class ExtractQuestionOptions(DFTransformBase): + """This class is for extracting the option list from a prompt.""" + + prompt_column_name: str + extracted_options_column_name: str + + def _extract_options_from_text_map(self, prompt): + """ + Extracts the options list from the text. + + Args: + - text (str): The text containing the prompt. + + Returns: + - str or None: The extracted list of options. + """ + + # get list of options from prompt + prompt_lines = prompt.splitlines() + matches = [i for i, x in enumerate(prompt_lines) if "Available options:" in x] + options = prompt_lines[matches[0]+1:matches[0]+5] + + return options + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + df[self.extracted_options_column_name] = df[self.prompt_column_name].apply(self._extract_options_from_text_map) + return df @dataclass class ExtractAnswerGrid(ExtractAnswer): @@ -459,17 +417,18 @@ def _parse_answer_function(self, answer_text, question_type): @dataclass -class ExtractAnswerSpatialMap(ExtractAnswer): +class ExtractAnswerSpatialMap(DFTransformBase): """This class is an answer extractor for the SPATIAL_MAP benchmark.""" answer_column_name: str extracted_answer_column_name: str - question_type_column_name: str - model_name: str + extracted_options_column_name: str - @abstractmethod - def _parse_answer_function(self, answer_text, question_type): - return extract_answer_from_text_map(answer_text, question_type, self.model_name) + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + df[self.extracted_answer_column_name] = df.apply( + lambda x: extract_answer_from_text_map(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1 + ) + return df @dataclass diff --git a/eureka_ml_insights/metrics/__init__.py b/eureka_ml_insights/metrics/__init__.py index 2b19ec1..56eb04f 100644 --- a/eureka_ml_insights/metrics/__init__.py +++ b/eureka_ml_insights/metrics/__init__.py @@ -2,6 +2,7 @@ from .geomtric_reasoning_metrics import GeoMCQMetric from .metrics_base import ( CaseInsensitiveMatch, + CaseInsensitiveOrMatch, ClassicMetric, CompositeMetric, ExactMatch, diff --git a/eureka_ml_insights/metrics/metrics_base.py b/eureka_ml_insights/metrics/metrics_base.py index 45d4249..4449a61 100644 --- a/eureka_ml_insights/metrics/metrics_base.py +++ b/eureka_ml_insights/metrics/metrics_base.py @@ -144,6 +144,25 @@ def __evaluate__(self, answer_text, target_text, is_valid): else: return "incorrect" +class ExactOrMatch(ExactMatch): + """This class checks for a case-insensitive, but otherwise exact match, and returns the or of them.""" + + def __evaluate__(self, answer_texts, target_text, is_valid): + + if not is_valid: + return "none" + + results = [] + for answer_text in answer_texts: + res = super().__evaluate__(str(answer_text), str(target_text), is_valid) + results.append(res) + + corrects = [x=="correct" for x in results] + + if (any(corrects)): + return "correct" + else: + return "incorrect" class CaseInsensitiveMatch(ExactMatch): """This class checks for a case-insensitive, but otherwise exact match.""" @@ -151,6 +170,13 @@ class CaseInsensitiveMatch(ExactMatch): def __evaluate__(self, answer_text, target_text, is_valid): return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid) +class CaseInsensitiveOrMatch(ExactOrMatch): + """This class checks for a case-insensitive, but otherwise exact or match.""" + + def __evaluate__(self, answer_texts, target_text, is_valid): + answer_texts = [str(answer_text).lower() for answer_text in answer_texts] + return super().__evaluate__(answer_texts, str(target_text).lower(), is_valid) + class IdentityMetric(Metric): From bb2dd0872fefe8948056c7b06c5b9747b3809e97 Mon Sep 17 00:00:00 2001 From: neel Date: Fri, 6 Dec 2024 13:21:24 -0700 Subject: [PATCH 2/8] updated maze and spatial map configs and answer extraction --- .../configs/vision_language/maze.py | 19 +++++--- .../configs/vision_language/spatial_map.py | 4 +- eureka_ml_insights/data_utils/__init__.py | 6 +-- .../data_utils/spatial_utils.py | 48 ++++++++----------- 4 files changed, 36 insertions(+), 41 deletions(-) diff --git a/eureka_ml_insights/configs/vision_language/maze.py b/eureka_ml_insights/configs/vision_language/maze.py index 0f8da2e..3b21dae 100644 --- a/eureka_ml_insights/configs/vision_language/maze.py +++ b/eureka_ml_insights/configs/vision_language/maze.py @@ -8,11 +8,12 @@ ColumnRename, DataLoader, DataReader, - ExtractAnswerMaze, + ExtractQuestionOptions, + ExtractAnswerSpatialMapAndMaze, PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator from ..config import ( AggregatorConfig, @@ -81,23 +82,27 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) "format": ".jsonl", "transform": SequenceTransform( [ + ExtractQuestionOptions( + prompt_column_name="prompt", + extracted_options_column_name="target_options_answers", + ), ColumnRename(name_mapping={"model_output": "model_output_raw"}), - ExtractAnswerMaze( + ExtractAnswerSpatialMapAndMaze( answer_column_name="model_output_raw", extracted_answer_column_name="model_output", - question_type_column_name="question_type", + extracted_options_column_name="target_options_answers", ), ], ), }, ), - metric_config=MetricConfig(CaseInsensitiveMatch), + metric_config=MetricConfig(CaseInsensitiveOrMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveMatch_result"], + "column_names": ["CaseInsensitiveOrMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/eureka_ml_insights/configs/vision_language/spatial_map.py b/eureka_ml_insights/configs/vision_language/spatial_map.py index 22b5bc4..d18edba 100644 --- a/eureka_ml_insights/configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/configs/vision_language/spatial_map.py @@ -8,7 +8,7 @@ ColumnRename, DataLoader, DataReader, - ExtractAnswerSpatialMap, + ExtractAnswerSpatialMapAndMaze, ExtractQuestionOptions, PrependStringTransform, SequenceTransform, @@ -88,7 +88,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) extracted_options_column_name="target_options_answers", ), ColumnRename(name_mapping={"model_output": "model_output_raw"}), - ExtractAnswerSpatialMap( + ExtractAnswerSpatialMapAndMaze( answer_column_name="model_output_raw", extracted_answer_column_name="model_output", extracted_options_column_name="target_options_answers", diff --git a/eureka_ml_insights/data_utils/__init__.py b/eureka_ml_insights/data_utils/__init__.py index 70c3db7..4e4d783 100644 --- a/eureka_ml_insights/data_utils/__init__.py +++ b/eureka_ml_insights/data_utils/__init__.py @@ -15,8 +15,7 @@ from .secret_key_utils import GetKey from .spatial_utils import ( ExtractAnswerGrid, - ExtractAnswerMaze, - ExtractAnswerSpatialMap, + ExtractAnswerSpatialMapAndMaze, ExtractQuestionOptions, ) from .transform import ( @@ -70,8 +69,7 @@ PrependStringTransform, GetKey, ExtractAnswerGrid, - ExtractAnswerSpatialMap, - ExtractAnswerMaze, + ExtractAnswerSpatialMapAndMaze, ShuffleColumnsTransform, ColumnMatchMapTransform, TokenCounterTransform, diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index c40ea7d..5b61224 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -157,7 +157,7 @@ def extract_answer_from_text_grid(text, question_type): return None # Return None if no numbers are found -def extract_answer_from_text_map(model_output_raw, options): +def extract_answer_from_text_map_and_maze(model_output_raw, options): """ Extracts the answer from the text based on known model output patterns. Searches for botha letter and whole word answer and returns both as they are not @@ -170,12 +170,20 @@ def extract_answer_from_text_map(model_output_raw, options): - str or None: The extracted answers, or empty strings if no answer could be extracted. """ - # replace common subsitutions for numbers in model outputs + # replace common subsitutions in model outputs + + model_output_parsed_letter = "" + model_output_parsed = "" + + if not model_output_raw: + return [model_output_parsed, model_output_parsed_letter] + model_output_raw = model_output_raw.replace("no objects", "0 objects") + model_output_raw = model_output_raw.replace("not", "no") + model_output_raw = model_output_raw.replace("should be", "is") number_mapping = { - "zero": 0, - "no": 0, + "zero": 0, "one": 1, "two": 2, "three": 3, @@ -190,18 +198,15 @@ def extract_answer_from_text_map(model_output_raw, options): for k, v in number_mapping.items(): model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE) - # get dict of options from options string - options_dict = {x.split(".")[0].strip():x.split(".")[1].strip() for x in options} - - model_output_parsed_letter = "" - model_output_parsed = "" + # get dict of options from options string + options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options} - # "Concise Asnwer" is a common GPT-4o pattern - if "Concise Answer" in model_output_raw: + # "Concise Answer" is a common GPT-4o pattern + if "Concise Answer:".lower() in model_output_raw.lower(): pattern_letter = r"^\**Concise Answer:\**\s+(\w)\. (\w+)" matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) if matches: - match_option = matches.group(1) + match_option = matches.group(1).lower() model_output_parsed_letter = options_dict[match_option] pattern_phrase = r"\**Concise Answer:\**\s+([^\n]+)" @@ -220,7 +225,7 @@ def extract_answer_from_text_map(model_output_raw, options): # first look for a single letter answer matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) if matches: - match_option = matches.group(1) + match_option = matches.group(1).lower() if match_option in options_dict: model_output_parsed_letter = options_dict[match_option] else: @@ -417,7 +422,7 @@ def _parse_answer_function(self, answer_text, question_type): @dataclass -class ExtractAnswerSpatialMap(DFTransformBase): +class ExtractAnswerSpatialMapAndMaze(DFTransformBase): """This class is an answer extractor for the SPATIAL_MAP benchmark.""" answer_column_name: str @@ -426,19 +431,6 @@ class ExtractAnswerSpatialMap(DFTransformBase): def transform(self, df: pd.DataFrame) -> pd.DataFrame: df[self.extracted_answer_column_name] = df.apply( - lambda x: extract_answer_from_text_map(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1 + lambda x: extract_answer_from_text_map_and_maze(x[self.answer_column_name], x[self.extracted_options_column_name]), axis=1 ) return df - - -@dataclass -class ExtractAnswerMaze(ExtractAnswer): - """This class is an answer extractor for the MAZE benchmark.""" - - answer_column_name: str - extracted_answer_column_name: str - question_type_column_name: str - - @abstractmethod - def _parse_answer_function(self, answer_text, question_type): - return extract_answer_from_text_maze(answer_text, question_type) From c4b6bb07ea786cd466476de823f4bec241841f21 Mon Sep 17 00:00:00 2001 From: neel Date: Fri, 6 Dec 2024 14:27:16 -0700 Subject: [PATCH 3/8] parsng updates for o1 --- .../data_utils/spatial_utils.py | 65 ++++++++++++------- 1 file changed, 41 insertions(+), 24 deletions(-) diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 5b61224..09256d5 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -178,9 +178,9 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): if not model_output_raw: return [model_output_parsed, model_output_parsed_letter] - model_output_raw = model_output_raw.replace("no objects", "0 objects") - model_output_raw = model_output_raw.replace("not", "no") - model_output_raw = model_output_raw.replace("should be", "is") + model_output_raw = re.sub(r"\bno objects\b", "0 objects", model_output_raw, re.IGNORECASE) + model_output_raw = re.sub(r"\bnot\b", "no", model_output_raw, re.IGNORECASE) + model_output_raw = re.sub(r"\bshould be\b", "is", model_output_raw, re.IGNORECASE) number_mapping = { "zero": 0, @@ -198,29 +198,46 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): for k, v in number_mapping.items(): model_output_raw = re.sub(rf"\b{k}\b", str(v), model_output_raw, re.IGNORECASE) - # get dict of options from options string + # get dict of options from options string options_dict = {x.split(".")[0].strip().lower():x.split(".")[1].strip().lower() for x in options} - # "Concise Answer" is a common GPT-4o pattern - if "Concise Answer:".lower() in model_output_raw.lower(): - pattern_letter = r"^\**Concise Answer:\**\s+(\w)\. (\w+)" + + model_output_parsed_letter = "" + model_output_parsed = "" + + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + + if "Answer:".lower() in model_output_raw.lower(): + pattern_letter = r"^\**Answer:\**\s+(\w)\. (\w+)" matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) if matches: match_option = matches.group(1).lower() - model_output_parsed_letter = options_dict[match_option] + if match_option in options_dict: + model_output_parsed_letter = options_dict[match_option] + else: + model_output_parsed_letter = match_option - pattern_phrase = r"\**Concise Answer:\**\s+([^\n]+)" - model_output_answer_line = re.search(pattern_phrase, model_output_raw, re.IGNORECASE).group(1) + pattern_phrase = r"Answer:\**\s+([^\n]+)" + matches = re.search(pattern_phrase, model_output_raw, re.IGNORECASE) + if matches: + model_output_answer_line = matches.group(1) - answers = [v for k, v in options_dict.items()] - answers_pattern = rf"\b({'|'.join(answers)})\b" - answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) - if answers_match: - model_output_parsed = answers_match.group(1) + if answers_match: + model_output_parsed = answers_match.group(1) + else: + letters = [k for k, v in options_dict.items()] + letters_pattern = rf"\b({'|'.join(letters)})\b" + letters_pattern_match = re.search(letters_pattern, model_output_answer_line, re.IGNORECASE) - else: - pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*' + if letters_pattern_match: + match_option = letters_pattern_match.group(1).lower() + model_output_parsed_letter = options_dict[match_option] + + elif "answer is".lower() in model_output_raw.lower(): + pattern_letter = r'answer is:*\s*\**([\w\d]+)[\s:.]*\**' # first look for a single letter answer matches = re.search(pattern_letter, model_output_raw, re.IGNORECASE) @@ -231,16 +248,16 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): else: model_output_parsed_letter = match_option - # next look if any of the options names are present in the first sentance + # next look if any of the options names are present in the first sentance - model_output_answer_line = model_output_raw.splitlines()[0] + model_output_answer_line = model_output_raw.splitlines()[0] - answers = [v for k, v in options_dict.items()] - answers_pattern = rf"\b({'|'.join(answers)})\b" - answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) + answers = [v for k, v in options_dict.items()] + answers_pattern = rf"\b({'|'.join(answers)})\b" + answers_match = re.search(answers_pattern, model_output_answer_line, re.IGNORECASE) - if answers_match: - model_output_parsed = answers_match.group(1) + if answers_match: + model_output_parsed = answers_match.group(1) return [model_output_parsed, model_output_parsed_letter] From 46e35730b50c964ba3ce7640570d6565ffd0a035 Mon Sep 17 00:00:00 2001 From: neel Date: Fri, 6 Dec 2024 17:02:52 -0700 Subject: [PATCH 4/8] comments updates --- eureka_ml_insights/data_utils/spatial_utils.py | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 09256d5..74352c0 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -160,11 +160,12 @@ def extract_answer_from_text_grid(text, question_type): def extract_answer_from_text_map_and_maze(model_output_raw, options): """ Extracts the answer from the text based on known model output patterns. - Searches for botha letter and whole word answer and returns both as they are not + Searches for both a letter and whole word answer and returns both as they are not always consistent. Args: - - text (str): The text containing the model's answer. + - model_output_raw (str): The text containing the model's answer. + - options (str): The list of options. Returns: - str or None: The extracted answers, or empty strings if no answer could be extracted. @@ -248,7 +249,7 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): else: model_output_parsed_letter = match_option - # next look if any of the options names are present in the first sentance + # next look if any of the options names are present in the first line model_output_answer_line = model_output_raw.splitlines()[0] @@ -440,7 +441,7 @@ def _parse_answer_function(self, answer_text, question_type): @dataclass class ExtractAnswerSpatialMapAndMaze(DFTransformBase): - """This class is an answer extractor for the SPATIAL_MAP benchmark.""" + """This class is an answer extractor for the SPATIAL_MAP and MAZE benchmark.""" answer_column_name: str extracted_answer_column_name: str From b890e1a0b09a526cf0117d5763a56dd0b3078bb5 Mon Sep 17 00:00:00 2001 From: neel Date: Mon, 9 Dec 2024 13:41:32 -0700 Subject: [PATCH 5/8] updated comments --- eureka_ml_insights/data_utils/spatial_utils.py | 2 +- eureka_ml_insights/metrics/metrics_base.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 74352c0..68dc332 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -405,7 +405,7 @@ class ExtractQuestionOptions(DFTransformBase): def _extract_options_from_text_map(self, prompt): """ - Extracts the options list from the text. + Extracts the multiple-choice options list from the text. Args: - text (str): The text containing the prompt. diff --git a/eureka_ml_insights/metrics/metrics_base.py b/eureka_ml_insights/metrics/metrics_base.py index 4449a61..d12af4e 100644 --- a/eureka_ml_insights/metrics/metrics_base.py +++ b/eureka_ml_insights/metrics/metrics_base.py @@ -145,7 +145,7 @@ def __evaluate__(self, answer_text, target_text, is_valid): return "incorrect" class ExactOrMatch(ExactMatch): - """This class checks for a case-insensitive, but otherwise exact match, and returns the or of them.""" + """This class checks for a case-sensitive, but otherwise exact match, and returns the or of them.""" def __evaluate__(self, answer_texts, target_text, is_valid): From c7b848ee0adc112a157999f662fd7dff1c8cb585 Mon Sep 17 00:00:00 2001 From: neel Date: Wed, 11 Dec 2024 14:52:03 -0700 Subject: [PATCH 6/8] added to comments, renamed metric classes --- eureka_ml_insights/data_utils/__init__.py | 1 + eureka_ml_insights/metrics/__init__.py | 3 ++- eureka_ml_insights/metrics/metrics_base.py | 20 +++++++++++++++---- .../user_configs/vision_language/maze.py | 4 ++-- .../vision_language/spatial_map.py | 4 ++-- 5 files changed, 23 insertions(+), 9 deletions(-) diff --git a/eureka_ml_insights/data_utils/__init__.py b/eureka_ml_insights/data_utils/__init__.py index 5935d48..415b1b7 100644 --- a/eureka_ml_insights/data_utils/__init__.py +++ b/eureka_ml_insights/data_utils/__init__.py @@ -69,6 +69,7 @@ PrependStringTransform, ExtractAnswerGrid, ExtractAnswerSpatialMapAndMaze, + ExtractQuestionOptions, ShuffleColumnsTransform, ColumnMatchMapTransform, TokenCounterTransform, diff --git a/eureka_ml_insights/metrics/__init__.py b/eureka_ml_insights/metrics/__init__.py index 56eb04f..3e6521e 100644 --- a/eureka_ml_insights/metrics/__init__.py +++ b/eureka_ml_insights/metrics/__init__.py @@ -2,12 +2,13 @@ from .geomtric_reasoning_metrics import GeoMCQMetric from .metrics_base import ( CaseInsensitiveMatch, - CaseInsensitiveOrMatch, ClassicMetric, CompositeMetric, ExactMatch, IdentityMetric, Metric, + MultiCandidateAnyExactMatch, + MultiCandidateAnyCaseInsensitiveMatch, SubstringExistsMatch, ) from .mmmu_metrics import MMMUMetric diff --git a/eureka_ml_insights/metrics/metrics_base.py b/eureka_ml_insights/metrics/metrics_base.py index d12af4e..a96146d 100644 --- a/eureka_ml_insights/metrics/metrics_base.py +++ b/eureka_ml_insights/metrics/metrics_base.py @@ -144,8 +144,14 @@ def __evaluate__(self, answer_text, target_text, is_valid): else: return "incorrect" -class ExactOrMatch(ExactMatch): - """This class checks for a case-sensitive, but otherwise exact match, and returns the or of them.""" +class MultiCandidateAnyExactMatch(ExactMatch): + """ + This class checks for a case-sensitive match for a list of answers from the model output, + and returns the or of the list of metric results. + + This is required for answers to multiple-choice questions. As many models sometimes give the letter answer + and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. + """ def __evaluate__(self, answer_texts, target_text, is_valid): @@ -170,8 +176,14 @@ class CaseInsensitiveMatch(ExactMatch): def __evaluate__(self, answer_text, target_text, is_valid): return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid) -class CaseInsensitiveOrMatch(ExactOrMatch): - """This class checks for a case-insensitive, but otherwise exact or match.""" +class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch): + """ + This class checks for a case-insensitive match for a list of answers from the model output, + and returns the or of the list of metric results. + + This is required for answers to multiple-choice questions. As many models sometimes give the letter answer + and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. + """ def __evaluate__(self, answer_texts, target_text, is_valid): answer_texts = [str(answer_text).lower() for answer_text in answer_texts] diff --git a/eureka_ml_insights/user_configs/vision_language/maze.py b/eureka_ml_insights/user_configs/vision_language/maze.py index ee4b20d..a4d7eb1 100644 --- a/eureka_ml_insights/user_configs/vision_language/maze.py +++ b/eureka_ml_insights/user_configs/vision_language/maze.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator +from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -96,7 +96,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(CaseInsensitiveOrMatch), + metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), aggregator_configs=[ AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), AggregatorConfig( diff --git a/eureka_ml_insights/user_configs/vision_language/spatial_map.py b/eureka_ml_insights/user_configs/vision_language/spatial_map.py index 7bac915..81f9f97 100644 --- a/eureka_ml_insights/user_configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_map.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import CaseInsensitiveOrMatch, CountAggregator +from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -97,7 +97,7 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(CaseInsensitiveOrMatch), + metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), aggregator_configs=[ AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), AggregatorConfig( From ae456f17927555e4f2f3a148da36e2601675df55 Mon Sep 17 00:00:00 2001 From: neel Date: Tue, 17 Dec 2024 14:42:35 -0700 Subject: [PATCH 7/8] removed new multicandidate metrics and return or or multiple answers and substring match instead --- .../data_utils/spatial_utils.py | 2 +- eureka_ml_insights/metrics/__init__.py | 2 - eureka_ml_insights/metrics/metrics_base.py | 38 ------------------- .../user_configs/vision_language/maze.py | 8 ++-- .../vision_language/spatial_map.py | 8 ++-- 5 files changed, 9 insertions(+), 49 deletions(-) diff --git a/eureka_ml_insights/data_utils/spatial_utils.py b/eureka_ml_insights/data_utils/spatial_utils.py index 68dc332..f2907d2 100644 --- a/eureka_ml_insights/data_utils/spatial_utils.py +++ b/eureka_ml_insights/data_utils/spatial_utils.py @@ -260,7 +260,7 @@ def extract_answer_from_text_map_and_maze(model_output_raw, options): if answers_match: model_output_parsed = answers_match.group(1) - return [model_output_parsed, model_output_parsed_letter] + return model_output_parsed + " or " + model_output_parsed_letter def extract_answer_from_text_maze(text, question_type): diff --git a/eureka_ml_insights/metrics/__init__.py b/eureka_ml_insights/metrics/__init__.py index 3e6521e..2b19ec1 100644 --- a/eureka_ml_insights/metrics/__init__.py +++ b/eureka_ml_insights/metrics/__init__.py @@ -7,8 +7,6 @@ ExactMatch, IdentityMetric, Metric, - MultiCandidateAnyExactMatch, - MultiCandidateAnyCaseInsensitiveMatch, SubstringExistsMatch, ) from .mmmu_metrics import MMMUMetric diff --git a/eureka_ml_insights/metrics/metrics_base.py b/eureka_ml_insights/metrics/metrics_base.py index a96146d..45d4249 100644 --- a/eureka_ml_insights/metrics/metrics_base.py +++ b/eureka_ml_insights/metrics/metrics_base.py @@ -144,31 +144,6 @@ def __evaluate__(self, answer_text, target_text, is_valid): else: return "incorrect" -class MultiCandidateAnyExactMatch(ExactMatch): - """ - This class checks for a case-sensitive match for a list of answers from the model output, - and returns the or of the list of metric results. - - This is required for answers to multiple-choice questions. As many models sometimes give the letter answer - and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. - """ - - def __evaluate__(self, answer_texts, target_text, is_valid): - - if not is_valid: - return "none" - - results = [] - for answer_text in answer_texts: - res = super().__evaluate__(str(answer_text), str(target_text), is_valid) - results.append(res) - - corrects = [x=="correct" for x in results] - - if (any(corrects)): - return "correct" - else: - return "incorrect" class CaseInsensitiveMatch(ExactMatch): """This class checks for a case-insensitive, but otherwise exact match.""" @@ -176,19 +151,6 @@ class CaseInsensitiveMatch(ExactMatch): def __evaluate__(self, answer_text, target_text, is_valid): return super().__evaluate__(str(answer_text).lower(), str(target_text).lower(), is_valid) -class MultiCandidateAnyCaseInsensitiveMatch(MultiCandidateAnyExactMatch): - """ - This class checks for a case-insensitive match for a list of answers from the model output, - and returns the or of the list of metric results. - - This is required for answers to multiple-choice questions. As many models sometimes give the letter answer - and sometimes the full word answer. This allows one to consider the answer correct if either one was correct. - """ - - def __evaluate__(self, answer_texts, target_text, is_valid): - answer_texts = [str(answer_text).lower() for answer_text in answer_texts] - return super().__evaluate__(answer_texts, str(target_text).lower(), is_valid) - class IdentityMetric(Metric): diff --git a/eureka_ml_insights/user_configs/vision_language/maze.py b/eureka_ml_insights/user_configs/vision_language/maze.py index a4d7eb1..b461a93 100644 --- a/eureka_ml_insights/user_configs/vision_language/maze.py +++ b/eureka_ml_insights/user_configs/vision_language/maze.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -96,13 +96,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveOrMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, }, diff --git a/eureka_ml_insights/user_configs/vision_language/spatial_map.py b/eureka_ml_insights/user_configs/vision_language/spatial_map.py index 81f9f97..4bb2134 100644 --- a/eureka_ml_insights/user_configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_map.py @@ -13,7 +13,7 @@ PrependStringTransform, SequenceTransform, ) -from eureka_ml_insights.metrics import MultiCandidateAnyCaseInsensitiveMatch, CountAggregator +from eureka_ml_insights.metrics import SubstringExistsMatch, CountAggregator from eureka_ml_insights.configs import ( AggregatorConfig, @@ -97,13 +97,13 @@ def configure_pipeline(self, model_config: ModelConfig, resume_from: str = None) ), }, ), - metric_config=MetricConfig(MultiCandidateAnyCaseInsensitiveMatch), + metric_config=MetricConfig(SubstringExistsMatch), aggregator_configs=[ - AggregatorConfig(CountAggregator, {"column_names": ["CaseInsensitiveOrMatch_result"], "normalize": True}), + AggregatorConfig(CountAggregator, {"column_names": ["SubstringExistsMatch_result"], "normalize": True}), AggregatorConfig( CountAggregator, { - "column_names": ["CaseInsensitiveOrMatch_result"], + "column_names": ["SubstringExistsMatch_result"], "group_by": "task", "normalize": True, }, From 82a4eb21d2054380d388347f21b42f95f8c233f8 Mon Sep 17 00:00:00 2001 From: neel Date: Tue, 17 Dec 2024 16:24:16 -0700 Subject: [PATCH 8/8] added unittest for extract_answer_from_text_map_and_maze to test regex --- .../vision_language_data_utils_tests.py | 63 +++++++++++++++++++ 1 file changed, 63 insertions(+) create mode 100644 tests/data_utils_tests/vision_language_data_utils_tests.py diff --git a/tests/data_utils_tests/vision_language_data_utils_tests.py b/tests/data_utils_tests/vision_language_data_utils_tests.py new file mode 100644 index 0000000..bb30c93 --- /dev/null +++ b/tests/data_utils_tests/vision_language_data_utils_tests.py @@ -0,0 +1,63 @@ +# write unit tests for the classes in data_utils/transform.py + +import logging +import unittest + +import pandas as pd + +from eureka_ml_insights.data_utils.spatial_utils import extract_answer_from_text_map_and_maze + +log = logging.getLogger("VISION_LANGUAGE_extract_answer_from_text_map_and_maze_tests") + + +class TestExtractAnswer(unittest.TestCase): + def test_extract_answer_from_text_map_and_maze(self): + + ## Direction answer extraction tests + testcases_dir = [ + "The correct answer is A: Planetarium Prints is located Northeast of Police Supply Store.", + + "Planetarium Prints is located to the Northeast of Police Supply Store, so the answer is A.\n\nTherefore, A (Northeast) is the correct answer as it accurately describes the position of Planetarium Prints relative to Police Supply Store.", + + "A. Northeast\n\nBased on the information provided, Planetarium Prints is located to the Northeast of Police Supply Store. This can be determined through the following steps:\n\n1. Police Supply Store is used as the reference point.\n2. Coral Crafts is stated to be Northwest of Police Supply Store.\n3. Planetarium Prints is then described as being Southeast of Coral Crafts.\n4. Since Coral Crafts is Northwest of Police Supply Store, and Planetarium Prints is Southeast of Coral Crafts, this places Planetarium Prints to the Northeast of Police Supply Store.\n\nAdditionally, the description directly states \"Planetarium Prints is to the Northeast of Police Supply Store\", confirming the answer.", + + "Planetarium Prints is **Northeast** of Police Supply Store. \n\nHere's the step-by-step explanation:\n\n1. **Establish a reference point:** The question asks for the direction of Planetarium Prints *relative* to Police Supply Store. This means we consider Police Supply Store as our central point.\n\n2. **Use the given information:** We are told \"Planetarium Prints is to the Northeast of Police Supply Store.\"\n\n3. **Direct relationship:** Since we have a direct statement about the relative position of Planetarium Prints to Police Supply Store, we don't need to analyze the other clues.\n\n4. **Final answer:** Therefore, Planetarium Prints is located **Northeast** of Police Supply Store. \n", + + "Planetarium Prints is Northeast of Police Supply Store.\n\n**Reasoning:**\n\n1. **Identify the relative positions:**\n - Police Supply Store is the reference point.\n - Planetarium Prints is one of the objects we need to locate relative to the Police Supply Store.\n\n2. **Analyze the given directions:**\n - Planetarium Prints is described as being to the Northeast of Police Supply Store.\n\n3. **Verify with the map:**\n - On the map, Planetarium Prints is positioned to the upper right (Northeast) of the Police Supply Store.\n\n4. **Conclusion:**\n - Based on the description and the map, Planetarium Prints is indeed Northeast of Police Supply Store.\n\nThus, the correct answer is:\nA. Northeast", + + "Planetarium Prints is to the **Northeast** of Police Supply Store (Option A).\n\n - Based on the above placements, **Planetarium Prints is to the Northeast of Police Supply Store**.\n\nTherefore, the correct answer is **Option A: Northeast**.", + + "The correct answer is A. Northeast.\n\nTo determine the direction of Planetarium Prints relative to Police Supply Store, we can follow the given descriptions:\n\n1. Police Supply Store is in the map.\n2. Narwhal's Novelties is to the Northwest of Police Supply Store.\n3. Coral Crafts is to the Northwest of Narwhal's Novelties.\n4. Coral Crafts is to the Northwest of Police Supply Store.\n5. Planetarium Prints is to the Southeast of Coral Crafts.\n6. Planetarium Prints is to the Northeast of Police Supply Store.\n7. Oz Oddities is to the Southwest of Planetarium Prints.\n8. Oz Oddities is to the Southwest of Police Supply Store.\n9. Ice Queen Ice Cream is to the Northwest of Planetarium Prints.\n10. Ice Queen Ice Cream is to the Southeast of Coral Crafts.\n\nFrom these descriptions, we can see that Planetarium Prints is located to the Northeast of Police Supply Store. This is directly stated in point 6. Therefore, the correct answer is A. Northeast.", + + "A. Northeast.\n\nPlanetarium Prints is to the Northeast of Police Supply Store. The reasoning behind this answer is based on the provided directions and locations of the objects on the map. Narwhal's Novelties is to the Northwest of Police Supply Store, and Coral Crafts is to the Northwest of Narwhal's Novelties. Planetarium Prints is to the Southeast of Coral Crafts, which means it is also to the Northeast of Police Supply Store.", + ] + + target_options_dir = [["A. Northeast","B. Northwest","C. Southwest","D. Southeast."]] * len(testcases_dir) + correct_answers_dir = ["northeast"] * len(testcases_dir) + + ## Numerical extraction tests + + testcases_numerical = [ + "A. 1\n\nTo determine how many objects are in the Southeast of Oz Oddities, we need to look at the relative positions of the objects on the map:\n\n1. Oz Oddities is located at the bottom of the map.\n2. Directly to the Northeast of Oz Oddities is the Police Supply Store.\n3. To the Southeast of Oz Oddities, there is only one object, which is Planetarium Prints.\n4. All other objects are either to the North or Northwest of Oz Oddities and therefore not in the Southeast direction.\n\nBased on the map, only Planetarium Prints is in the Southeast of Oz Oddities, which means the correct answer is A. 1.", + + "There are zero objects", + + "There are no objects", + ] + + target_options_numerical= [["A. 1","B. 0","C. 2","D. 3."]] * len(testcases_numerical) + correct_answers_numerical = ["1", "0", "0"] + + target_options = target_options_dir + target_options_numerical + testcases = testcases_dir + testcases_numerical + correct_answers = correct_answers_dir + correct_answers_numerical + + results = [] + for i, test in enumerate(testcases): + extracted_answer = extract_answer_from_text_map_and_maze(test, target_options[i]) + results.append(correct_answers[i].lower() in extracted_answer.lower()) + + self.assertTrue(all(results)) + +if __name__ == "__main__": + unittest.main()