From 3a1fd02785861fbcfdc9f0be62f3ff0b58829bd1 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Mon, 11 Nov 2024 20:41:00 -0800 Subject: [PATCH 01/10] onboarding calendar planning - pipeline+metrics --- eureka_ml_insights/configs/__init__.py | 2 + eureka_ml_insights/configs/ba_calendar.py | 152 ++++++++++ eureka_ml_insights/configs/model_configs.py | 3 +- .../metrics/ba_calendar_metrics.py | 276 ++++++++++++++++++ eureka_ml_insights/metrics/reports.py | 35 ++- .../calendar_scheduling.jinja | 7 + 6 files changed, 471 insertions(+), 4 deletions(-) create mode 100644 eureka_ml_insights/configs/ba_calendar.py create mode 100644 eureka_ml_insights/metrics/ba_calendar_metrics.py create mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja diff --git a/eureka_ml_insights/configs/__init__.py b/eureka_ml_insights/configs/__init__.py index 6f13d5a..f842dd5 100644 --- a/eureka_ml_insights/configs/__init__.py +++ b/eureka_ml_insights/configs/__init__.py @@ -16,6 +16,7 @@ from .geometer import GEOMETER_PIPELINE from .ifeval import IFEval_PIPELINE from .aime import AIME_PIPELINE +from .ba_calendar import Calendar_Schedule_PIPELINE from .image_understanding.object_detection import ( OBJECT_DETECTION_PAIRS_LOCAL_PIPELINE, OBJECT_DETECTION_PAIRS_PIPELINE, @@ -113,6 +114,7 @@ MAZE_TEXTONLY_PIPELINE, MAZE_REPORTING_PIPELINE, IFEval_PIPELINE, + Calendar_Schedule_PIPELINE, FlenQA_Experiment_Pipeline, GEOMETER_PIPELINE, MMMU_BASELINE_PIPELINE, diff --git a/eureka_ml_insights/configs/ba_calendar.py b/eureka_ml_insights/configs/ba_calendar.py new file mode 100644 index 0000000..8cd3bba --- /dev/null +++ b/eureka_ml_insights/configs/ba_calendar.py @@ -0,0 +1,152 @@ +import os +from tkinter import N + +from eureka_ml_insights.core import ( + Inference, + PromptProcessing, +) + +from eureka_ml_insights.core.eval_reporting import EvalReporting +from eureka_ml_insights.data_utils.data import ( + DataLoader, + DataReader, +) +from eureka_ml_insights.data_utils.transform import ColumnRename, SamplerTransform, SequenceTransform +from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric +from eureka_ml_insights.metrics.reports import ( + AverageAggregator, + BiLevelAverageAggregator, + NAFilteredAverageAggregator, + TwoColumnSumAverageAggregator, +) + +from .config import ( + AggregatorConfig, + DataJoinConfig, + DataProcessingConfig, + DataSetConfig, + EvalReportingConfig, + InferenceConfig, + MetricConfig, + PipelineConfig, + PromptProcessingConfig, +) +from .experiment_config import ExperimentConfig + +class Calendar_Schedule_PIPELINE(ExperimentConfig): + """This class specifies the config for running any benchmark on any model""" + + def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> PipelineConfig: + # data preprocessing + self.data_processing_comp = PromptProcessingConfig( + component_type=PromptProcessing, + prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling.jinja"), + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join("../local_benchmark_data/Natasha_benchmarks/datasets/datasets/", "ba_calendar.jsonl"), + "transform": SequenceTransform([ + ColumnRename(name_mapping={"task_prompt": "prompt"}), + SamplerTransform(random_seed=5, sample_count=10), + ]), + }, + ), + output_dir=os.path.join(self.log_dir, "data_processing_output"), + ) + + # inference component + self.inference_comp = InferenceConfig( + component_type=Inference, + model_config=model_config, + data_loader_config=DataSetConfig( + DataLoader, + {"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")}, + ), + output_dir=os.path.join(self.log_dir, "inference_result"), + resume_from=resume_from, + ) + + # Configure the evaluation and reporting component for evaluation and dataset level aggregation + self.evalreporting_comp = EvalReportingConfig( + component_type=EvalReporting, + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), + "format": ".jsonl", + # "transform": ColumnRename(name_mapping={"model_output": "response"}), + }, + ), + metric_config=MetricConfig(BACalendarMetric), + aggregator_configs=[ + AggregatorConfig( + AverageAggregator, + { + "column_names": [ + "BACalendarMetric_all_correct", + ], + "filename_base": "BaCal_AllCorrect_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_availability_programmatic_check", + "filename_base": "BaCal_Availability_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_meeting_duration_programmatic_check", + "filename_base": "BaCal_MeetingDuration_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_buffer_time_programmatic_check", + "filename_base": "BaCal_BufferTime_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_no_weekends_programmatic_check", + "filename_base": "BaCal_NoWeekends_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_time_restrictions_programmatic_check", + "filename_base": "BaCal_TimeRestrictions_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_specific_times_programmatic_check", + "filename_base": "BaCal_SpecificTimes_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_priority_programmatic_check", + "filename_base": "BaCal_Priority_Check_Aggregated", + }, + ), + ], + output_dir=os.path.join(self.log_dir, "eval_report"), + ) + + # Configure the pipeline + return PipelineConfig( + [ + self.data_processing_comp, + self.inference_comp, + self.evalreporting_comp + ], + self.log_dir, + ) \ No newline at end of file diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 64bcc33..23d9b9e 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -111,7 +111,8 @@ # Claude models CLAUDE_SECRET_KEY_PARAMS = { - "key_name": "your_claude_secret_key_name", + # "key_name": "your_claude_secret_key_name", + "key_name": "aif-eval-claude", "local_keys_path": "keys/keys.json", "key_vault_url": None, } diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py new file mode 100644 index 0000000..d690cf2 --- /dev/null +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -0,0 +1,276 @@ +import json +import re +from datetime import datetime, timedelta + +from eureka_ml_insights.metrics.metrics_base import CompositeMetric + +# Helper functions +def is_formatted(solution): + pattern = r"^(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday) ([0-9]|[01]\d|2[0-3]):[0-5]\d-([0-9]|[01]\d|2[0-3]):[0-5]\d$" + return bool(re.match(pattern, solution)) + +def generate_time_slots(start_time, end_time, granularity): + granularity=5 + slots = [] + current_time = start_time + while current_time + timedelta(minutes=granularity) <= end_time: + slots.append((current_time, current_time + timedelta(minutes=granularity))) + current_time += timedelta(minutes=granularity) + return slots + +def parse_time_block(time_block): + start_str, end_str = time_block.split('-') + start_time = datetime.strptime(start_str, "%H:%M") + end_time = datetime.strptime(end_str, "%H:%M") + return start_time, end_time + +def filter_slots_by_duration(time_slots, duration): + filtered_slots = [] + for i in range(len(time_slots)): + accumulated_duration = timedelta() + for j in range(i, len(time_slots)): + accumulated_duration += time_slots[j][1] - time_slots[j][0] + if accumulated_duration >= timedelta(minutes=duration): + filtered_slots.append((time_slots[i][0], time_slots[j][1])) + break + return filtered_slots + +def filter_slots_by_constraints(time_slots, constraints, day): + filtered_slots = [] + for slot in time_slots: + start_time, end_time = slot + if constraints['no_meetings_before']: + no_meetings_before = datetime.strptime(f"{constraints['no_meetings_before']}:00", "%H:%M") + if start_time < no_meetings_before: + continue + if constraints['no_meetings_after']: + no_meetings_after = datetime.strptime(f"{constraints['no_meetings_after']}:00", "%H:%M") + if end_time >= no_meetings_after: + continue + if constraints['no_meetings_on_weekends'] and day in ['Saturday', 'Sunday']: + continue + if constraints['no_meetings_during_specific_times']: + no_meetings_start, no_meetings_end = parse_time_block(constraints['no_meetings_during_specific_times']) + if (start_time < no_meetings_end and end_time > no_meetings_start): + continue + filtered_slots.append(slot) + return filtered_slots + +# ask_true_false returns a tuple bool, str where the bool is the answer and the str is the justification + +class BACalendarMetric(CompositeMetric): + """ + Composite metric for evaluating if a response follows instructions. + + This metric evaluates if a given response follows the provided instructions. + It calculates both strict and loose evaluation scores based on the response's adherence to the instructions. + """ + + def __init__(self): + super().__init__() + self.no_solution_response = "No common time slot available" + + def __evaluate__(self, row): + results = {} + results.update(self.run_programmatic_tests(row)) + return results + + def run_programmatic_tests(self, instance): + result = {} + solution = instance['model_output'] + if not is_formatted(solution): + result['format_programmatic'] = 1 + result.update(self.check_availability_programmatic(instance, solution)) + result.update(self.check_meeting_duration_programmatic(instance, solution)) + result.update(self.check_buffer_time_programmatic(instance, solution)) + result.update(self.check_no_weekends_programmatic(instance, solution)) + result.update(self.check_time_restrictions_programmatic(instance, solution)) + result.update(self.check_specific_times_programmatic(instance, solution)) # Added programmatic specific times check + result.update(self.check_priority_programmatic(instance, solution)) # Added model-based priority check + all_correct = 1 + for key, value in result.items(): + if value == 0: + all_correct = 0 + result['all_correct'] = all_correct + return result + + def is_formatted(self, solution): + run_tests=True + if solution == self.no_solution_response: + run_tests=False + if not is_formatted(solution): + run_tests=False + return run_tests + + def check_availability_programmatic(self, instance, solution): + if not instance['constraints'].get('availability', True): + result = {'availability_programmatic_check': 'NA'} + return result + + if not self.is_formatted(solution): + result = {'availability_programmatic_check': 0} + return result + + day, time_range = solution.split() + start_time, end_time = parse_time_block(time_range) + all_available = 1 + availability = json.loads(instance['metadata']['availability'].replace("'", '"')) + for participant, schedule in availability.items(): + if day not in schedule: + all_available = 0 + break + available_blocks = schedule[day] + available = False + for block in available_blocks: + block_start, block_end = parse_time_block(block) + if block_start <= start_time and block_end >= end_time: + available = True + break + if not available: + all_available = 0 + break + + return {'availability_programmatic_check': all_available} + + def check_meeting_duration_programmatic(self, instance, solution): + if not instance['constraints'].get('meeting_duration', True): + result = {'meeting_duration_programmatic_check': 'NA'} + return result + + if not self.is_formatted(solution): + result = {'meeting_duration_programmatic_check': 0} + return result + + _, time_range = solution.split() + start_time, end_time = parse_time_block(time_range) + meeting_duration = (end_time - start_time).total_seconds() / 60 + expected_duration = instance['constraints']['meeting_duration'] + + return {'meeting_duration_programmatic_check': int(meeting_duration == expected_duration)} + + + def check_buffer_time_programmatic(self, instance, solution): + buffer_time = instance['constraints'].get('buffer_time_before_and_after_meeting', True) + if buffer_time is None or not buffer_time: + result = {'buffer_time_programmatic_check': 'NA'} + return result + + if not self.is_formatted(solution): + result = {'buffer_time_programmatic_check': 0} + return result + + buffer_time = instance['constraints']['buffer_time_before_and_after_meeting'] + day, time_range = solution.split() + start_time, end_time = parse_time_block(time_range) + buffer_start_time = start_time - timedelta(minutes=buffer_time) + buffer_end_time = end_time + timedelta(minutes=buffer_time) + all_buffer_respected = 1 + + availability = json.loads(instance['metadata']['availability'].replace("'", '"')) + for participant, schedule in availability.items(): + if day not in schedule: + all_buffer_respected = 0 + break + available_blocks = schedule[day] + buffer_respected = False + for block in available_blocks: + block_start, block_end = parse_time_block(block) + if block_start <= buffer_start_time and block_end >= buffer_end_time: + buffer_respected = True + break + if not buffer_respected: + all_buffer_respected = 0 + break + return {'buffer_time_programmatic_check': all_buffer_respected} + + def check_no_weekends_programmatic(self, instance, solution): + if not instance['constraints'].get('no_meetings_on_weekends', True): + return {'no_weekends_programmatic_check': 'NA'} + + if not self.is_formatted(solution): + return {'no_weekends_programmatic_check': 0} + + day, _ = solution.split() + day_of_week = datetime.strptime(day, '%A').weekday() + no_weekends = day_of_week < 5 + return {'no_weekends_programmatic_check': int(no_weekends)} + + def check_time_restrictions_programmatic(self, instance, solution): + if not instance['constraints'].get('no_meetings_before', True) and not instance['constraints'].get('no_meetings_after', True): + return {'time_restrictions_programmatic_check': 'NA'} + + if not self.is_formatted(solution): + return {'time_restrictions_programmatic_check': 0} + + _, time_range = solution.split() + start_time, end_time = parse_time_block(time_range) + + no_meetings_before = instance['constraints'].get('no_meetings_before') + no_meetings_after = instance['constraints'].get('no_meetings_after') + + if no_meetings_before: + no_meetings_before = datetime.strptime(f"{no_meetings_before}:00", "%H:%M") + if start_time < no_meetings_before: + return {'time_restrictions_programmatic_check': 0} + + if no_meetings_after: + no_meetings_after = datetime.strptime(f"{no_meetings_after}:00", '%H:%M') + if end_time > no_meetings_after: + return {'time_restrictions_programmatic_check': 0} + return {'time_restrictions_programmatic_check': 1} + + def check_priority_programmatic(self, instance, solution): + if not instance['constraints'].get('high_priority_meeting', False): + return {'priority_programmatic_check': 'NA'} + + if not self.is_formatted(solution): + return {'priority_programmatic_check': 0} + + metadata = instance['metadata'] + result = False + params = instance['params'] + constraints = instance['constraints'] + if constraints['buffer_time_before_and_after_meeting']: + buffer_time = constraints['buffer_time_before_and_after_meeting'] + else: + buffer_time = 0 + for day in params['days_of_week']: + common_time_slots = None + availability = json.loads(metadata['availability'].replace("'", '"')) + for participant, schedule in availability.items(): + if day in schedule: + participant_time_slots = [] + for time_slot in schedule[day]: + start_time, end_time = parse_time_block(time_slot) + time_slots = generate_time_slots(start_time, end_time, params['granularity']) + time_slots = filter_slots_by_duration(time_slots, constraints['meeting_duration'] + 2 * buffer_time) + time_slots = filter_slots_by_constraints(time_slots, constraints, day=day) + participant_time_slots.extend(time_slots) + if common_time_slots is None: + common_time_slots = set(participant_time_slots) + else: + common_time_slots = common_time_slots.intersection(participant_time_slots) + if common_time_slots: + first_available_slot = sorted(list(common_time_slots))[0] + first_available_start = (first_available_slot[0]+timedelta(minutes=buffer_time)).strftime('%H:%M') + first_available_end = (first_available_slot[1]-timedelta(minutes=buffer_time)).strftime('%H:%M') + result = solution == f"{day} {first_available_start}-{first_available_end}" + return {'priority_programmatic_check': int(result)} + + def check_specific_times_programmatic(self, instance, solution): + if not instance['constraints'].get('no_meetings_during_specific_times', True): + return {'specific_times_programmatic_check': 'NA'} + + if not self.is_formatted(solution): + return {'specific_times_programmatic_check': 0} + + restricted_times = instance['constraints']['no_meetings_during_specific_times'] + restricted_start, restricted_end = parse_time_block(restricted_times) + day, time_range = solution.split() + start_time, end_time = parse_time_block(time_range) + + if (start_time < restricted_end and end_time > restricted_start): + result = 0 + else: + result = 1 + return {'specific_times_programmatic_check': result} \ No newline at end of file diff --git a/eureka_ml_insights/metrics/reports.py b/eureka_ml_insights/metrics/reports.py index 1edb2d1..8c6bdc4 100644 --- a/eureka_ml_insights/metrics/reports.py +++ b/eureka_ml_insights/metrics/reports.py @@ -114,14 +114,43 @@ def _aggregate_grouped(self, data): class AverageAggregator(NumericalAggregator): def _aggregate(self, data): - averages = {col: data[col].mean().round(3) for col in self.column_names} + if len(data) == 0: + averages = {col: 0 for col in self.column_names} + else: + averages = {col: data[col].mean().round(3) for col in self.column_names} self.aggregated_result = averages def _aggregate_grouped(self, data): - gb = data.groupby(self.group_by) - averages = {col: round(gb[col].mean(), 3).to_dict() for col in self.column_names} + if len(data) == 0: + averages = {col: 0 for col in self.column_names} + else: + gb = data.groupby(self.group_by) + averages = {col: round(gb[col].mean(), 3).to_dict() for col in self.column_names} self.aggregated_result = averages +class NAFilteredAverageAggregator(AverageAggregator): + def __init__(self, column_name, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs): + """ + args: + column_name: column name to filter and aggregate + output_dir: str. directory to save the report + group_by: str. or list of str. column(s) to group by before aggregating + ignore_non_numeric: bool. if True ignore non-numeric values for average aggregator + filename_base: str. optional base string to be used in the file name for the report. If not None, the report filename will concatenate the class name, datetime, and filename_base. + """ + + self.column_name = column_name + self.group_by = group_by + self.output_dir = output_dir + self.aggregated_result = None + self.ignore_non_numeric = ignore_non_numeric + self.filename_base = filename_base + super().__init__([column_name], output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) + + def aggregate(self, data): + filtered_data = data[data[self.column_name] != "NA"].copy() + super().aggregate(filtered_data) + class AverageSTDDevAggregator(NumericalAggregator): diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja new file mode 100644 index 0000000..69159e6 --- /dev/null +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja @@ -0,0 +1,7 @@ +You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. +Make sure you use the availability schedules to generate your response. +High priority meetings should be scheduled as early as possible. +Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. +Respond with "[day] [start_time]-[end_time]" or "No common time slot available" +Do not respond with any additional information or comments. +{{prompt}} \ No newline at end of file From 26f980bf2e6396932a88b97390217e12d2f253cf Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Sun, 1 Dec 2024 11:32:53 -0800 Subject: [PATCH 02/10] ba-calendar - removing unused imports, needs metric review --- eureka_ml_insights/configs/ba_calendar.py | 11 +++---- eureka_ml_insights/configs/model_configs.py | 13 +++++++-- .../metrics/ba_calendar_metrics.py | 29 +++++++++++++++++-- eureka_ml_insights/models/models.py | 2 +- tests/pipeline_tests.py | 13 +++++++++ 5 files changed, 55 insertions(+), 13 deletions(-) diff --git a/eureka_ml_insights/configs/ba_calendar.py b/eureka_ml_insights/configs/ba_calendar.py index 8cd3bba..858fdfc 100644 --- a/eureka_ml_insights/configs/ba_calendar.py +++ b/eureka_ml_insights/configs/ba_calendar.py @@ -15,15 +15,11 @@ from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric from eureka_ml_insights.metrics.reports import ( AverageAggregator, - BiLevelAverageAggregator, NAFilteredAverageAggregator, - TwoColumnSumAverageAggregator, ) from .config import ( AggregatorConfig, - DataJoinConfig, - DataProcessingConfig, DataSetConfig, EvalReportingConfig, InferenceConfig, @@ -44,7 +40,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P data_reader_config=DataSetConfig( DataReader, { - "path": os.path.join("../local_benchmark_data/Natasha_benchmarks/datasets/datasets/", "ba_calendar.jsonl"), + "path": os.path.join("../local_benchmark_data/Natasha_benchmarks/datasets/datasets/", "ba_calendar_wkey.jsonl"), "transform": SequenceTransform([ ColumnRename(name_mapping={"task_prompt": "prompt"}), SamplerTransform(random_seed=5, sample_count=10), @@ -64,6 +60,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P ), output_dir=os.path.join(self.log_dir, "inference_result"), resume_from=resume_from, + # max_concurrent=4, ) # Configure the evaluation and reporting component for evaluation and dataset level aggregation @@ -74,7 +71,6 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), "format": ".jsonl", - # "transform": ColumnRename(name_mapping={"model_output": "response"}), }, ), metric_config=MetricConfig(BACalendarMetric), @@ -84,8 +80,9 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "column_names": [ "BACalendarMetric_all_correct", + "BACalendarMetric_fraction_passed" ], - "filename_base": "BaCal_AllCorrect_Aggregated", + "filename_base": "BaCal_OverallMetrics_Aggregated", }, ), AggregatorConfig( diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 23d9b9e..ab68beb 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -93,6 +93,14 @@ "key_vault_url": None, } +GEMINI_EXP_1114_PRO_CONFIG = ModelConfig( + GeminiModel, + { + "model_name": "gemini-exp-1114", + "secret_key_params": GEMINI_SECRET_KEY_PARAMS, + }, +) + GEMINI_V15_PRO_CONFIG = ModelConfig( GeminiModel, { @@ -111,8 +119,7 @@ # Claude models CLAUDE_SECRET_KEY_PARAMS = { - # "key_name": "your_claude_secret_key_name", - "key_name": "aif-eval-claude", + "key_name": "your_claude_secret_key_name", "local_keys_path": "keys/keys.json", "key_vault_url": None, } @@ -194,4 +201,4 @@ }, "model_name": "Mistral-large-2407", }, -) +) \ No newline at end of file diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py index d690cf2..7ae441b 100644 --- a/eureka_ml_insights/metrics/ba_calendar_metrics.py +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -1,7 +1,11 @@ +import ast import json import re +import numpy as np from datetime import datetime, timedelta +import pandas as pd + from eureka_ml_insights.metrics.metrics_base import CompositeMetric # Helper functions @@ -56,6 +60,20 @@ def filter_slots_by_constraints(time_slots, constraints, day): filtered_slots.append(slot) return filtered_slots +# def convert_to_bool(value): +# if isinstance(value, str): +# if value == 'True': +# return True +# elif value == 'False': +# return False +# elif value == 'na': +# return np.nan +# elif isinstance(value, (bool, np.bool_)): +# return value +# elif pd.isna(value): # Safely handle NaNs +# return np.nan +# return value + # ask_true_false returns a tuple bool, str where the bool is the answer and the str is the justification class BACalendarMetric(CompositeMetric): @@ -78,8 +96,9 @@ def __evaluate__(self, row): def run_programmatic_tests(self, instance): result = {} solution = instance['model_output'] + solution = solution.strip('"').strip('`').strip('\n') if not is_formatted(solution): - result['format_programmatic'] = 1 + result['format_programmatic'] = 1 #should be 0 result.update(self.check_availability_programmatic(instance, solution)) result.update(self.check_meeting_duration_programmatic(instance, solution)) result.update(self.check_buffer_time_programmatic(instance, solution)) @@ -88,10 +107,16 @@ def run_programmatic_tests(self, instance): result.update(self.check_specific_times_programmatic(instance, solution)) # Added programmatic specific times check result.update(self.check_priority_programmatic(instance, solution)) # Added model-based priority check all_correct = 1 + passed_constraints = [] for key, value in result.items(): if value == 0: all_correct = 0 + # if value != 'NA': + x = value + if x != 'NA' and pd.notna(x) and isinstance(x, int): + passed_constraints.append(value) result['all_correct'] = all_correct + result['fraction_passed'] = np.mean(passed_constraints) return result def is_formatted(self, solution): @@ -234,7 +259,7 @@ def check_priority_programmatic(self, instance, solution): buffer_time = constraints['buffer_time_before_and_after_meeting'] else: buffer_time = 0 - for day in params['days_of_week']: + for day in params['days_of_week']: # update this post cleaning up data! common_time_slots = None availability = json.loads(metadata['availability'].replace("'", '"')) for participant, schedule in availability.items(): diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index f8e2607..333f175 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -230,7 +230,7 @@ def __post_init__(self): } except ValueError: self.bearer_token_provider = get_bearer_token_provider( - AzureCliCredential(), "https://cognitiveservices.azure.com/.default" + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" ) headers = { "Content-Type": "application/json", diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index 8fc0fd3..bdfee19 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -247,6 +247,19 @@ def configure_pipeline(self): ] ) return config + +# class TEST_BA_Calendar_PIPELINE(IFEval_PIPELINE): +# # Test config the IFEval benchmark with TestModel and TestDataLoader +# def configure_pipeline(self): +# config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) +# self.data_processing_comp.data_reader_config.init_args["transform"] = SequenceTransform( +# [ +# RunPythonTransform("df['instruction_id_list_copy'] = df.loc[:, 'instruction_id_list']"), +# RunPythonTransform("df = df.explode(['instruction_id_list_copy'])"), +# SamplerTransform(sample_count=N_ITER, random_seed=99, stratify_by="instruction_id_list_copy"), +# ] +# ) +# return config class TEST_TOXIGEN_PIPELINE(ToxiGen_Discriminative_PIPELINE): From 858bfa4b0c6d1779d948fc22421968a3ab400368 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Sun, 8 Dec 2024 22:16:45 -0800 Subject: [PATCH 03/10] Added HFReader for dataset reading, pipeline tests --- eureka_ml_insights/configs/ba_calendar.py | 9 ++-- .../metrics/ba_calendar_metrics.py | 28 ++--------- tests/pipeline_tests.py | 50 ++++++++++++++----- 3 files changed, 48 insertions(+), 39 deletions(-) diff --git a/eureka_ml_insights/configs/ba_calendar.py b/eureka_ml_insights/configs/ba_calendar.py index 858fdfc..5527217 100644 --- a/eureka_ml_insights/configs/ba_calendar.py +++ b/eureka_ml_insights/configs/ba_calendar.py @@ -10,6 +10,7 @@ from eureka_ml_insights.data_utils.data import ( DataLoader, DataReader, + HFDataReader, ) from eureka_ml_insights.data_utils.transform import ColumnRename, SamplerTransform, SequenceTransform from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric @@ -38,12 +39,12 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P component_type=PromptProcessing, prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling.jinja"), data_reader_config=DataSetConfig( - DataReader, - { - "path": os.path.join("../local_benchmark_data/Natasha_benchmarks/datasets/datasets/", "ba_calendar_wkey.jsonl"), + HFDataReader, + { + "path": "microsoft/ba-calendar", + "split": "test", "transform": SequenceTransform([ ColumnRename(name_mapping={"task_prompt": "prompt"}), - SamplerTransform(random_seed=5, sample_count=10), ]), }, ), diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py index 7ae441b..4012f30 100644 --- a/eureka_ml_insights/metrics/ba_calendar_metrics.py +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -60,28 +60,11 @@ def filter_slots_by_constraints(time_slots, constraints, day): filtered_slots.append(slot) return filtered_slots -# def convert_to_bool(value): -# if isinstance(value, str): -# if value == 'True': -# return True -# elif value == 'False': -# return False -# elif value == 'na': -# return np.nan -# elif isinstance(value, (bool, np.bool_)): -# return value -# elif pd.isna(value): # Safely handle NaNs -# return np.nan -# return value - -# ask_true_false returns a tuple bool, str where the bool is the answer and the str is the justification - class BACalendarMetric(CompositeMetric): """ - Composite metric for evaluating if a response follows instructions. + Composite metric for evaluating if a response for each criteria. - This metric evaluates if a given response follows the provided instructions. - It calculates both strict and loose evaluation scores based on the response's adherence to the instructions. + This metric evaluates if a given response follows the provided constraints. """ def __init__(self): @@ -98,20 +81,19 @@ def run_programmatic_tests(self, instance): solution = instance['model_output'] solution = solution.strip('"').strip('`').strip('\n') if not is_formatted(solution): - result['format_programmatic'] = 1 #should be 0 + result['format_programmatic'] = 1 result.update(self.check_availability_programmatic(instance, solution)) result.update(self.check_meeting_duration_programmatic(instance, solution)) result.update(self.check_buffer_time_programmatic(instance, solution)) result.update(self.check_no_weekends_programmatic(instance, solution)) result.update(self.check_time_restrictions_programmatic(instance, solution)) - result.update(self.check_specific_times_programmatic(instance, solution)) # Added programmatic specific times check - result.update(self.check_priority_programmatic(instance, solution)) # Added model-based priority check + result.update(self.check_specific_times_programmatic(instance, solution)) + result.update(self.check_priority_programmatic(instance, solution)) all_correct = 1 passed_constraints = [] for key, value in result.items(): if value == 0: all_correct = 0 - # if value != 'NA': x = value if x != 'NA' and pd.notna(x) and isinstance(x, int): passed_constraints.append(value) diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index bdfee19..c7988a1 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -6,6 +6,8 @@ import jsonlines +from eureka_ml_insights.configs.ba_calendar import Calendar_Schedule_PIPELINE + # setup loggers logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -248,18 +250,19 @@ def configure_pipeline(self): ) return config -# class TEST_BA_Calendar_PIPELINE(IFEval_PIPELINE): -# # Test config the IFEval benchmark with TestModel and TestDataLoader -# def configure_pipeline(self): -# config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) -# self.data_processing_comp.data_reader_config.init_args["transform"] = SequenceTransform( -# [ -# RunPythonTransform("df['instruction_id_list_copy'] = df.loc[:, 'instruction_id_list']"), -# RunPythonTransform("df = df.explode(['instruction_id_list_copy'])"), -# SamplerTransform(sample_count=N_ITER, random_seed=99, stratify_by="instruction_id_list_copy"), -# ] -# ) -# return config +class TEST_BA_Calendar_PIPELINE(Calendar_Schedule_PIPELINE): + # Test config the BA Calendar benchmark with TestModel and TestDataLoader + def configure_pipeline(self): + config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) + self.data_processing_comp.data_reader_config.init_args["transform"].transforms.extend( + [ + RunPythonTransform("df = df.explode(['selected_constraints'])"), + SamplerTransform(sample_count=N_ITER, random_seed=99, stratify_by="selected_constraints"), + ] + ) + self.inference_comp.data_loader_config.class_name = TestDataLoader + self.inference_comp.data_loader_config.init_args["n_iter"] = N_ITER + return config class TEST_TOXIGEN_PIPELINE(ToxiGen_Discriminative_PIPELINE): @@ -433,6 +436,29 @@ def test_outputs_exist(self) -> None: n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) self.assertEqual(n_aggregators, n_aggregator_files) +@unittest.skipIf("skip_tests_with_missing_ds" in os.environ, "Missing public dataset. TODO: revert") +class BA_Calendar_PipelineTest(PipelineTest, unittest.TestCase): + def get_config(self): + self.test_pipeline = TEST_BA_Calendar_PIPELINE() + self.config = self.test_pipeline.pipeline_config + return self.config + + def setUp(self) -> None: + super().setUp() + self.eval_configs = [self.test_pipeline.evalreporting_comp] + + def test_outputs_exist(self) -> None: + logging.info("Running test_outputs_exist test in PipelineTest") + self.assertTrue(any("transformed_data.jsonl" in str(file) for file in self.files)) + if self.data_reader_config.prompt_template_path: + self.assertTrue(any("processed_prompts.jsonl" in str(file) for file in self.files)) + self.assertTrue(any("inference_result.jsonl" in str(file) for file in self.files)) + if self.eval_config.metric_config is not None: + self.assertTrue(any("metric_results.jsonl" in str(file) for file in self.files)) + n_aggregators = len([config for eval_config in self.eval_configs for config in eval_config.aggregator_configs]) + n_aggregator_files = len([file for file in self.files if "aggregator" in str(file)]) + self.assertEqual(n_aggregators, n_aggregator_files) + class TOXIGEN_PipelineTest(PipelineTest, unittest.TestCase): def get_config(self): From 1595e98fa36c45e61acc7565ed61da73f19f2d63 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Sun, 8 Dec 2024 22:33:39 -0800 Subject: [PATCH 04/10] Added comment on ba_calendar_metrics --- eureka_ml_insights/metrics/ba_calendar_metrics.py | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py index 4012f30..0bfd88c 100644 --- a/eureka_ml_insights/metrics/ba_calendar_metrics.py +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -1,3 +1,7 @@ +# This file was authored by BenchAgents authors and is being reused under the MIT license. +# All code in this file is directly copied from the original source repository. +# https://github.com/microsoft/benchagents + import ast import json import re From 76567ea8b5e8c71df0baec384291caa91cc52402 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Tue, 7 Jan 2025 01:05:19 +0530 Subject: [PATCH 05/10] addressed import issue --- tests/pipeline_tests.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index 0a549f4..ebea985 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -6,7 +6,6 @@ import jsonlines -from eureka_ml_insights.user_configs.ba_calendar import Calendar_Schedule_PIPELINE # setup loggers logging.basicConfig(level=logging.INFO, format="%(asctime)s - %(levelname)s - %(message)s") @@ -42,6 +41,7 @@ IFEval_PIPELINE, ToxiGen_Discriminative_PIPELINE, ToxiGen_Generative_PIPELINE, + Calendar_Schedule_PIPELINE, ) from tests.test_utils import ( DetectionTestModel, From f333868ad9528c11c6741b50ed5da1c5fc83ad34 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Tue, 7 Jan 2025 01:15:47 +0530 Subject: [PATCH 06/10] added ba_cal config --- .../calendar_scheduling_cot.jinja | 6 + .../user_configs/ba_calendar.py | 150 ++++++++++++++++++ 2 files changed, 156 insertions(+) create mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja create mode 100644 eureka_ml_insights/user_configs/ba_calendar.py diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja new file mode 100644 index 0000000..f83844f --- /dev/null +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja @@ -0,0 +1,6 @@ +You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. +Make sure you use the availability schedules to generate your response. +High priority meetings should be scheduled as early as possible. +Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. +Respond with "[day] [start_time]-[end_time]" or "No common time slot available" +{{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/user_configs/ba_calendar.py b/eureka_ml_insights/user_configs/ba_calendar.py new file mode 100644 index 0000000..654e55a --- /dev/null +++ b/eureka_ml_insights/user_configs/ba_calendar.py @@ -0,0 +1,150 @@ +import os +from tkinter import N + +from eureka_ml_insights.core import ( + Inference, + PromptProcessing, +) + +from eureka_ml_insights.core.eval_reporting import EvalReporting +from eureka_ml_insights.data_utils.data import ( + DataLoader, + DataReader, + HFDataReader, +) +from eureka_ml_insights.data_utils.transform import ColumnRename, SamplerTransform, SequenceTransform +from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric +from eureka_ml_insights.metrics.reports import ( + AverageAggregator, + NAFilteredAverageAggregator, +) + +from ..configs.config import ( + AggregatorConfig, + DataSetConfig, + EvalReportingConfig, + InferenceConfig, + MetricConfig, + PipelineConfig, + PromptProcessingConfig, +) +from ..configs.experiment_config import ExperimentConfig + +class Calendar_Schedule_PIPELINE(ExperimentConfig): + """This class specifies the config for running any benchmark on any model""" + + def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> PipelineConfig: + # data preprocessing + self.data_processing_comp = PromptProcessingConfig( + component_type=PromptProcessing, + prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling.jinja"), + data_reader_config=DataSetConfig( + HFDataReader, + { + "path": "microsoft/ba-calendar", + "split": "test", + "transform": SequenceTransform([ + ColumnRename(name_mapping={"task_prompt": "prompt"}), + ]), + }, + ), + output_dir=os.path.join(self.log_dir, "data_processing_output"), + ) + + # inference component + self.inference_comp = InferenceConfig( + component_type=Inference, + model_config=model_config, + data_loader_config=DataSetConfig( + DataLoader, + {"path": os.path.join(self.data_processing_comp.output_dir, "transformed_data.jsonl")}, + ), + output_dir=os.path.join(self.log_dir, "inference_result"), + resume_from=resume_from, + # max_concurrent=4, + ) + + # Configure the evaluation and reporting component for evaluation and dataset level aggregation + self.evalreporting_comp = EvalReportingConfig( + component_type=EvalReporting, + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), + "format": ".jsonl", + }, + ), + metric_config=MetricConfig(BACalendarMetric), + aggregator_configs=[ + AggregatorConfig( + AverageAggregator, + { + "column_names": [ + "BACalendarMetric_all_correct", + "BACalendarMetric_fraction_passed" + ], + "filename_base": "BaCal_OverallMetrics_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_availability_programmatic_check", + "filename_base": "BaCal_Availability_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_meeting_duration_programmatic_check", + "filename_base": "BaCal_MeetingDuration_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_buffer_time_programmatic_check", + "filename_base": "BaCal_BufferTime_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_no_weekends_programmatic_check", + "filename_base": "BaCal_NoWeekends_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_time_restrictions_programmatic_check", + "filename_base": "BaCal_TimeRestrictions_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_specific_times_programmatic_check", + "filename_base": "BaCal_SpecificTimes_Check_Aggregated", + }, + ), + AggregatorConfig( + NAFilteredAverageAggregator, + { + "column_name": "BACalendarMetric_priority_programmatic_check", + "filename_base": "BaCal_Priority_Check_Aggregated", + }, + ), + ], + output_dir=os.path.join(self.log_dir, "eval_report"), + ) + + # Configure the pipeline + return PipelineConfig( + [ + self.data_processing_comp, + self.inference_comp, + self.evalreporting_comp + ], + self.log_dir, + ) \ No newline at end of file From 36c8a874108f2243629c203e73d9d2333fc4211e Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Tue, 14 Jan 2025 18:53:06 -0800 Subject: [PATCH 07/10] PR Comments, Majority Vote + Best of N Evals --- eureka_ml_insights/configs/model_configs.py | 8 - .../metrics/ba_calendar_metrics.py | 11 +- eureka_ml_insights/metrics/reports.py | 108 +++++++--- .../calendar_scheduling.jinja | 7 - .../calendar_scheduling_cot.jinja | 6 - eureka_ml_insights/user_configs/__init__.py | 8 +- .../user_configs/ba_calendar.py | 188 ++++++++++++++---- tests/metric_utils_tests/aggregator_tests.py | 124 ++++++++++++ tests/pipeline_tests.py | 4 +- 9 files changed, 365 insertions(+), 99 deletions(-) delete mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja delete mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index f8bc131..779bf05 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -98,14 +98,6 @@ "key_vault_url": None, } -GEMINI_EXP_1114_PRO_CONFIG = ModelConfig( - GeminiModel, - { - "model_name": "gemini-exp-1114", - "secret_key_params": GEMINI_SECRET_KEY_PARAMS, - }, -) - GEMINI_V15_PRO_CONFIG = ModelConfig( GeminiModel, { diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py index 0bfd88c..cfe1176 100644 --- a/eureka_ml_insights/metrics/ba_calendar_metrics.py +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -13,7 +13,7 @@ from eureka_ml_insights.metrics.metrics_base import CompositeMetric # Helper functions -def is_formatted(solution): +def check_time_slot_format(solution): pattern = r"^(Monday|Tuesday|Wednesday|Thursday|Friday|Saturday|Sunday) ([0-9]|[01]\d|2[0-3]):[0-5]\d-([0-9]|[01]\d|2[0-3]):[0-5]\d$" return bool(re.match(pattern, solution)) @@ -84,7 +84,7 @@ def run_programmatic_tests(self, instance): result = {} solution = instance['model_output'] solution = solution.strip('"').strip('`').strip('\n') - if not is_formatted(solution): + if check_time_slot_format(solution): result['format_programmatic'] = 1 result.update(self.check_availability_programmatic(instance, solution)) result.update(self.check_meeting_duration_programmatic(instance, solution)) @@ -98,8 +98,7 @@ def run_programmatic_tests(self, instance): for key, value in result.items(): if value == 0: all_correct = 0 - x = value - if x != 'NA' and pd.notna(x) and isinstance(x, int): + if value != 'NA' and pd.notna(value) and isinstance(value, int): passed_constraints.append(value) result['all_correct'] = all_correct result['fraction_passed'] = np.mean(passed_constraints) @@ -109,7 +108,7 @@ def is_formatted(self, solution): run_tests=True if solution == self.no_solution_response: run_tests=False - if not is_formatted(solution): + if not check_time_slot_format(solution): run_tests=False return run_tests @@ -245,7 +244,7 @@ def check_priority_programmatic(self, instance, solution): buffer_time = constraints['buffer_time_before_and_after_meeting'] else: buffer_time = 0 - for day in params['days_of_week']: # update this post cleaning up data! + for day in params['days_of_week']: # TODO: revisit this post data release to ensure consistency common_time_slots = None availability = json.loads(metadata['availability'].replace("'", '"')) for participant, schedule in availability.items(): diff --git a/eureka_ml_insights/metrics/reports.py b/eureka_ml_insights/metrics/reports.py index 8c6bdc4..8e78466 100644 --- a/eureka_ml_insights/metrics/reports.py +++ b/eureka_ml_insights/metrics/reports.py @@ -110,6 +110,19 @@ def _aggregate_grouped(self, data): sums = {col: gb[col].sum().to_dict() for col in self.column_names} self.aggregated_result = sums +class MaxAggregator(NumericalAggregator): + """ + This class aggregates data by taking the max of the values.""" + + def _aggregate(self, data): + sums = {col: data[col].max() for col in self.column_names} + self.aggregated_result = sums + + def _aggregate_grouped(self, data): + gb = data.groupby(self.group_by) + sums = {col: gb[col].max().to_dict() for col in self.column_names} + self.aggregated_result = sums + class AverageAggregator(NumericalAggregator): @@ -128,30 +141,6 @@ def _aggregate_grouped(self, data): averages = {col: round(gb[col].mean(), 3).to_dict() for col in self.column_names} self.aggregated_result = averages -class NAFilteredAverageAggregator(AverageAggregator): - def __init__(self, column_name, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs): - """ - args: - column_name: column name to filter and aggregate - output_dir: str. directory to save the report - group_by: str. or list of str. column(s) to group by before aggregating - ignore_non_numeric: bool. if True ignore non-numeric values for average aggregator - filename_base: str. optional base string to be used in the file name for the report. If not None, the report filename will concatenate the class name, datetime, and filename_base. - """ - - self.column_name = column_name - self.group_by = group_by - self.output_dir = output_dir - self.aggregated_result = None - self.ignore_non_numeric = ignore_non_numeric - self.filename_base = filename_base - super().__init__([column_name], output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) - - def aggregate(self, data): - filtered_data = data[data[self.column_name] != "NA"].copy() - super().aggregate(filtered_data) - - class AverageSTDDevAggregator(NumericalAggregator): def _aggregate(self, data): @@ -244,6 +233,45 @@ def _aggregate(self, data): col_std = first_result[col].std() self.aggregated_result.append({col: {"mean": col_mean, "std": col_std}}) +class BiLevelMaxAggregator(Aggregator): + """ + This class aggregates the data in two levels. It first groups the data by the first_groupby column and + aggregates the data by taking the max of the column_names. It It then groups the result by the + second_groupby column and aggregates the it again by taking the mean and standard deviation of + the column_names. + """ + + def __init__(self, column_names, first_groupby, output_dir, second_groupby=None, **kwargs): + super().__init__(column_names, output_dir, group_by=None, **kwargs) + self.first_groupby = first_groupby + self.second_groupby = second_groupby + + def _aggregate(self, data): + # take the average of the column for each group in the first groupby, + # aggregate the rest of the columns by 'first' + gb = data.groupby(self.first_groupby) + agg_map = {col: "max" for col in self.column_names} + agg_map.update( + {col: "first" for col in data.columns if col not in self.column_names and col != self.first_groupby} + ) + + first_result = gb.aggregate(agg_map).reset_index() + if self.second_groupby: + # take the average and std of the first level aggregation for each group in the second groupby + gb = first_result.groupby(self.second_groupby) + agg_map = {col: ["mean", "std"] for col in self.column_names} + # flatten the multi-level column index + second_result = gb.agg(agg_map).reset_index() + second_result.columns = [f"{col}_{agg}" if agg else col for col, agg in second_result.columns] + self.aggregated_result = second_result.to_dict(orient="records") + else: + # take the average and std of the first level aggregation + self.aggregated_result = [] + for col in self.column_names: + col_mean = first_result[col].mean() + col_std = first_result[col].std() + self.aggregated_result.append({col: {"mean": col_mean, "std": col_std}}) + class BiLevelCountAggregator(Aggregator): """ @@ -320,6 +348,38 @@ def _aggregate_grouped(self, data): divided_result = (gb[self.numerator_column_name].sum() / gb[self.denominator_column_name].sum()).to_dict() self.aggregated_result = {"ratio": divided_result} +class NAFilteredAggregator(Aggregator): + def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs): + """ + Aggregator that filters out "NA" values before aggregating the data. + args: + agg_class: Aggregator class to use for aggregation + column_names: column names to filter and aggregate + output_dir: str. directory to save the report + group_by: str. or list of str. column(s) to group by before aggregating + ignore_non_numeric: bool. if True ignore non-numeric values for average aggregator + filename_base: str. optional base string to be used in the file name for the report. If not None, the report filename will concatenate the class name, datetime, and filename_base. + """ + + self.base_aggregator = agg_class(column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) + self.column_names = column_names + self.group_by = group_by + self.output_dir = output_dir + self.aggregated_result = None + self.ignore_non_numeric = ignore_non_numeric + self.filename_base = filename_base + # super().__init__(self.input_column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) + + def aggregate(self, data): + agg_results = {} + for col in self.column_names: + # workaround to process one column at a time + filtered_data = data[data[col] != "NA"].copy() + self.base_aggregator.column_names = [col] + self.base_aggregator.aggregate(filtered_data) + agg_results.update(self.base_aggregator.aggregated_result) + self.aggregated_result = agg_results + class CocoDetectionAggregator(Aggregator): """This class uses the coco tools to calculated AP50 for the provided detections.""" diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja deleted file mode 100644 index 69159e6..0000000 --- a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling.jinja +++ /dev/null @@ -1,7 +0,0 @@ -You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. -Make sure you use the availability schedules to generate your response. -High priority meetings should be scheduled as early as possible. -Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. -Respond with "[day] [start_time]-[end_time]" or "No common time slot available" -Do not respond with any additional information or comments. -{{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja deleted file mode 100644 index f83844f..0000000 --- a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja +++ /dev/null @@ -1,6 +0,0 @@ -You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. -Make sure you use the availability schedules to generate your response. -High priority meetings should be scheduled as early as possible. -Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. -Respond with "[day] [start_time]-[end_time]" or "No common time slot available" -{{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/user_configs/__init__.py b/eureka_ml_insights/user_configs/__init__.py index 133441b..976a5ea 100644 --- a/eureka_ml_insights/user_configs/__init__.py +++ b/eureka_ml_insights/user_configs/__init__.py @@ -9,7 +9,10 @@ AIME_PIPELINE512Run, AIME_PIPELINE1024Run, ) -from .ba_calendar import Calendar_Schedule_PIPELINE +from .ba_calendar import ( + BA_Calendar_PIPELINE, + BA_Calendar_Parallel_PIPELINE, +) from .dna import DNA_PIPELINE from .drop import Drop_Experiment_Pipeline from .flenqa import FlenQA_Experiment_Pipeline @@ -114,7 +117,8 @@ KITAB_TWO_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, GPT35_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, DNA_PIPELINE, - Calendar_Schedule_PIPELINE, + BA_Calendar_PIPELINE, + BA_Calendar_Parallel_PIPELINE, ToxiGen_Discriminative_PIPELINE, ToxiGen_Generative_PIPELINE, Geo_Nondeterminism, diff --git a/eureka_ml_insights/user_configs/ba_calendar.py b/eureka_ml_insights/user_configs/ba_calendar.py index 654e55a..2b08ddd 100644 --- a/eureka_ml_insights/user_configs/ba_calendar.py +++ b/eureka_ml_insights/user_configs/ba_calendar.py @@ -1,51 +1,57 @@ import os -from tkinter import N +from typing import Any from eureka_ml_insights.core import ( Inference, PromptProcessing, ) +from eureka_ml_insights.core.data_processing import DataProcessing from eureka_ml_insights.core.eval_reporting import EvalReporting from eureka_ml_insights.data_utils.data import ( DataLoader, DataReader, HFDataReader, ) -from eureka_ml_insights.data_utils.transform import ColumnRename, SamplerTransform, SequenceTransform +from eureka_ml_insights.data_utils.transform import AddColumn, ColumnRename, CopyColumn, MajorityVoteTransform, MultiplyTransform, RunPythonTransform, SamplerTransform, SequenceTransform from eureka_ml_insights.metrics.ba_calendar_metrics import BACalendarMetric from eureka_ml_insights.metrics.reports import ( AverageAggregator, - NAFilteredAverageAggregator, + BiLevelMaxAggregator, + MaxAggregator, + NAFilteredAggregator, ) from ..configs.config import ( AggregatorConfig, + DataProcessingConfig, DataSetConfig, EvalReportingConfig, InferenceConfig, MetricConfig, PipelineConfig, PromptProcessingConfig, + ModelConfig, ) from ..configs.experiment_config import ExperimentConfig -class Calendar_Schedule_PIPELINE(ExperimentConfig): +class BA_Calendar_PIPELINE(ExperimentConfig): """This class specifies the config for running any benchmark on any model""" def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> PipelineConfig: # data preprocessing self.data_processing_comp = PromptProcessingConfig( component_type=PromptProcessing, - prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling.jinja"), + prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja"), data_reader_config=DataSetConfig( HFDataReader, { - "path": "microsoft/ba-calendar", - "split": "test", - "transform": SequenceTransform([ - ColumnRename(name_mapping={"task_prompt": "prompt"}), - ]), + "path": "microsoft/ba-calendar", + "split": "test", + "transform": SequenceTransform([ + ColumnRename(name_mapping={"task_prompt": "prompt"}), + MultiplyTransform(n_repeats=1), + ]), }, ), output_dir=os.path.join(self.log_dir, "data_processing_output"), @@ -61,7 +67,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P ), output_dir=os.path.join(self.log_dir, "inference_result"), resume_from=resume_from, - # max_concurrent=4, + max_concurrent=1, ) # Configure the evaluation and reporting component for evaluation and dataset level aggregation @@ -83,60 +89,140 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P "BACalendarMetric_all_correct", "BACalendarMetric_fraction_passed" ], - "filename_base": "BaCal_OverallMetrics_Aggregated", + "filename_base": "BaCal_OverallMetrics_SeparateRuns", + "group_by": "data_repeat_id", }, ), AggregatorConfig( - NAFilteredAverageAggregator, + NAFilteredAggregator, { - "column_name": "BACalendarMetric_availability_programmatic_check", - "filename_base": "BaCal_Availability_Check_Aggregated", - }, - ), - AggregatorConfig( - NAFilteredAverageAggregator, - { - "column_name": "BACalendarMetric_meeting_duration_programmatic_check", - "filename_base": "BaCal_MeetingDuration_Check_Aggregated", - }, - ), - AggregatorConfig( - NAFilteredAverageAggregator, - { - "column_name": "BACalendarMetric_buffer_time_programmatic_check", - "filename_base": "BaCal_BufferTime_Check_Aggregated", + "agg_class": AverageAggregator, + "column_names": [ + "BACalendarMetric_availability_programmatic_check", + "BACalendarMetric_meeting_duration_programmatic_check", + "BACalendarMetric_buffer_time_programmatic_check", + "BACalendarMetric_no_weekends_programmatic_check", + "BACalendarMetric_time_restrictions_programmatic_check", + "BACalendarMetric_specific_times_programmatic_check", + "BACalendarMetric_priority_programmatic_check" + ], + "filename_base": "BaCal_Constraint_Level_SeprateRuns", + "group_by": "data_repeat_id", }, ), AggregatorConfig( - NAFilteredAverageAggregator, + BiLevelMaxAggregator, { - "column_name": "BACalendarMetric_no_weekends_programmatic_check", - "filename_base": "BaCal_NoWeekends_Check_Aggregated", + "column_names": [ + "BACalendarMetric_all_correct", + "BACalendarMetric_fraction_passed" + ], + "first_groupby": "data_point_id", + "filename_base": "BaCal_BestOfN_Aggregated", + "normalize": True, }, ), AggregatorConfig( - NAFilteredAverageAggregator, + NAFilteredAggregator, { - "column_name": "BACalendarMetric_time_restrictions_programmatic_check", - "filename_base": "BaCal_TimeRestrictions_Check_Aggregated", + "agg_class": MaxAggregator, + "column_names": [ + "BACalendarMetric_availability_programmatic_check", + "BACalendarMetric_meeting_duration_programmatic_check", + "BACalendarMetric_buffer_time_programmatic_check", + "BACalendarMetric_no_weekends_programmatic_check", + "BACalendarMetric_time_restrictions_programmatic_check", + "BACalendarMetric_specific_times_programmatic_check", + "BACalendarMetric_priority_programmatic_check" + ], + "filename_base": "BaCal_Constraint_Level_BestOfN_Aggregated", + "group_by": "data_repeat_id", }, ), + + + ], + output_dir=os.path.join(self.log_dir, "eval_report"), + ) + + # Aggregate the results by a majority vote + # First, let us perform majority_vote + self.data_post_processing_addmv = DataProcessingConfig( + component_type=DataProcessing, + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), + "format": ".jsonl", + "transform": SequenceTransform( + [ + ColumnRename( + name_mapping={ + "model_output": "raw_output", + } + ), + AddColumn("model_output"), + MajorityVoteTransform(model_output_col="raw_output"), + CopyColumn("majority_vote", "model_output"), + ColumnRename( + name_mapping={ + "raw_output": "model_output_onerun", + "majority_vote": "model_output", + } + ), + ] + ), + }, + ), + output_dir=os.path.join(self.log_dir, "data_majvote_output"), + ) + # Second, compute eaxct match + self.postevalprocess_comp = EvalReportingConfig( + component_type=EvalReporting, + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join(self.data_post_processing_addmv.output_dir, "transformed_data.jsonl"), + "format": ".jsonl", + "transform": SequenceTransform( + [ + RunPythonTransform("df = df[df['data_repeat_id'] == 'repeat_0']"), + ] + ), + }, + ), + metric_config=MetricConfig(BACalendarMetric), + aggregator_configs=[ AggregatorConfig( - NAFilteredAverageAggregator, + AverageAggregator, { - "column_name": "BACalendarMetric_specific_times_programmatic_check", - "filename_base": "BaCal_SpecificTimes_Check_Aggregated", + "column_names": [ + "BACalendarMetric_all_correct", + "BACalendarMetric_fraction_passed" + ], + "filename_base": "BaCal_MajVote_OverallMetrics_Aggregated", + "group_by": "data_repeat_id", }, ), AggregatorConfig( - NAFilteredAverageAggregator, + NAFilteredAggregator, { - "column_name": "BACalendarMetric_priority_programmatic_check", - "filename_base": "BaCal_Priority_Check_Aggregated", + "agg_class": AverageAggregator, + "column_names": [ + "BACalendarMetric_availability_programmatic_check", + "BACalendarMetric_meeting_duration_programmatic_check", + "BACalendarMetric_buffer_time_programmatic_check", + "BACalendarMetric_no_weekends_programmatic_check", + "BACalendarMetric_time_restrictions_programmatic_check", + "BACalendarMetric_specific_times_programmatic_check", + "BACalendarMetric_priority_programmatic_check" + ], + "filename_base": "BaCal_MajVote_Constraint_Level_Aggregated", + "group_by": "data_repeat_id", }, ), ], - output_dir=os.path.join(self.log_dir, "eval_report"), + output_dir=os.path.join(self.log_dir, "majvote_eval_report"), ) # Configure the pipeline @@ -144,7 +230,21 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P [ self.data_processing_comp, self.inference_comp, - self.evalreporting_comp + self.evalreporting_comp, + self.data_post_processing_addmv, + self.postevalprocess_comp ], self.log_dir, - ) \ No newline at end of file + + ) + +class BA_Calendar_Parallel_PIPELINE(BA_Calendar_PIPELINE): + """This class specifies the config for running BA Calendar benchmark 5 repeated times""" + + def configure_pipeline( + self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any] + ) -> PipelineConfig: + pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from) + # data preprocessing + self.data_processing_comp.data_reader_config.init_args["transform"].transforms[-1] = MultiplyTransform(n_repeats=5) + return pipeline diff --git a/tests/metric_utils_tests/aggregator_tests.py b/tests/metric_utils_tests/aggregator_tests.py index 04c2ffb..eb1cba0 100644 --- a/tests/metric_utils_tests/aggregator_tests.py +++ b/tests/metric_utils_tests/aggregator_tests.py @@ -15,6 +15,7 @@ SumAggregator, TwoColumnSumAverageAggregator, ) +from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, NAFilteredAggregator PRECISION = 3 @@ -83,6 +84,43 @@ def test_average_aggregator_group_by_multiple_columns(self): avg_agg.write_results() self.assertTrue(os.path.exists(avg_agg.output_file)) +class TestMaxAggregator(TestData, unittest.TestCase): + def test_max_aggregator(self): + avg_agg = MaxAggregator(["col1", "col2"], self.output_dir) + avg_agg.aggregate(self.data) + self.assertEqual( + avg_agg.aggregated_result, + { + "col1": max(self.data["col1"]), + "col2": max(self.data["col2"]), + }, + ) + + def test_max_aggregator_input_validation(self): + avg_agg = MaxAggregator("col3", self.output_dir) + self.assertRaises(ValueError, avg_agg.aggregate, self.data) + + def test_max_aggregator_group_by(self): + avg_agg = MaxAggregator(["col1", "col2"], self.output_dir, group_by="col3") + avg_agg.aggregate(self.data) + self.assertEqual(avg_agg.aggregated_result, {"col1": {"a": 3, "c": 6}, "col2": {"a": 3, "c": 3}}) + + def test_max_aggregator_group_by_multiple_columns(self): + self.output_dir = create_logdir("MaxAggregatorTests") + + avg_agg = MaxAggregator(["col1", "col2"], self.output_dir, group_by=["col3", "col4"]) + avg_agg.aggregate(self.data) + self.assertEqual( + avg_agg.aggregated_result, + { + "col1": {("a_x"): 2, ("a_y"): 3, ("c_y"): 6}, + "col2": {("a_x"): 2, ("a_y"): 3, ("c_y"): 3}, + }, + ) + + avg_agg.write_results() + self.assertTrue(os.path.exists(avg_agg.output_file)) + class TestCountAggregator(TestData, unittest.TestCase): def test_count_aggregator(self): @@ -220,6 +258,48 @@ def test_bilevel_average_aggregator_2(self): ] self.assertEqual(avg_agg.aggregated_result, expected) +class BiLevelMaxAggregatorTest(BiLevelAggregatorTestData, unittest.TestCase): + def test_bilevel_average_aggregator(self): + avg_agg = BiLevelMaxAggregator( + ["numeric_metric"], first_groupby="data_point_id", second_groupby="group", output_dir=self.output_dir + ) + avg_agg.aggregate(self.data) + expected = [ + { + "group": "a", + "numeric_metric_mean": np.mean([np.max([5, 6, 5]), np.max([8, 8, 8])]), + "numeric_metric_std": np.std([np.max([5, 6, 5]), np.max([8, 8, 8])], ddof=1), + }, + { + "group": "b", + "numeric_metric_mean": np.mean([np.max([2, 3, 4]), np.max([3, 4, 2])]), + "numeric_metric_std": np.std([np.max([2, 3, 4]), np.max([3, 4, 2])], ddof=1), + }, + ] + + for i in range(len(avg_agg.aggregated_result)): + self.assertAlmostEqual( + avg_agg.aggregated_result[i]["numeric_metric_mean"], + expected[i]["numeric_metric_mean"], + places=self.precision, + ) + self.assertAlmostEqual( + avg_agg.aggregated_result[i]["numeric_metric_std"], + expected[i]["numeric_metric_std"], + places=self.precision, + ) + + def test_bilevel_average_aggregator_2(self): + avg_agg = BiLevelMaxAggregator( + ["numeric_metric"], first_groupby="data_repeat_id", second_groupby=None, output_dir=self.output_dir + ) + avg_agg.aggregate(self.data) + expected_first_level = [np.max([5, 8, 2, 3]), np.max([6, 8, 3, 4]), np.max([5, 8, 4, 2])] + expected = [ + {"numeric_metric": {"mean": np.mean(expected_first_level), "std": np.std(expected_first_level, ddof=1)}} + ] + self.assertEqual(avg_agg.aggregated_result, expected) + class BiLevelCountAggregatorTest(BiLevelAggregatorTestData, unittest.TestCase): def test_bilevel_count_aggregator(self): @@ -339,5 +419,49 @@ def test_average_aggregator_group_by_multiple_columns(self): self.assertTrue(os.path.exists(avg_agg.output_file)) +class NAFilteredAggregatorTestData: + def setUp(self): + self.data = pd.DataFrame( + { + "data_point_id": [1, 2, 3, 4, 1, 2, 3, 4, 1, 2, 3, 4], + "data_repeat_id": [1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3], + "col1": [5, 'NA', 2, 3, 6, 8, 3, 'NA', 5, 8, 4, 2], + "col2": [5, 8, 2, 'NA', 6, 8, 3, 4, 5, 8, 'NA', 2], + "col3": [5, 8, 'NA', 3, 'abc', 8, 3, 4, 5, 8, 4, 2], + "categorical_metric": ["x", "y", "z", "z", "y", "y", "z", "y", "x", "y", "y", "x"], + "group": ["a", "a", "b", "b", "a", "a", "b", "b", "a", "a", "b", "b"], + # [5, 6, 8, 5, 8, ] + # [2, 3, 3, 4, 2] + # [5, 8, 6, 8, 5, 8, ] + # [2, 3, 4, 2] + } + ) + self.output_dir = "output_dir" + self.precision = PRECISION + +class TestNAFilteredAggregator(NAFilteredAggregatorTestData, unittest.TestCase): + def test_average_aggregator(self): + avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir) + avg_agg.aggregate(self.data) + x = [a for a in self.data["col1"] if a != 'NA'] + y = [a for a in self.data["col2"] if a != 'NA'] + self.assertEqual( + avg_agg.aggregated_result, + {"col1": sum(x) / len(x), "col2": sum(y) / len(y)}, + ) + + def test_average_aggregator_input_validation(self): + avg_agg = NAFilteredAggregator(AverageAggregator, ["col3"], self.output_dir) + self.assertRaises(ValueError, avg_agg.aggregate, self.data) + + def test_average_aggregator_group_by(self): + self.output_dir = create_logdir("NAFilteredAggregatorTests") + avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir, group_by="group") + avg_agg.aggregate(self.data) + self.assertEqual(avg_agg.aggregated_result, {"col1": {"a": 6.4, "b": 2.8}, "col2": {"a": 6.667, "b": 2.75}}) + avg_agg.write_results() + self.assertTrue(os.path.exists(avg_agg.output_file)) + + if __name__ == "__main__": unittest.main() diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index ebea985..10edefd 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -41,7 +41,7 @@ IFEval_PIPELINE, ToxiGen_Discriminative_PIPELINE, ToxiGen_Generative_PIPELINE, - Calendar_Schedule_PIPELINE, + BA_Calendar_PIPELINE, ) from tests.test_utils import ( DetectionTestModel, @@ -252,7 +252,7 @@ def configure_pipeline(self): ) return config -class TEST_BA_Calendar_PIPELINE(Calendar_Schedule_PIPELINE): +class TEST_BA_Calendar_PIPELINE(BA_Calendar_PIPELINE): # Test config the BA Calendar benchmark with TestModel and TestDataLoader def configure_pipeline(self): config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) From 2992d0a7d400b365dc33f899e1ad5b7591bf59d5 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Tue, 14 Jan 2025 18:55:11 -0800 Subject: [PATCH 08/10] renamed prompt templates --- .../ba_calendar_templates/calendar_scheduling_brief.jinja | 7 +++++++ .../calendar_scheduling_regular.jinja | 6 ++++++ 2 files changed, 13 insertions(+) create mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja create mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_regular.jinja diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja new file mode 100644 index 0000000..69159e6 --- /dev/null +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja @@ -0,0 +1,7 @@ +You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. +Make sure you use the availability schedules to generate your response. +High priority meetings should be scheduled as early as possible. +Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. +Respond with "[day] [start_time]-[end_time]" or "No common time slot available" +Do not respond with any additional information or comments. +{{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_regular.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_regular.jinja new file mode 100644 index 0000000..f83844f --- /dev/null +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_regular.jinja @@ -0,0 +1,6 @@ +You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. +Make sure you use the availability schedules to generate your response. +High priority meetings should be scheduled as early as possible. +Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. +Respond with "[day] [start_time]-[end_time]" or "No common time slot available" +{{prompt}} \ No newline at end of file From b1cab65f5e04a3aac2a9f5883abb5192328aef82 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Wed, 15 Jan 2025 11:31:53 -0800 Subject: [PATCH 09/10] refactor configs, add answer extraction for cot setting --- eureka_ml_insights/configs/model_configs.py | 1 + .../data_utils/ba_calendar_utils.py | 36 ++++++++++++++ .../calendar_scheduling_cot.jinja | 9 ++++ .../user_configs/ba_calendar.py | 47 ++++++++++++++----- 4 files changed, 80 insertions(+), 13 deletions(-) create mode 100644 eureka_ml_insights/data_utils/ba_calendar_utils.py create mode 100644 eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 779bf05..70fbb30 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -15,6 +15,7 @@ RestEndpointModel, TestModel, ) +from eureka_ml_insights.models.models import AzureOpenAIModel from .config import ModelConfig diff --git a/eureka_ml_insights/data_utils/ba_calendar_utils.py b/eureka_ml_insights/data_utils/ba_calendar_utils.py new file mode 100644 index 0000000..5a8c1df --- /dev/null +++ b/eureka_ml_insights/data_utils/ba_calendar_utils.py @@ -0,0 +1,36 @@ +import re +from dataclasses import dataclass + +import pandas as pd + +from .transform import DFTransformBase + + +@dataclass +class BA_Calendar_ExtractAnswer(DFTransformBase): + model_output_column: str + model_answer_column: str + + def transform(self, df: pd.DataFrame) -> pd.DataFrame: + df[self.model_answer_column] = df[self.model_output_column].apply(self.parse_output_answer) + return df + + @staticmethod + def parse_output_answer(response): + """ + Parse the input string to extract answer of a given BA Calendar problems. + Parameters: + response (str): Input string containing answer X in the form of "Final Answer: X". + Returns: + numerical_value (float): A numeric value representing the model's answer. + """ + answer = "" + + # Try to find an answer in the "Final Answer: X" format + print(response) + match = re.search(r"(?i)(?<=Final Answer: ).*", response) + print(match) + if match: + answer = match.group(0) + + return answer \ No newline at end of file diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja new file mode 100644 index 0000000..d63accd --- /dev/null +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja @@ -0,0 +1,9 @@ +You are a scheduling assistant. Given the availability schedules of multiple participants and some additional constraints, your task is to find a common time slot. +Make sure you use the availability schedules to generate your response. +High priority meetings should be scheduled as early as possible. +Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. +The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available". +Think through and provide your your answer in the format: +Reason: +Final Answer: +{{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/user_configs/ba_calendar.py b/eureka_ml_insights/user_configs/ba_calendar.py index 2b08ddd..0699567 100644 --- a/eureka_ml_insights/user_configs/ba_calendar.py +++ b/eureka_ml_insights/user_configs/ba_calendar.py @@ -8,6 +8,7 @@ from eureka_ml_insights.core.data_processing import DataProcessing from eureka_ml_insights.core.eval_reporting import EvalReporting +from eureka_ml_insights.data_utils.ba_calendar_utils import BA_Calendar_ExtractAnswer from eureka_ml_insights.data_utils.data import ( DataLoader, DataReader, @@ -42,7 +43,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P # data preprocessing self.data_processing_comp = PromptProcessingConfig( component_type=PromptProcessing, - prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja"), + prompt_template_path=os.path.join(os.path.dirname(__file__), "../prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja"), data_reader_config=DataSetConfig( HFDataReader, { @@ -52,7 +53,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P ColumnRename(name_mapping={"task_prompt": "prompt"}), MultiplyTransform(n_repeats=1), ]), - }, + } ), output_dir=os.path.join(self.log_dir, "data_processing_output"), ) @@ -78,6 +79,17 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), "format": ".jsonl", + "transform": SequenceTransform( + [ + ColumnRename( + name_mapping={ + "model_output": "raw_output", + } + ), + AddColumn("model_output"), + BA_Calendar_ExtractAnswer("raw_output", "model_output"), + ] + ), }, ), metric_config=MetricConfig(BACalendarMetric), @@ -110,6 +122,21 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P "group_by": "data_repeat_id", }, ), + ], + output_dir=os.path.join(self.log_dir, "eval_report"), + ) + + # Aggregate the results by best of n + self.bon_evalreporting_comp = EvalReportingConfig( + component_type=EvalReporting, + data_reader_config=DataSetConfig( + DataReader, + { + "path": os.path.join(self.evalreporting_comp.output_dir, "metric_results.jsonl"), + "format": ".jsonl", + }, + ), + aggregator_configs=[ AggregatorConfig( BiLevelMaxAggregator, { @@ -142,34 +169,27 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P ], - output_dir=os.path.join(self.log_dir, "eval_report"), + output_dir=os.path.join(self.log_dir, "bestofn_eval_report"), ) # Aggregate the results by a majority vote - # First, let us perform majority_vote self.data_post_processing_addmv = DataProcessingConfig( component_type=DataProcessing, data_reader_config=DataSetConfig( DataReader, { - "path": os.path.join(self.inference_comp.output_dir, "inference_result.jsonl"), + "path": os.path.join(self.evalreporting_comp.output_dir, "metric_results.jsonl"), "format": ".jsonl", "transform": SequenceTransform( [ ColumnRename( name_mapping={ - "model_output": "raw_output", + "model_output": "model_output_onerun", } ), AddColumn("model_output"), - MajorityVoteTransform(model_output_col="raw_output"), + MajorityVoteTransform(model_output_col="model_output_onerun"), CopyColumn("majority_vote", "model_output"), - ColumnRename( - name_mapping={ - "raw_output": "model_output_onerun", - "majority_vote": "model_output", - } - ), ] ), }, @@ -231,6 +251,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P self.data_processing_comp, self.inference_comp, self.evalreporting_comp, + self.bon_evalreporting_comp, self.data_post_processing_addmv, self.postevalprocess_comp ], From b17aefd1c612d1594acecfd30874bad3bf3ec925 Mon Sep 17 00:00:00 2001 From: Vidhisha Balachandran Date: Fri, 17 Jan 2025 10:54:34 -0800 Subject: [PATCH 10/10] change NA to none, test fixes --- eureka_ml_insights/configs/model_configs.py | 2 +- .../metrics/ba_calendar_metrics.py | 23 ++++--- eureka_ml_insights/metrics/reports.py | 11 ++-- .../calendar_scheduling_brief.jinja | 5 +- .../calendar_scheduling_cot.jinja | 2 +- .../user_configs/ba_calendar.py | 60 ++++--------------- tests/metric_utils_tests/aggregator_tests.py | 18 +++--- tests/pipeline_tests.py | 2 +- 8 files changed, 47 insertions(+), 76 deletions(-) diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 70fbb30..3b9d816 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -199,4 +199,4 @@ }, "model_name": "Mistral-large-2407", }, -) \ No newline at end of file +) diff --git a/eureka_ml_insights/metrics/ba_calendar_metrics.py b/eureka_ml_insights/metrics/ba_calendar_metrics.py index cfe1176..98f4ec6 100644 --- a/eureka_ml_insights/metrics/ba_calendar_metrics.py +++ b/eureka_ml_insights/metrics/ba_calendar_metrics.py @@ -98,7 +98,7 @@ def run_programmatic_tests(self, instance): for key, value in result.items(): if value == 0: all_correct = 0 - if value != 'NA' and pd.notna(value) and isinstance(value, int): + if value is not None and value != 'NA' and pd.notna(value) and isinstance(value, int): passed_constraints.append(value) result['all_correct'] = all_correct result['fraction_passed'] = np.mean(passed_constraints) @@ -114,7 +114,8 @@ def is_formatted(self, solution): def check_availability_programmatic(self, instance, solution): if not instance['constraints'].get('availability', True): - result = {'availability_programmatic_check': 'NA'} + # result = {'availability_programmatic_check': 'NA'} + result = {'availability_programmatic_check': None} return result if not self.is_formatted(solution): @@ -144,7 +145,8 @@ def check_availability_programmatic(self, instance, solution): def check_meeting_duration_programmatic(self, instance, solution): if not instance['constraints'].get('meeting_duration', True): - result = {'meeting_duration_programmatic_check': 'NA'} + # result = {'meeting_duration_programmatic_check': 'NA'} + result = {'meeting_duration_programmatic_check': None} return result if not self.is_formatted(solution): @@ -162,7 +164,8 @@ def check_meeting_duration_programmatic(self, instance, solution): def check_buffer_time_programmatic(self, instance, solution): buffer_time = instance['constraints'].get('buffer_time_before_and_after_meeting', True) if buffer_time is None or not buffer_time: - result = {'buffer_time_programmatic_check': 'NA'} + # result = {'buffer_time_programmatic_check': 'NA'} + result = {'buffer_time_programmatic_check': None} return result if not self.is_formatted(solution): @@ -195,7 +198,8 @@ def check_buffer_time_programmatic(self, instance, solution): def check_no_weekends_programmatic(self, instance, solution): if not instance['constraints'].get('no_meetings_on_weekends', True): - return {'no_weekends_programmatic_check': 'NA'} + # return {'no_weekends_programmatic_check': 'NA'} + return {'no_weekends_programmatic_check': None} if not self.is_formatted(solution): return {'no_weekends_programmatic_check': 0} @@ -207,7 +211,8 @@ def check_no_weekends_programmatic(self, instance, solution): def check_time_restrictions_programmatic(self, instance, solution): if not instance['constraints'].get('no_meetings_before', True) and not instance['constraints'].get('no_meetings_after', True): - return {'time_restrictions_programmatic_check': 'NA'} + # return {'time_restrictions_programmatic_check': 'NA'} + return {'time_restrictions_programmatic_check': None} if not self.is_formatted(solution): return {'time_restrictions_programmatic_check': 0} @@ -231,7 +236,8 @@ def check_time_restrictions_programmatic(self, instance, solution): def check_priority_programmatic(self, instance, solution): if not instance['constraints'].get('high_priority_meeting', False): - return {'priority_programmatic_check': 'NA'} + # return {'priority_programmatic_check': 'NA'} + return {'priority_programmatic_check': None} if not self.is_formatted(solution): return {'priority_programmatic_check': 0} @@ -269,7 +275,8 @@ def check_priority_programmatic(self, instance, solution): def check_specific_times_programmatic(self, instance, solution): if not instance['constraints'].get('no_meetings_during_specific_times', True): - return {'specific_times_programmatic_check': 'NA'} + # return {'specific_times_programmatic_check': 'NA'} + return {'specific_times_programmatic_check': None} if not self.is_formatted(solution): return {'specific_times_programmatic_check': 0} diff --git a/eureka_ml_insights/metrics/reports.py b/eureka_ml_insights/metrics/reports.py index 8e78466..9175d28 100644 --- a/eureka_ml_insights/metrics/reports.py +++ b/eureka_ml_insights/metrics/reports.py @@ -348,12 +348,13 @@ def _aggregate_grouped(self, data): divided_result = (gb[self.numerator_column_name].sum() / gb[self.denominator_column_name].sum()).to_dict() self.aggregated_result = {"ratio": divided_result} -class NAFilteredAggregator(Aggregator): - def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs): +class ValueFilteredAggregator(Aggregator): + def __init__(self, agg_class, value, column_names, output_dir, group_by=None, ignore_non_numeric=False, filename_base=None, **kwargs): """ - Aggregator that filters out "NA" values before aggregating the data. + Aggregator that filters out a particular value before aggregating the data. args: agg_class: Aggregator class to use for aggregation + value: value to filter out column_names: column names to filter and aggregate output_dir: str. directory to save the report group_by: str. or list of str. column(s) to group by before aggregating @@ -362,19 +363,19 @@ def __init__(self, agg_class, column_names, output_dir, group_by=None, ignore_no """ self.base_aggregator = agg_class(column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) + self.value = value self.column_names = column_names self.group_by = group_by self.output_dir = output_dir self.aggregated_result = None self.ignore_non_numeric = ignore_non_numeric self.filename_base = filename_base - # super().__init__(self.input_column_names, output_dir, group_by, ignore_non_numeric, filename_base, **kwargs) def aggregate(self, data): agg_results = {} for col in self.column_names: # workaround to process one column at a time - filtered_data = data[data[col] != "NA"].copy() + filtered_data = data[data[col] != self.value].copy() self.base_aggregator.column_names = [col] self.base_aggregator.aggregate(filtered_data) agg_results.update(self.base_aggregator.aggregated_result) diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja index 69159e6..e6425f5 100644 --- a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_brief.jinja @@ -2,6 +2,7 @@ You are a scheduling assistant. Given the availability schedules of multiple par Make sure you use the availability schedules to generate your response. High priority meetings should be scheduled as early as possible. Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. -Respond with "[day] [start_time]-[end_time]" or "No common time slot available" -Do not respond with any additional information or comments. +The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available". +Provide your answer in the format: +Final Answer: {{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja index d63accd..397ce42 100644 --- a/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja +++ b/eureka_ml_insights/prompt_templates/ba_calendar_templates/calendar_scheduling_cot.jinja @@ -3,7 +3,7 @@ Make sure you use the availability schedules to generate your response. High priority meetings should be scheduled as early as possible. Buffer time refers to the required remaining available time before and after a meeting. For example, if buffer time is 15 minutes, a meeting from 9:00-10:00 will require availability from 8:45-10:15. The final time slot solution should be "[day] [start_time]-[end_time]" or "No common time slot available". -Think through and provide your your answer in the format: +Think through and provide your answer in the format: Reason: Final Answer: {{prompt}} \ No newline at end of file diff --git a/eureka_ml_insights/user_configs/ba_calendar.py b/eureka_ml_insights/user_configs/ba_calendar.py index 0699567..efa5f5a 100644 --- a/eureka_ml_insights/user_configs/ba_calendar.py +++ b/eureka_ml_insights/user_configs/ba_calendar.py @@ -19,8 +19,6 @@ from eureka_ml_insights.metrics.reports import ( AverageAggregator, BiLevelMaxAggregator, - MaxAggregator, - NAFilteredAggregator, ) from ..configs.config import ( @@ -99,17 +97,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "column_names": [ "BACalendarMetric_all_correct", - "BACalendarMetric_fraction_passed" - ], - "filename_base": "BaCal_OverallMetrics_SeparateRuns", - "group_by": "data_repeat_id", - }, - ), - AggregatorConfig( - NAFilteredAggregator, - { - "agg_class": AverageAggregator, - "column_names": [ + "BACalendarMetric_fraction_passed", "BACalendarMetric_availability_programmatic_check", "BACalendarMetric_meeting_duration_programmatic_check", "BACalendarMetric_buffer_time_programmatic_check", @@ -118,7 +106,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P "BACalendarMetric_specific_times_programmatic_check", "BACalendarMetric_priority_programmatic_check" ], - "filename_base": "BaCal_Constraint_Level_SeprateRuns", + "filename_base": "BaCal_OverallMetrics_SeparateRuns", "group_by": "data_repeat_id", }, ), @@ -142,18 +130,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "column_names": [ "BACalendarMetric_all_correct", - "BACalendarMetric_fraction_passed" - ], - "first_groupby": "data_point_id", - "filename_base": "BaCal_BestOfN_Aggregated", - "normalize": True, - }, - ), - AggregatorConfig( - NAFilteredAggregator, - { - "agg_class": MaxAggregator, - "column_names": [ + "BACalendarMetric_fraction_passed", "BACalendarMetric_availability_programmatic_check", "BACalendarMetric_meeting_duration_programmatic_check", "BACalendarMetric_buffer_time_programmatic_check", @@ -162,18 +139,17 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P "BACalendarMetric_specific_times_programmatic_check", "BACalendarMetric_priority_programmatic_check" ], - "filename_base": "BaCal_Constraint_Level_BestOfN_Aggregated", - "group_by": "data_repeat_id", + "first_groupby": "data_point_id", + "filename_base": "BaCal_BestOfN_Aggregated", + "normalize": True, }, ), - - ], output_dir=os.path.join(self.log_dir, "bestofn_eval_report"), ) # Aggregate the results by a majority vote - self.data_post_processing_addmv = DataProcessingConfig( + self.maj_vote_data_post_processing = DataProcessingConfig( component_type=DataProcessing, data_reader_config=DataSetConfig( DataReader, @@ -197,12 +173,12 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P output_dir=os.path.join(self.log_dir, "data_majvote_output"), ) # Second, compute eaxct match - self.postevalprocess_comp = EvalReportingConfig( + self.majvote_evalreporting_comp = EvalReportingConfig( component_type=EvalReporting, data_reader_config=DataSetConfig( DataReader, { - "path": os.path.join(self.data_post_processing_addmv.output_dir, "transformed_data.jsonl"), + "path": os.path.join(self.maj_vote_data_post_processing.output_dir, "transformed_data.jsonl"), "format": ".jsonl", "transform": SequenceTransform( [ @@ -218,17 +194,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P { "column_names": [ "BACalendarMetric_all_correct", - "BACalendarMetric_fraction_passed" - ], - "filename_base": "BaCal_MajVote_OverallMetrics_Aggregated", - "group_by": "data_repeat_id", - }, - ), - AggregatorConfig( - NAFilteredAggregator, - { - "agg_class": AverageAggregator, - "column_names": [ + "BACalendarMetric_fraction_passed", "BACalendarMetric_availability_programmatic_check", "BACalendarMetric_meeting_duration_programmatic_check", "BACalendarMetric_buffer_time_programmatic_check", @@ -237,7 +203,7 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P "BACalendarMetric_specific_times_programmatic_check", "BACalendarMetric_priority_programmatic_check" ], - "filename_base": "BaCal_MajVote_Constraint_Level_Aggregated", + "filename_base": "BaCal_MajVote_OverallMetrics_Aggregated", "group_by": "data_repeat_id", }, ), @@ -252,8 +218,8 @@ def configure_pipeline(self, model_config=None, resume_from=None, **kwargs) -> P self.inference_comp, self.evalreporting_comp, self.bon_evalreporting_comp, - self.data_post_processing_addmv, - self.postevalprocess_comp + self.maj_vote_data_post_processing, + self.majvote_evalreporting_comp ], self.log_dir, diff --git a/tests/metric_utils_tests/aggregator_tests.py b/tests/metric_utils_tests/aggregator_tests.py index eb1cba0..4f08ecd 100644 --- a/tests/metric_utils_tests/aggregator_tests.py +++ b/tests/metric_utils_tests/aggregator_tests.py @@ -15,7 +15,7 @@ SumAggregator, TwoColumnSumAverageAggregator, ) -from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, NAFilteredAggregator +from eureka_ml_insights.metrics.reports import BiLevelMaxAggregator, MaxAggregator, ValueFilteredAggregator PRECISION = 3 @@ -419,7 +419,7 @@ def test_average_aggregator_group_by_multiple_columns(self): self.assertTrue(os.path.exists(avg_agg.output_file)) -class NAFilteredAggregatorTestData: +class ValueFilteredAggregatorTestData: def setUp(self): self.data = pd.DataFrame( { @@ -430,18 +430,14 @@ def setUp(self): "col3": [5, 8, 'NA', 3, 'abc', 8, 3, 4, 5, 8, 4, 2], "categorical_metric": ["x", "y", "z", "z", "y", "y", "z", "y", "x", "y", "y", "x"], "group": ["a", "a", "b", "b", "a", "a", "b", "b", "a", "a", "b", "b"], - # [5, 6, 8, 5, 8, ] - # [2, 3, 3, 4, 2] - # [5, 8, 6, 8, 5, 8, ] - # [2, 3, 4, 2] } ) self.output_dir = "output_dir" self.precision = PRECISION -class TestNAFilteredAggregator(NAFilteredAggregatorTestData, unittest.TestCase): +class TestValueFilteredAggregator(ValueFilteredAggregatorTestData, unittest.TestCase): def test_average_aggregator(self): - avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir) + avg_agg = ValueFilteredAggregator(AverageAggregator, "NA", ["col1", "col2"], self.output_dir) avg_agg.aggregate(self.data) x = [a for a in self.data["col1"] if a != 'NA'] y = [a for a in self.data["col2"] if a != 'NA'] @@ -451,12 +447,12 @@ def test_average_aggregator(self): ) def test_average_aggregator_input_validation(self): - avg_agg = NAFilteredAggregator(AverageAggregator, ["col3"], self.output_dir) + avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col3"], self.output_dir) self.assertRaises(ValueError, avg_agg.aggregate, self.data) def test_average_aggregator_group_by(self): - self.output_dir = create_logdir("NAFilteredAggregatorTests") - avg_agg = NAFilteredAggregator(AverageAggregator, ["col1", "col2"], self.output_dir, group_by="group") + self.output_dir = create_logdir("ValueFilteredAggregatorTests") + avg_agg = ValueFilteredAggregator(AverageAggregator, 'NA', ["col1", "col2"], self.output_dir, group_by="group") avg_agg.aggregate(self.data) self.assertEqual(avg_agg.aggregated_result, {"col1": {"a": 6.4, "b": 2.8}, "col2": {"a": 6.667, "b": 2.75}}) avg_agg.write_results() diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index 20005bb..fe22824 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -487,7 +487,7 @@ def get_config(self): def setUp(self) -> None: super().setUp() - self.eval_configs = [self.test_pipeline.evalreporting_comp] + self.eval_configs = [self.test_pipeline.evalreporting_comp,self.test_pipeline.bon_evalreporting_comp, self.test_pipeline.majvote_evalreporting_comp] def test_outputs_exist(self) -> None: logging.info("Running test_outputs_exist test in PipelineTest")