Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

numeric match and topic tag #72

Open
wants to merge 6 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions eureka_ml_insights/metrics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
SpatialAndLayoutReasoningMetric,
)

from .aime_metrics import NumericMatch
__all__ = [
Metric,
ClassicMetric,
Expand All @@ -52,4 +53,5 @@
SumAggregator,
MMMUMetric,
MaxTokenF1ScoreMetric,
NumericMatch,
]
20 changes: 20 additions & 0 deletions eureka_ml_insights/metrics/aime_metrics.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from tqdm.auto import tqdm

from eureka_ml_insights.metrics.metrics_base import ClassicMetric

import numpy as np

class NumericMatch(ClassicMetric):
"""This class checks for a numeric match."""
eps = 1e-6
def __evaluate__(self, answer_text, target_text, is_valid):
if not is_valid:
return "none"
try:
diff = np.abs(float(target_text)-float(answer_text))
except:
return "none"
if diff<self.eps:
return "correct"
else:
return "incorrect"
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
You are a genius math expert in understanding math questions. Please read the following math question, and then decide which math category it falls into.

Your judgment should be one of the following:

arithmetic
algebra
counting
geometry
number theory
probability
other topics

Do not generate any other texts except one of the above topics.

----------
Original question:
{{prompt}}
----------
Your judgment:
lchen001 marked this conversation as resolved.
Show resolved Hide resolved
1 change: 1 addition & 0 deletions eureka_ml_insights/user_configs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
AIME_PIPELINE256Run,
AIME_PIPELINE512Run,
AIME_PIPELINE1024Run,
AIME_PIPELINETag,
)
from .dna import DNA_PIPELINE
from .drop import Drop_Experiment_Pipeline
Expand Down
33 changes: 28 additions & 5 deletions eureka_ml_insights/user_configs/aime.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@
from eureka_ml_insights.data_utils.aime_utils import AIMEExtractAnswer
from eureka_ml_insights.data_utils.data import DataLoader
from eureka_ml_insights.metrics.metrics_base import ExactMatch
from eureka_ml_insights.metrics.aime_metrics import NumericMatch

from eureka_ml_insights.metrics.reports import (
BiLevelCountAggregator,
CountAggregator,
Expand Down Expand Up @@ -114,16 +116,16 @@ def configure_pipeline(
"format": ".jsonl",
},
),
metric_config=MetricConfig(ExactMatch),
metric_config=MetricConfig(NumericMatch),
aggregator_configs=[
AggregatorConfig(
CountAggregator,
{
"column_names": [
"ExactMatch_result",
"NumericMatch_result",
],
"group_by": "Year",
"filename_base": "ExactMatch_GroupBy",
"filename_base": "NumericMatch_GroupBy",
},
),
],
Expand Down Expand Up @@ -171,13 +173,13 @@ def configure_pipeline(
"format": ".jsonl",
},
),
metric_config=MetricConfig(ExactMatch),
metric_config=MetricConfig(NumericMatch),
aggregator_configs=[
AggregatorConfig(
BiLevelCountAggregator,
{
"column_names": [
"ExactMatch_result",
"NumericMatch_result",
],
"first_groupby": "ID",
"filename_base": "MajorityVote",
Expand Down Expand Up @@ -312,3 +314,24 @@ def configure_pipeline(
MultiplyTransform(n_repeats=1024)
)
return pipeline


class AIME_PIPELINETag(AIME_PIPELINE):
"""This class specifies the config for running AIME benchmark 5 repeated times"""
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

update comment so it reflects the functionality of the class


def configure_pipeline(
self, model_config: ModelConfig, resume_from: str = None, **kwargs: dict[str, Any]
) -> PipelineConfig:
pipeline = super().configure_pipeline(model_config=model_config, resume_from=resume_from)
'''
self.data_processing_comp.data_reader_config.init_args["transform"].transforms.append(
SamplerTransform(random_seed=0,
lchen001 marked this conversation as resolved.
Show resolved Hide resolved
sample_count=10,
)
)
'''
# data preprocessing
self.data_processing_comp.prompt_template_path=os.path.join(
os.path.dirname(__file__), "../prompt_templates/aime_templates/Template_tag1.jinja"
)
return pipeline
Loading