diff --git a/eureka_ml_insights/configs/__init__.py b/eureka_ml_insights/configs/__init__.py index 7d04a3f..b5a720c 100644 --- a/eureka_ml_insights/configs/__init__.py +++ b/eureka_ml_insights/configs/__init__.py @@ -1,4 +1,3 @@ -from .aime import AIME_PIPELINE from .config import ( AggregatorConfig, DataJoinConfig, @@ -11,71 +10,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .dna import DNA_PIPELINE -from .drop import Drop_Experiment_Pipeline from .experiment_config import ExperimentConfig, create_logdir -from .flenqa import FlenQA_Experiment_Pipeline -from .geometer import GEOMETER_PIPELINE -from .gpqa import GPQA_Experiment_Pipeline -from .ifeval import IFEval_PIPELINE -from .image_understanding.object_detection import ( - OBJECT_DETECTION_PAIRS_LOCAL_PIPELINE, - OBJECT_DETECTION_PAIRS_PIPELINE, - OBJECT_DETECTION_SINGLE_LOCAL_PIPELINE, - OBJECT_DETECTION_SINGLE_PIPELINE, -) -from .image_understanding.object_recognition import ( - OBJECT_RECOGNITION_PAIRS_LOCAL_PIPELINE, - OBJECT_RECOGNITION_PAIRS_PIPELINE, - OBJECT_RECOGNITION_SINGLE_LOCAL_PIPELINE, - OBJECT_RECOGNITION_SINGLE_PIPELINE, -) -from .image_understanding.spatial_reasoning import ( - SPATIAL_REASONING_PAIRS_LOCAL_PIPELINE, - SPATIAL_REASONING_PAIRS_PIPELINE, - SPATIAL_REASONING_SINGLE_LOCAL_PIPELINE, - SPATIAL_REASONING_SINGLE_PIPELINE, -) -from .image_understanding.visual_prompting import ( - VISUAL_PROMPTING_PAIRS_LOCAL_PIPELINE, - VISUAL_PROMPTING_PAIRS_PIPELINE, - VISUAL_PROMPTING_SINGLE_LOCAL_PIPELINE, - VISUAL_PROMPTING_SINGLE_PIPELINE, -) -from .kitab import ( - GPT35_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_SELF_CONTEXT, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, - KITAB_TWO_BOOK_CONSTRAINT_PIPELINE, - KITAB_TWO_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, -) -from .mmmu import MMMU_BASELINE_PIPELINE -from .nondeterminism import ( - Geo_Nondeterminism, - IFEval_Nondeterminism, - Kitab_Nondeterminism, - MMMU_Nondeterminism, -) -from .toxigen import ( - ToxiGen_Discriminative_PIPELINE, - ToxiGen_Generative_PIPELINE, -) -from .vision_language.maze import ( - MAZE_PIPELINE, - MAZE_REPORTING_PIPELINE, - MAZE_TEXTONLY_PIPELINE, -) -from .vision_language.spatial_grid import ( - SPATIAL_GRID_PIPELINE, - SPATIAL_GRID_REPORTING_PIPELINE, - SPATIAL_GRID_TEXTONLY_PIPELINE, -) -from .vision_language.spatial_map import ( - SPATIAL_MAP_PIPELINE, - SPATIAL_MAP_REPORTING_PIPELINE, - SPATIAL_MAP_TEXTONLY_PIPELINE, -) __all__ = [ PipelineConfig, @@ -89,50 +24,5 @@ DataSetConfig, EvalReportingConfig, ExperimentConfig, - OBJECT_DETECTION_PAIRS_PIPELINE, - OBJECT_DETECTION_SINGLE_PIPELINE, - OBJECT_DETECTION_PAIRS_LOCAL_PIPELINE, - OBJECT_DETECTION_SINGLE_LOCAL_PIPELINE, - OBJECT_RECOGNITION_PAIRS_PIPELINE, - OBJECT_RECOGNITION_SINGLE_PIPELINE, - OBJECT_RECOGNITION_PAIRS_LOCAL_PIPELINE, - OBJECT_RECOGNITION_SINGLE_LOCAL_PIPELINE, - SPATIAL_REASONING_PAIRS_PIPELINE, - SPATIAL_REASONING_SINGLE_PIPELINE, - SPATIAL_REASONING_PAIRS_LOCAL_PIPELINE, - SPATIAL_REASONING_SINGLE_LOCAL_PIPELINE, - VISUAL_PROMPTING_PAIRS_PIPELINE, - VISUAL_PROMPTING_SINGLE_PIPELINE, - VISUAL_PROMPTING_PAIRS_LOCAL_PIPELINE, - VISUAL_PROMPTING_SINGLE_LOCAL_PIPELINE, - SPATIAL_GRID_PIPELINE, - SPATIAL_GRID_TEXTONLY_PIPELINE, - SPATIAL_GRID_REPORTING_PIPELINE, - SPATIAL_MAP_PIPELINE, - SPATIAL_MAP_TEXTONLY_PIPELINE, - SPATIAL_MAP_REPORTING_PIPELINE, - MAZE_PIPELINE, - MAZE_TEXTONLY_PIPELINE, - MAZE_REPORTING_PIPELINE, - IFEval_PIPELINE, - FlenQA_Experiment_Pipeline, - GPQA_Experiment_Pipeline, - Drop_Experiment_Pipeline, - GEOMETER_PIPELINE, - MMMU_BASELINE_PIPELINE, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, - KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_SELF_CONTEXT, - KITAB_TWO_BOOK_CONSTRAINT_PIPELINE, - KITAB_TWO_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, - GPT35_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, - DNA_PIPELINE, - ToxiGen_Discriminative_PIPELINE, - ToxiGen_Generative_PIPELINE, - Geo_Nondeterminism, - MMMU_Nondeterminism, - IFEval_Nondeterminism, - Kitab_Nondeterminism, - AIME_PIPELINE, create_logdir, ] diff --git a/eureka_ml_insights/configs/model_configs.py b/eureka_ml_insights/configs/model_configs.py index 01e2cd1..1b83a7b 100644 --- a/eureka_ml_insights/configs/model_configs.py +++ b/eureka_ml_insights/configs/model_configs.py @@ -13,6 +13,7 @@ LLaVAModel, MistralServerlessAzureRestEndpointModel, RestEndpointModel, + TestModel, ) from .config import ModelConfig @@ -21,6 +22,10 @@ # in the secret_key_params dictionary. OR you can provide the key name and key vault URL to fetch the key from Azure Key Vault. # You don't need to provide both the key_vault_url and local_keys_path. You can provide one of them based on your setup. + +# Test model +TEST_MODEL_CONFIG = ModelConfig(TestModel, {}) + # OpenAI models OPENAI_SECRET_KEY_PARAMS = { @@ -193,4 +198,4 @@ }, "model_name": "Mistral-large-2407", }, -) \ No newline at end of file +) diff --git a/eureka_ml_insights/core/__init__.py b/eureka_ml_insights/core/__init__.py index c79de3d..82a51ec 100644 --- a/eureka_ml_insights/core/__init__.py +++ b/eureka_ml_insights/core/__init__.py @@ -1,5 +1,5 @@ from .data_join import DataJoin -from .data_processing import DataProcessing, NumpyEncoder +from .data_processing import DataProcessing from .eval_reporting import EvalReporting from .inference import Inference from .pipeline import Component, Pipeline @@ -12,6 +12,5 @@ "EvalReporting", "DataProcessing", "PromptProcessing", - "NumpyEncoder", "DataJoin", ] diff --git a/eureka_ml_insights/core/data_processing.py b/eureka_ml_insights/core/data_processing.py index 8fc45e3..56a653a 100644 --- a/eureka_ml_insights/core/data_processing.py +++ b/eureka_ml_insights/core/data_processing.py @@ -1,11 +1,10 @@ -import base64 import json import logging import os from hashlib import md5 from typing import List, Optional -import numpy as np +from eureka_ml_insights.data_utils import NumpyEncoder from .pipeline import Component from .reserved_names import ( @@ -21,36 +20,6 @@ def compute_hash(val: str) -> str: return md5(val.encode("utf-8")).hexdigest() -class NumpyEncoder(json.JSONEncoder): - """Special json encoder for numpy types""" - - def default(self, obj): - if isinstance( - obj, - ( - np.int_, - np.intc, - np.intp, - np.int8, - np.int16, - np.int32, - np.int64, - np.uint8, - np.uint16, - np.uint32, - np.uint64, - ), - ): - return int(obj) - elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): - return float(obj) - elif isinstance(obj, (np.ndarray,)): - return obj.tolist() - elif isinstance(obj, bytes): - return base64.b64encode(obj).decode("ascii") - return json.JSONEncoder.default(self, obj) - - class DataProcessing(Component): @classmethod def from_config(cls, config): diff --git a/eureka_ml_insights/core/eval_reporting.py b/eureka_ml_insights/core/eval_reporting.py index c7e72c7..7e2eaa9 100644 --- a/eureka_ml_insights/core/eval_reporting.py +++ b/eureka_ml_insights/core/eval_reporting.py @@ -2,9 +2,9 @@ import json import os +from eureka_ml_insights.data_utils import NumpyEncoder from eureka_ml_insights.metrics import Reporter -from .data_processing import NumpyEncoder from .pipeline import Component diff --git a/eureka_ml_insights/data_utils/__init__.py b/eureka_ml_insights/data_utils/__init__.py index 88b5921..5d93b55 100644 --- a/eureka_ml_insights/data_utils/__init__.py +++ b/eureka_ml_insights/data_utils/__init__.py @@ -11,8 +11,8 @@ MMDataLoader, TXTWriter, ) +from .encoders import NumpyEncoder from .prompt_processing import JinjaPromptTemplate -from .secret_key_utils import GetKey from .spatial_utils import ( ExtractAnswerGrid, ExtractAnswerMaze, @@ -67,11 +67,11 @@ RegexTransform, ASTEvalTransform, PrependStringTransform, - GetKey, ExtractAnswerGrid, ExtractAnswerSpatialMap, ExtractAnswerMaze, ShuffleColumnsTransform, ColumnMatchMapTransform, TokenCounterTransform, + NumpyEncoder, ] diff --git a/eureka_ml_insights/data_utils/data.py b/eureka_ml_insights/data_utils/data.py index fc520e9..55fd556 100644 --- a/eureka_ml_insights/data_utils/data.py +++ b/eureka_ml_insights/data_utils/data.py @@ -14,9 +14,9 @@ from PIL import Image from tqdm import tqdm -from eureka_ml_insights.core import NumpyEncoder +from eureka_ml_insights.secret_management import get_secret -from .secret_key_utils import GetKey +from .encoders import NumpyEncoder from .transform import DFTransformBase log = logging.getLogger("data_reader") @@ -216,14 +216,14 @@ def get_query_string(self, query_string=None, secret_key_params=None): One of the two arguments must be provided. args: query_string: str, query string to authenticate with Azure Blob Storage. - secret_key_params: dict, dictionary containing the paramters to call GetKey with. + secret_key_params: dict, dictionary containing the paramters to call get_secret with. """ self.query_string = query_string self.secret_key_params = secret_key_params if self.query_string is None and self.secret_key_params is None: raise ValueError("Either provide query_string or secret_key_params to load data from Azure.") if self.query_string is None: - self.query_string = GetKey(**secret_key_params) + self.query_string = get_secret(**secret_key_params) class AzureMMDataLoader(MMDataLoader): diff --git a/eureka_ml_insights/data_utils/encoders.py b/eureka_ml_insights/data_utils/encoders.py new file mode 100644 index 0000000..a8ab9d2 --- /dev/null +++ b/eureka_ml_insights/data_utils/encoders.py @@ -0,0 +1,32 @@ +import json +import base64 +import numpy as np + +class NumpyEncoder(json.JSONEncoder): + """Special json encoder for numpy types""" + + def default(self, obj): + if isinstance( + obj, + ( + np.int_, + np.intc, + np.intp, + np.int8, + np.int16, + np.int32, + np.int64, + np.uint8, + np.uint16, + np.uint32, + np.uint64, + ), + ): + return int(obj) + elif isinstance(obj, (np.float_, np.float16, np.float32, np.float64)): + return float(obj) + elif isinstance(obj, (np.ndarray,)): + return obj.tolist() + elif isinstance(obj, bytes): + return base64.b64encode(obj).decode("ascii") + return json.JSONEncoder.default(self, obj) \ No newline at end of file diff --git a/eureka_ml_insights/metrics/kitab_metrics.py b/eureka_ml_insights/metrics/kitab_metrics.py index 8d49f5d..c11aafb 100644 --- a/eureka_ml_insights/metrics/kitab_metrics.py +++ b/eureka_ml_insights/metrics/kitab_metrics.py @@ -11,17 +11,17 @@ import numpy as np import requests from azure.ai.textanalytics import TextAnalyticsClient -from azure.identity import DefaultAzureCredential from azure.core.credentials import AzureKeyCredential from azure.core.exceptions import ( HttpResponseError, ServiceRequestError, ServiceResponseError, ) +from azure.identity import DefaultAzureCredential from fuzzywuzzy import fuzz -from eureka_ml_insights.data_utils import GetKey, kitab_utils from eureka_ml_insights.metrics import CompositeMetric +from eureka_ml_insights.secret_management import get_secret class KitabMetric(CompositeMetric): @@ -40,10 +40,10 @@ def __init__(self, temp_path_names, azure_lang_service_config): ) # requires an Azure Cognitive Services Endpoint # (ref: https://learn.microsoft.com/en-us/azure/ai-services/language-service/) - self.key = GetKey( - key_name = azure_lang_service_config["secret_key_params"].get("key_name", None), - local_keys_path = azure_lang_service_config["secret_key_params"].get("local_keys_path", None), - key_vault_url = azure_lang_service_config["secret_key_params"].get("key_vault_url", None), + self.key = get_secret( + key_name=azure_lang_service_config["secret_key_params"].get("key_name", None), + local_keys_path=azure_lang_service_config["secret_key_params"].get("local_keys_path", None), + key_vault_url=azure_lang_service_config["secret_key_params"].get("key_vault_url", None), ) self.endpoint = azure_lang_service_config["url"] self.text_analytics_credential = self.get_verified_credential() @@ -51,19 +51,14 @@ def __init__(self, temp_path_names, azure_lang_service_config): def get_verified_credential(self): model_version = "latest" try: - text_analytics_client = TextAnalyticsClient( - endpoint=self.endpoint, credential=AzureKeyCredential(self.key) - ) + text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=AzureKeyCredential(self.key)) text_analytics_client.recognize_entities(["New York City"], model_version=model_version) return AzureKeyCredential(self.key) except Exception as e: logging.info(f"Failed to create the TextAnalyticsClient using AzureKeyCredential") logging.info("The error is caused by: {}".format(e)) try: - text_analytics_client = TextAnalyticsClient( - endpoint=self.endpoint, credential=DefaultAzureCredential() - - ) + text_analytics_client = TextAnalyticsClient(endpoint=self.endpoint, credential=DefaultAzureCredential()) text_analytics_client.recognize_entities(["New York City"], model_version=model_version) return DefaultAzureCredential() except Exception as e: @@ -112,17 +107,13 @@ def process_row(self, row, gpt4_names): all_books = [] raw_unmapped = [] - mapped_books = [ - self.process_title(book) for book in ast.literal_eval(row["mapped_books"]) - ] + mapped_books = [self.process_title(book) for book in ast.literal_eval(row["mapped_books"])] model_books = ( [self.process_title(book) for book in row["model_books"]] if isinstance(row["model_books"], list) else [self.process_title(book) for book in row["model_books"]["titles"]] ) - all_books = [ - self.process_title(self.process_all_books(book)) for book in ast.literal_eval(row["all_books"]) - ] + all_books = [self.process_title(self.process_all_books(book)) for book in ast.literal_eval(row["all_books"])] raw_books = [self.process_title(book) for book in ast.literal_eval(row["raw_books"])] len(model_books) @@ -468,9 +459,7 @@ def extract_persons(self, text): text_analytics_client = TextAnalyticsClient( endpoint=self.endpoint, credential=self.text_analytics_credential, api_version="2023-04-01" ) - result = text_analytics_client.recognize_entities( - input_texts, model_version="2023-04-15-preview" - ) + result = text_analytics_client.recognize_entities(input_texts, model_version="2023-04-15-preview") error_flag = any([review.is_error for review in result]) result = [review for review in result if not review.is_error] diff --git a/eureka_ml_insights/metrics/spatial_and_layout_metrics.py b/eureka_ml_insights/metrics/spatial_and_layout_metrics.py index df1d9b1..eeec2cb 100644 --- a/eureka_ml_insights/metrics/spatial_and_layout_metrics.py +++ b/eureka_ml_insights/metrics/spatial_and_layout_metrics.py @@ -5,12 +5,8 @@ import nltk from pycocotools.coco import COCO -from eureka_ml_insights.data_utils import JsonReader -from eureka_ml_insights.metrics.metrics_base import ( - ClassicMetric, - DetectionMetric, - MultipleChoiceMetric, -) +from ..data_utils.data import JsonReader +from .metrics_base import ClassicMetric, DetectionMetric, MultipleChoiceMetric def download_nltk_resources(): diff --git a/eureka_ml_insights/models/__init__.py b/eureka_ml_insights/models/__init__.py index a4bbc88..3a8d024 100644 --- a/eureka_ml_insights/models/__init__.py +++ b/eureka_ml_insights/models/__init__.py @@ -13,6 +13,7 @@ MistralServerlessAzureRestEndpointModel, Phi3HFModel, RestEndpointModel, + TestModel, ) __all__ = [ @@ -30,4 +31,5 @@ LlamaServerlessAzureRestEndpointModel, LLaVAModel, RestEndpointModel, + TestModel, ] diff --git a/eureka_ml_insights/models/models.py b/eureka_ml_insights/models/models.py index 39bff8d..5681171 100644 --- a/eureka_ml_insights/models/models.py +++ b/eureka_ml_insights/models/models.py @@ -11,7 +11,7 @@ import tiktoken from azure.identity import DefaultAzureCredential, get_bearer_token_provider -from eureka_ml_insights.data_utils import GetKey +from eureka_ml_insights.secret_management import get_secret @dataclass @@ -77,10 +77,10 @@ def get_api_key(self): """ This method is used to get the api_key for the models that require key-based authentication. Either api_key (str) or secret_key_params (dict) must be provided. - if api_key is not directly provided, secret_key_params must be provided to get the api_key using GetKey method. + if api_key is not directly provided, secret_key_params must be provided to get the api_key using get_secret method. """ if self.api_key is None: - self.api_key = GetKey(**self.secret_key_params) + self.api_key = get_secret(**self.secret_key_params) return self.api_key @@ -403,7 +403,9 @@ class AzureOpenAIClientMixIn: def get_client(self): from openai import AzureOpenAI - token_provider = get_bearer_token_provider(DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default") + token_provider = get_bearer_token_provider( + DefaultAzureCredential(), "https://cognitiveservices.azure.com/.default" + ) return AzureOpenAI( azure_endpoint=self.url, api_version=self.api_version, @@ -982,3 +984,21 @@ def get_response(self, request): def handle_request_error(self, e): return False + + +@dataclass +class TestModel(Model): + # This class is used for testing purposes only. It only waits for a specified time and returns a response. + response_time: float = 0.1 + model_output: str = "This is a test response." + + def __post_init__(self): + self.n_output_tokens = self.count_tokens() + + def generate(self, text_prompt, **kwargs): + return { + "model_output": self.model_output, + "is_valid": True, + "response_time": self.response_time, + "n_output_tokens": self.n_output_tokens, + } diff --git a/eureka_ml_insights/secret_management/__init__.py b/eureka_ml_insights/secret_management/__init__.py new file mode 100644 index 0000000..a1e5a89 --- /dev/null +++ b/eureka_ml_insights/secret_management/__init__.py @@ -0,0 +1 @@ +from .secret_key_utils import get_secret diff --git a/eureka_ml_insights/data_utils/secret_key_utils.py b/eureka_ml_insights/secret_management/secret_key_utils.py similarity index 96% rename from eureka_ml_insights/data_utils/secret_key_utils.py rename to eureka_ml_insights/secret_management/secret_key_utils.py index ce3f329..ab4c61c 100644 --- a/eureka_ml_insights/data_utils/secret_key_utils.py +++ b/eureka_ml_insights/secret_management/secret_key_utils.py @@ -9,7 +9,7 @@ logging.basicConfig(level=logging.INFO, format="%(filename)s - %(funcName)s - %(message)s") -def GetKey(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None) -> Optional[str]: +def get_secret(key_name: str, local_keys_path:Optional[str]=None, key_vault_url:Optional[str]=None) -> Optional[str]: """This function retrieves a key from key vault or if it is locally cached in a file. args: key_name: str, the name of the key to retrieve. @@ -129,4 +129,4 @@ def get_cached_keys_dict(local_keys_path: str) -> Dict[str, str]: if __name__ == "__main__": key_name = "aifeval-datasets" key_vault_url = "https://aifeval.vault.azure.net/" - GetKey(key_name, key_vault_url=key_vault_url) + get_secret(key_name, key_vault_url=key_vault_url) diff --git a/eureka_ml_insights/user_configs/__init__.py b/eureka_ml_insights/user_configs/__init__.py new file mode 100644 index 0000000..0896652 --- /dev/null +++ b/eureka_ml_insights/user_configs/__init__.py @@ -0,0 +1,113 @@ +from .aime import AIME_PIPELINE +from .dna import DNA_PIPELINE +from .drop import Drop_Experiment_Pipeline +from .flenqa import FlenQA_Experiment_Pipeline +from .geometer import GEOMETER_PIPELINE +from .gpqa import GPQA_Experiment_Pipeline +from .ifeval import IFEval_PIPELINE +from .image_understanding.object_detection import ( + OBJECT_DETECTION_PAIRS_LOCAL_PIPELINE, + OBJECT_DETECTION_PAIRS_PIPELINE, + OBJECT_DETECTION_SINGLE_LOCAL_PIPELINE, + OBJECT_DETECTION_SINGLE_PIPELINE, +) +from .image_understanding.object_recognition import ( + OBJECT_RECOGNITION_PAIRS_LOCAL_PIPELINE, + OBJECT_RECOGNITION_PAIRS_PIPELINE, + OBJECT_RECOGNITION_SINGLE_LOCAL_PIPELINE, + OBJECT_RECOGNITION_SINGLE_PIPELINE, +) +from .image_understanding.spatial_reasoning import ( + SPATIAL_REASONING_PAIRS_LOCAL_PIPELINE, + SPATIAL_REASONING_PAIRS_PIPELINE, + SPATIAL_REASONING_SINGLE_LOCAL_PIPELINE, + SPATIAL_REASONING_SINGLE_PIPELINE, +) +from .image_understanding.visual_prompting import ( + VISUAL_PROMPTING_PAIRS_LOCAL_PIPELINE, + VISUAL_PROMPTING_PAIRS_PIPELINE, + VISUAL_PROMPTING_SINGLE_LOCAL_PIPELINE, + VISUAL_PROMPTING_SINGLE_PIPELINE, +) +from .kitab import ( + GPT35_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_SELF_CONTEXT, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, + KITAB_TWO_BOOK_CONSTRAINT_PIPELINE, + KITAB_TWO_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, +) +from .mmmu import MMMU_BASELINE_PIPELINE +from .nondeterminism import ( + Geo_Nondeterminism, + IFEval_Nondeterminism, + Kitab_Nondeterminism, + MMMU_Nondeterminism, +) +from .toxigen import ( + ToxiGen_Discriminative_PIPELINE, + ToxiGen_Generative_PIPELINE, +) +from .vision_language.maze import ( + MAZE_PIPELINE, + MAZE_REPORTING_PIPELINE, + MAZE_TEXTONLY_PIPELINE, +) +from .vision_language.spatial_grid import ( + SPATIAL_GRID_PIPELINE, + SPATIAL_GRID_REPORTING_PIPELINE, + SPATIAL_GRID_TEXTONLY_PIPELINE, +) +from .vision_language.spatial_map import ( + SPATIAL_MAP_PIPELINE, + SPATIAL_MAP_REPORTING_PIPELINE, + SPATIAL_MAP_TEXTONLY_PIPELINE, +) + +__all__ = [ + OBJECT_DETECTION_PAIRS_PIPELINE, + OBJECT_DETECTION_SINGLE_PIPELINE, + OBJECT_DETECTION_PAIRS_LOCAL_PIPELINE, + OBJECT_DETECTION_SINGLE_LOCAL_PIPELINE, + OBJECT_RECOGNITION_PAIRS_PIPELINE, + OBJECT_RECOGNITION_SINGLE_PIPELINE, + OBJECT_RECOGNITION_PAIRS_LOCAL_PIPELINE, + OBJECT_RECOGNITION_SINGLE_LOCAL_PIPELINE, + SPATIAL_REASONING_PAIRS_PIPELINE, + SPATIAL_REASONING_SINGLE_PIPELINE, + SPATIAL_REASONING_PAIRS_LOCAL_PIPELINE, + SPATIAL_REASONING_SINGLE_LOCAL_PIPELINE, + VISUAL_PROMPTING_PAIRS_PIPELINE, + VISUAL_PROMPTING_SINGLE_PIPELINE, + VISUAL_PROMPTING_PAIRS_LOCAL_PIPELINE, + VISUAL_PROMPTING_SINGLE_LOCAL_PIPELINE, + SPATIAL_GRID_PIPELINE, + SPATIAL_GRID_TEXTONLY_PIPELINE, + SPATIAL_GRID_REPORTING_PIPELINE, + SPATIAL_MAP_PIPELINE, + SPATIAL_MAP_TEXTONLY_PIPELINE, + SPATIAL_MAP_REPORTING_PIPELINE, + MAZE_PIPELINE, + MAZE_TEXTONLY_PIPELINE, + MAZE_REPORTING_PIPELINE, + IFEval_PIPELINE, + FlenQA_Experiment_Pipeline, + GPQA_Experiment_Pipeline, + Drop_Experiment_Pipeline, + GEOMETER_PIPELINE, + MMMU_BASELINE_PIPELINE, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, + KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_SELF_CONTEXT, + KITAB_TWO_BOOK_CONSTRAINT_PIPELINE, + KITAB_TWO_BOOK_CONSTRAINT_PIPELINE_WITH_CONTEXT, + GPT35_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE, + DNA_PIPELINE, + ToxiGen_Discriminative_PIPELINE, + ToxiGen_Generative_PIPELINE, + Geo_Nondeterminism, + MMMU_Nondeterminism, + IFEval_Nondeterminism, + Kitab_Nondeterminism, + AIME_PIPELINE, +] diff --git a/eureka_ml_insights/configs/aime.py b/eureka_ml_insights/user_configs/aime.py similarity index 95% rename from eureka_ml_insights/configs/aime.py rename to eureka_ml_insights/user_configs/aime.py index 670db9e..6408529 100644 --- a/eureka_ml_insights/configs/aime.py +++ b/eureka_ml_insights/user_configs/aime.py @@ -16,7 +16,7 @@ from eureka_ml_insights.metrics.metrics_base import ExactMatch from eureka_ml_insights.metrics.reports import CountAggregator -from .config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -27,7 +27,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig class AIME_PIPELINE(ExperimentConfig): diff --git a/eureka_ml_insights/configs/dna.py b/eureka_ml_insights/user_configs/dna.py similarity index 97% rename from eureka_ml_insights/configs/dna.py rename to eureka_ml_insights/user_configs/dna.py index 6210128..2df3d2d 100644 --- a/eureka_ml_insights/configs/dna.py +++ b/eureka_ml_insights/user_configs/dna.py @@ -18,7 +18,7 @@ from eureka_ml_insights.data_utils.transform import AddColumn from eureka_ml_insights.metrics.reports import CountAggregator -from .config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -28,8 +28,8 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig -from .model_configs import OAI_GPT4_1106_PREVIEW_CONFIG +from eureka_ml_insights.configs import ExperimentConfig +from eureka_ml_insights.configs.model_configs import OAI_GPT4_1106_PREVIEW_CONFIG class DNA_PIPELINE(ExperimentConfig): diff --git a/eureka_ml_insights/configs/drop.py b/eureka_ml_insights/user_configs/drop.py similarity index 97% rename from eureka_ml_insights/configs/drop.py rename to eureka_ml_insights/user_configs/drop.py index dda1fda..7d90204 100644 --- a/eureka_ml_insights/configs/drop.py +++ b/eureka_ml_insights/user_configs/drop.py @@ -14,7 +14,7 @@ ) from eureka_ml_insights.metrics import AverageAggregator, MaxTokenF1ScoreMetric -from .config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, @@ -24,7 +24,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig """This file contains user defined configuration classes for the geometric reasoning task on geometer dataset. """ diff --git a/eureka_ml_insights/configs/flenqa.py b/eureka_ml_insights/user_configs/flenqa.py similarity index 97% rename from eureka_ml_insights/configs/flenqa.py rename to eureka_ml_insights/user_configs/flenqa.py index ae4c1ad..59ffbe1 100644 --- a/eureka_ml_insights/configs/flenqa.py +++ b/eureka_ml_insights/user_configs/flenqa.py @@ -17,7 +17,7 @@ from eureka_ml_insights.data_utils.flenqa_utils import FlenQAOutputProcessor from eureka_ml_insights.metrics import CountAggregator, ExactMatch -from .config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -28,7 +28,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig class FlenQA_Experiment_Pipeline(ExperimentConfig): diff --git a/eureka_ml_insights/configs/geometer.py b/eureka_ml_insights/user_configs/geometer.py similarity index 96% rename from eureka_ml_insights/configs/geometer.py rename to eureka_ml_insights/user_configs/geometer.py index e084318..addbaa9 100644 --- a/eureka_ml_insights/configs/geometer.py +++ b/eureka_ml_insights/user_configs/geometer.py @@ -11,7 +11,7 @@ ) from eureka_ml_insights.metrics import CountAggregator, GeoMCQMetric -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataSetConfig, EvalReportingConfig, @@ -21,7 +21,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig """This file contains user defined configuration classes for the geometric reasoning task on geometer dataset. """ diff --git a/eureka_ml_insights/configs/gpqa.py b/eureka_ml_insights/user_configs/gpqa.py similarity index 97% rename from eureka_ml_insights/configs/gpqa.py rename to eureka_ml_insights/user_configs/gpqa.py index 64340f7..56ac28a 100644 --- a/eureka_ml_insights/configs/gpqa.py +++ b/eureka_ml_insights/user_configs/gpqa.py @@ -15,7 +15,7 @@ ) from eureka_ml_insights.metrics import CountAggregator, ExactMatch -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataSetConfig, EvalReportingConfig, @@ -25,7 +25,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig """This file contains user defined configuration classes for the geometric reasoning task on the GPQA dataset. """ diff --git a/eureka_ml_insights/configs/ifeval.py b/eureka_ml_insights/user_configs/ifeval.py similarity index 98% rename from eureka_ml_insights/configs/ifeval.py rename to eureka_ml_insights/user_configs/ifeval.py index a05a5a8..f1ef44d 100644 --- a/eureka_ml_insights/configs/ifeval.py +++ b/eureka_ml_insights/user_configs/ifeval.py @@ -20,7 +20,7 @@ TwoColumnSumAverageAggregator, ) -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -31,7 +31,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig class IFEval_PIPELINE(ExperimentConfig): diff --git a/eureka_ml_insights/configs/image_understanding/__init__.py b/eureka_ml_insights/user_configs/image_understanding/__init__.py similarity index 100% rename from eureka_ml_insights/configs/image_understanding/__init__.py rename to eureka_ml_insights/user_configs/image_understanding/__init__.py diff --git a/eureka_ml_insights/configs/image_understanding/common.py b/eureka_ml_insights/user_configs/image_understanding/common.py similarity index 95% rename from eureka_ml_insights/configs/image_understanding/common.py rename to eureka_ml_insights/user_configs/image_understanding/common.py index d066d58..a2a2c8a 100644 --- a/eureka_ml_insights/configs/image_understanding/common.py +++ b/eureka_ml_insights/user_configs/image_understanding/common.py @@ -2,7 +2,7 @@ from eureka_ml_insights.data_utils import DataReader, MMDataLoader -from ..config import DataSetConfig +from eureka_ml_insights.configs import DataSetConfig class LOCAL_DATA_PIPELINE: diff --git a/eureka_ml_insights/configs/image_understanding/object_detection.py b/eureka_ml_insights/user_configs/image_understanding/object_detection.py similarity index 99% rename from eureka_ml_insights/configs/image_understanding/object_detection.py rename to eureka_ml_insights/user_configs/image_understanding/object_detection.py index 40820c2..8585184 100644 --- a/eureka_ml_insights/configs/image_understanding/object_detection.py +++ b/eureka_ml_insights/user_configs/image_understanding/object_detection.py @@ -17,7 +17,7 @@ CocoObjectDetectionMetric, ) -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/image_understanding/object_recognition.py b/eureka_ml_insights/user_configs/image_understanding/object_recognition.py similarity index 99% rename from eureka_ml_insights/configs/image_understanding/object_recognition.py rename to eureka_ml_insights/user_configs/image_understanding/object_recognition.py index c238a8a..0d7c2ae 100644 --- a/eureka_ml_insights/configs/image_understanding/object_recognition.py +++ b/eureka_ml_insights/user_configs/image_understanding/object_recognition.py @@ -18,7 +18,7 @@ ObjectRecognitionMetric, ) -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/image_understanding/spatial_reasoning.py b/eureka_ml_insights/user_configs/image_understanding/spatial_reasoning.py similarity index 99% rename from eureka_ml_insights/configs/image_understanding/spatial_reasoning.py rename to eureka_ml_insights/user_configs/image_understanding/spatial_reasoning.py index 9f13ecf..c2cbfa0 100644 --- a/eureka_ml_insights/configs/image_understanding/spatial_reasoning.py +++ b/eureka_ml_insights/user_configs/image_understanding/spatial_reasoning.py @@ -20,7 +20,7 @@ SpatialAndLayoutReasoningMetric, ) -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/image_understanding/visual_prompting.py b/eureka_ml_insights/user_configs/image_understanding/visual_prompting.py similarity index 99% rename from eureka_ml_insights/configs/image_understanding/visual_prompting.py rename to eureka_ml_insights/user_configs/image_understanding/visual_prompting.py index bb9438d..4adaefd 100644 --- a/eureka_ml_insights/configs/image_understanding/visual_prompting.py +++ b/eureka_ml_insights/user_configs/image_understanding/visual_prompting.py @@ -18,7 +18,7 @@ ObjectRecognitionMetric, ) -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/kitab.py b/eureka_ml_insights/user_configs/kitab.py similarity index 99% rename from eureka_ml_insights/configs/kitab.py rename to eureka_ml_insights/user_configs/kitab.py index 97bdd63..0decdab 100644 --- a/eureka_ml_insights/configs/kitab.py +++ b/eureka_ml_insights/user_configs/kitab.py @@ -22,7 +22,7 @@ from eureka_ml_insights.metrics import AverageAggregator from eureka_ml_insights.metrics.kitab_metrics import KitabMetric -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -33,7 +33,7 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig +from eureka_ml_insights.configs import ExperimentConfig # Example template for an Azure Language Service config # required for running entity recognition for evaluating human and city name diff --git a/eureka_ml_insights/configs/mmmu.py b/eureka_ml_insights/user_configs/mmmu.py similarity index 98% rename from eureka_ml_insights/configs/mmmu.py rename to eureka_ml_insights/user_configs/mmmu.py index b7c6980..524c73a 100644 --- a/eureka_ml_insights/configs/mmmu.py +++ b/eureka_ml_insights/user_configs/mmmu.py @@ -20,7 +20,7 @@ ) from eureka_ml_insights.metrics import CountAggregator, MMMUMetric -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/nondeterminism.py b/eureka_ml_insights/user_configs/nondeterminism.py similarity index 100% rename from eureka_ml_insights/configs/nondeterminism.py rename to eureka_ml_insights/user_configs/nondeterminism.py diff --git a/eureka_ml_insights/configs/specifications/aime_spec.txt b/eureka_ml_insights/user_configs/specifications/aime_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/aime_spec.txt rename to eureka_ml_insights/user_configs/specifications/aime_spec.txt diff --git a/eureka_ml_insights/configs/specifications/dna_spec.txt b/eureka_ml_insights/user_configs/specifications/dna_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/dna_spec.txt rename to eureka_ml_insights/user_configs/specifications/dna_spec.txt diff --git a/eureka_ml_insights/configs/specifications/flenqa_spec.txt b/eureka_ml_insights/user_configs/specifications/flenqa_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/flenqa_spec.txt rename to eureka_ml_insights/user_configs/specifications/flenqa_spec.txt diff --git a/eureka_ml_insights/configs/specifications/geometer_spec.txt b/eureka_ml_insights/user_configs/specifications/geometer_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/geometer_spec.txt rename to eureka_ml_insights/user_configs/specifications/geometer_spec.txt diff --git a/eureka_ml_insights/configs/specifications/ifeval_spec.txt b/eureka_ml_insights/user_configs/specifications/ifeval_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/ifeval_spec.txt rename to eureka_ml_insights/user_configs/specifications/ifeval_spec.txt diff --git a/eureka_ml_insights/configs/specifications/image_understanding.txt b/eureka_ml_insights/user_configs/specifications/image_understanding.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/image_understanding.txt rename to eureka_ml_insights/user_configs/specifications/image_understanding.txt diff --git a/eureka_ml_insights/configs/specifications/kitab_spec.txt b/eureka_ml_insights/user_configs/specifications/kitab_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/kitab_spec.txt rename to eureka_ml_insights/user_configs/specifications/kitab_spec.txt diff --git a/eureka_ml_insights/configs/specifications/mmmu.txt b/eureka_ml_insights/user_configs/specifications/mmmu.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/mmmu.txt rename to eureka_ml_insights/user_configs/specifications/mmmu.txt diff --git a/eureka_ml_insights/configs/specifications/physical_holoassist_spec.txt b/eureka_ml_insights/user_configs/specifications/physical_holoassist_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/physical_holoassist_spec.txt rename to eureka_ml_insights/user_configs/specifications/physical_holoassist_spec.txt diff --git a/eureka_ml_insights/configs/specifications/toxigen_spec.txt b/eureka_ml_insights/user_configs/specifications/toxigen_spec.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/toxigen_spec.txt rename to eureka_ml_insights/user_configs/specifications/toxigen_spec.txt diff --git a/eureka_ml_insights/configs/specifications/vision_language.txt b/eureka_ml_insights/user_configs/specifications/vision_language.txt similarity index 100% rename from eureka_ml_insights/configs/specifications/vision_language.txt rename to eureka_ml_insights/user_configs/specifications/vision_language.txt diff --git a/eureka_ml_insights/configs/toxigen.py b/eureka_ml_insights/user_configs/toxigen.py similarity index 98% rename from eureka_ml_insights/configs/toxigen.py rename to eureka_ml_insights/user_configs/toxigen.py index 6c2583b..ab9fb91 100644 --- a/eureka_ml_insights/configs/toxigen.py +++ b/eureka_ml_insights/user_configs/toxigen.py @@ -25,7 +25,7 @@ ExactMatch, ) -from .config import ( +from eureka_ml_insights.configs import( AggregatorConfig, DataProcessingConfig, DataSetConfig, @@ -35,8 +35,8 @@ PipelineConfig, PromptProcessingConfig, ) -from .experiment_config import ExperimentConfig -from .model_configs import OAI_GPT4_1106_PREVIEW_CONFIG +from eureka_ml_insights.configs import ExperimentConfig +from eureka_ml_insights.configs.model_configs import OAI_GPT4_1106_PREVIEW_CONFIG """This class specifies the config for running Toxigen discriminative benchmark.""" diff --git a/eureka_ml_insights/configs/vision_language/__init__.py b/eureka_ml_insights/user_configs/vision_language/__init__.py similarity index 100% rename from eureka_ml_insights/configs/vision_language/__init__.py rename to eureka_ml_insights/user_configs/vision_language/__init__.py diff --git a/eureka_ml_insights/configs/vision_language/maze.py b/eureka_ml_insights/user_configs/vision_language/maze.py similarity index 99% rename from eureka_ml_insights/configs/vision_language/maze.py rename to eureka_ml_insights/user_configs/vision_language/maze.py index 0f8da2e..7294a65 100644 --- a/eureka_ml_insights/configs/vision_language/maze.py +++ b/eureka_ml_insights/user_configs/vision_language/maze.py @@ -14,7 +14,7 @@ ) from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/vision_language/spatial_grid.py b/eureka_ml_insights/user_configs/vision_language/spatial_grid.py similarity index 99% rename from eureka_ml_insights/configs/vision_language/spatial_grid.py rename to eureka_ml_insights/user_configs/vision_language/spatial_grid.py index 92166eb..c45d6d7 100644 --- a/eureka_ml_insights/configs/vision_language/spatial_grid.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_grid.py @@ -13,7 +13,7 @@ SequenceTransform, ) from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/eureka_ml_insights/configs/vision_language/spatial_map.py b/eureka_ml_insights/user_configs/vision_language/spatial_map.py similarity index 99% rename from eureka_ml_insights/configs/vision_language/spatial_map.py rename to eureka_ml_insights/user_configs/vision_language/spatial_map.py index 7a4cfdf..1453335 100644 --- a/eureka_ml_insights/configs/vision_language/spatial_map.py +++ b/eureka_ml_insights/user_configs/vision_language/spatial_map.py @@ -14,7 +14,7 @@ ) from eureka_ml_insights.metrics import CaseInsensitiveMatch, CountAggregator -from ..config import ( +from eureka_ml_insights.configs import ( AggregatorConfig, DataSetConfig, EvalReportingConfig, diff --git a/main.py b/main.py index 5e6b143..8eed49c 100755 --- a/main.py +++ b/main.py @@ -3,7 +3,7 @@ import argparse import logging -from eureka_ml_insights import configs +from eureka_ml_insights import user_configs as configs from eureka_ml_insights.configs import model_configs from eureka_ml_insights.core import Pipeline diff --git a/tests/pipeline_tests.py b/tests/pipeline_tests.py index 071c1d8..1d86e33 100644 --- a/tests/pipeline_tests.py +++ b/tests/pipeline_tests.py @@ -12,7 +12,14 @@ path = str(Path(Path(__file__).parent.absolute()).parent.absolute()) # noqa sys.path.insert(0, path) # noqa -from eureka_ml_insights.configs import ( +from eureka_ml_insights.configs import MetricConfig, ModelConfig +from eureka_ml_insights.core import Pipeline +from eureka_ml_insights.data_utils.transform import ( + RunPythonTransform, + SamplerTransform, + SequenceTransform, +) +from eureka_ml_insights.user_configs import ( AIME_PIPELINE, DNA_PIPELINE, GEOMETER_PIPELINE, @@ -28,20 +35,12 @@ SPATIAL_MAP_TEXTONLY_PIPELINE, SPATIAL_REASONING_SINGLE_PIPELINE, VISUAL_PROMPTING_SINGLE_PIPELINE, - GPQA_Experiment_Pipeline, Drop_Experiment_Pipeline, + GPQA_Experiment_Pipeline, IFEval_PIPELINE, - MetricConfig, - ModelConfig, ToxiGen_Discriminative_PIPELINE, ToxiGen_Generative_PIPELINE, ) -from eureka_ml_insights.core import Pipeline -from eureka_ml_insights.data_utils.transform import ( - RunPythonTransform, - SamplerTransform, - SequenceTransform, -) from tests.test_utils import ( DetectionTestModel, DNAEvaluationInferenceTestModel, @@ -262,6 +261,7 @@ def configure_pipeline(self): } return config + class TEST_TOXIGEN_GEN_PIPELINE(ToxiGen_Generative_PIPELINE): def configure_pipeline(self): config = super().configure_pipeline(model_config=ModelConfig(GenericTestModel, {})) @@ -297,7 +297,8 @@ def configure_pipeline(self): "n_iter": N_ITER, } return config - + + class TEST_DROP_PIPELINE(Drop_Experiment_Pipeline): # Test config the Drop benchmark with TestModel and TestDataLoader def configure_pipeline(self): @@ -462,14 +463,17 @@ class TOXIGEN_PipelineTest(PipelineTest, unittest.TestCase): def get_config(self): return TEST_TOXIGEN_PIPELINE().pipeline_config + class TOXIGEN_GEN_PipelineTest(PipelineTest, unittest.TestCase): def get_config(self): return TEST_TOXIGEN_GEN_PIPELINE().pipeline_config + class KITAB_ONE_BOOK_CONSTRAINT_PIPELINE_PipelineTest(PipelineTest, unittest.TestCase): def get_config(self): return TEST_KITAB_ONE_BOOK_CONSTRAINT_PIPELINE().pipeline_config + @unittest.skipIf("skip_tests_with_missing_ds" in os.environ, "Missing public dataset. TODO: revert") class GPQA_PipelineTest(PipelineTest, unittest.TestCase): def get_config(self):