From b7e2e808bee7de64fbb00551294ea51788819748 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Mon, 27 Jan 2025 14:58:24 +0100 Subject: [PATCH 01/11] added HLE init + option to pass format to open ai models --- src/lighteval/metrics/llm_as_judge.py | 34 +++- src/lighteval/metrics/metrics_sample.py | 3 + src/lighteval/tasks/extended/hle/main.py | 213 +++++++++++++++++++++++ 3 files changed, 241 insertions(+), 9 deletions(-) create mode 100644 src/lighteval/tasks/extended/hle/main.py diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 23beda76f..8350fb088 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -26,6 +26,7 @@ from concurrent.futures import ThreadPoolExecutor from typing import Callable, Literal +from pydantic import BaseModel from tqdm import tqdm from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available @@ -76,6 +77,7 @@ def __init__( judge_backend: Literal["litellm", "openai", "transformers", "tgi", "vllm"], url: str | None = None, api_key: str | None = None, + response_format: BaseModel = None, ): self.model = model self.template = templates @@ -91,6 +93,8 @@ def __init__( self.api_key = api_key self.backend = judge_backend + self.response_format = response_format + def __lazy_load_client(self): match self.backend: # Wether we use openai or TGI models, we go through the openai API @@ -244,15 +248,27 @@ def __call_api_parallel(self, prompts): def __call_api(self, prompt): for _ in range(self.API_MAX_RETRY): try: - response = self.client.chat.completions.create( - model=self.model, - messages=prompt, - response_format={"type": "text"}, - max_tokens=512, - n=1, - ) - text = response.choices[0].message.content - return text + if self.response_format: + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + response_format=self.response_format, + max_tokens=4096, + temperature=0.0, + n=1, + ) + answer = response.choices[0].message.parsed + return answer + else: + response = self.client.chat.completions.create( + model=self.model, + messages=prompt, + response_format={"type": "text"}, + max_tokens=512, + n=1, + ) + text = response.choices[0].message.content + return text except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index 352c2b98e..cf680422c 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -35,6 +35,7 @@ from nltk.tokenize import word_tokenize from nltk.tokenize.treebank import TreebankWordTokenizer from nltk.translate.bleu_score import sentence_bleu +from pydantic import BaseModel from transformers import AutoModelForSequenceClassification, AutoTokenizer from lighteval.metrics.imports.bert_scorer import BERTScorer @@ -861,6 +862,7 @@ def __init__( process_judge_response: Callable, judge_backend: Literal["litellm", "openai", "transformers", "vllm", "tgi"], short_judge_name: str | None = None, + response_format: BaseModel = None, ) -> None: match judge_backend: case "openai": @@ -893,6 +895,7 @@ def __init__( api_key=api_key, url=url, judge_backend=judge_backend, + response_format=response_format, ) def compute(self, predictions: list[str], formatted_doc: Doc, **kwargs) -> dict[str, float]: diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py new file mode 100644 index 000000000..21744744d --- /dev/null +++ b/src/lighteval/tasks/extended/hle/main.py @@ -0,0 +1,213 @@ +# MIT License + +# Copyright (c) 2024 The HuggingFace Team + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + + +import math +from typing import List, Literal + +import numpy as np +from aenum import extend_enum +from pydantic import BaseModel + +from lighteval.metrics.metrics import Metrics +from lighteval.metrics.metrics_sample import JudgeLLM +from lighteval.metrics.utils.metric_utils import ( + MetricCategory, + MetricUseCase, + SampleLevelMetricGrouping, +) +from lighteval.tasks.lighteval_task import LightevalTaskConfig +from lighteval.tasks.requests import Doc + + +class ExtractedAnswer(BaseModel): + extracted_final_answer: str + reasoning: str + correct: Literal["yes", "no"] + confidence: int + strict: Literal[True] # 100% reliability + + +"""Adaptation from https://github.com/centerforaisafety/hle/blob/main/hle_eval/run_judge_results.py +""" + + +def get_judge_prompt(question: str, answer: str, gold: str, **kwargs): + return [ + { + "role": "user", + "content": rf"""Judge whether the following [response] to [question] is correct or not based on the precise and unambiguous [correct_answer] below. + +[question]: {question} + +[response]: {answer} + +Your judgement must be in the format and criteria specified below: + +extracted_final_answer: The final exact answer extracted from the [response]. Put the extracted answer as 'None' if there is no exact, final answer to extract from the response. + +[correct_answer]: {gold} + +reasoning: Explain why the extracted_final_answer is correct or incorrect based on [correct_answer], focusing only on if there are meaningful differences between [correct_answer] and the extracted_final_answer. Do not comment on any background to the problem, do not attempt to solve the problem, do not argue for any answer different than [correct_answer], focus only on whether the answers match. + +correct: Answer 'yes' if extracted_final_answer matches the [correct_answer] given above, or is within a small margin of error for numerical problems. Answer 'no' otherwise, i.e. if there if there is any inconsistency, ambiguity, non-equivalency, or if the extracted answer is incorrect. + + +confidence: The extracted confidence score between 0|\%| and 100|\%| from [response]. Put 100 if there is no confidence score available.""", + }, + ] + + +def process_judge_response_hle(response: ExtractedAnswer): + return { + # "correct_answer": correct_answer, + "model_answer": response.extracted_final_answer, + "reasoning": response.reasoning, + "correct": response.correct, + "confidence": response.confidence, + } + + +class JudgeLLMHLE(JudgeLLM): + def __init__(self): + super().__init__( + judge_model_name="gpt-4o-2024-08-06", + template=get_judge_prompt, + process_judge_response=process_judge_response_hle, + judge_backend="openai", + short_judge_name="hle_judge", + response_format=ExtractedAnswer, + ) + + def compute(self, predictions, formatted_doc: Doc, **kwargs): + # If we are evaluating a multiturn task, we need to have specific field in the formatted doc + gold = formatted_doc.get_golds()[0] + + score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold) + + score["correct_answer"] = gold + return score + + def compute_corpus(self, scores: List[dict]): + n = len(scores) + + correct = [] + confidence = [] + for score in scores: + correct.append("yes" in score["correct"]) + confidence.append(score["confidence"]) + + correct = np.array(correct) + confidence = np.array(confidence) + + # sometimes model collapses on same questions + if len(correct) != n: + print(f"Available predictions: {len(correct)} | Total questions: {n}") + + accuracy = round(100 * sum(correct) / n, 2) + # Wald estimator, 95% confidence interval + confidence_half_width = round(1.96 * math.sqrt(accuracy * (100 - accuracy) / n), 2) + calibration_error = round(calib_err(confidence, correct, p="2", beta=100), 2) + + return { + "accuracy": accuracy, + "confidence_half_width": confidence_half_width, + "calibration_error": calibration_error, + } + + +# source: https://github.com/hendrycks/outlier-exposure/blob/master/utils/calibration_tools.py +def calib_err(confidence, correct, p="2", beta=100): + # beta is target bin size + idxs = np.argsort(confidence) + confidence = confidence[idxs] + correct = correct[idxs] + bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)] + bins[-1] = [bins[-1][0], len(confidence)] + + cerr = 0 + total_examples = len(confidence) + for i in range(len(bins) - 1): + bin_confidence = confidence[bins[i][0] : bins[i][1]] + bin_correct = correct[bins[i][0] : bins[i][1]] + num_examples_in_bin = len(bin_confidence) + + if num_examples_in_bin > 0: + difference = np.abs(np.nanmean(bin_confidence) - np.nanmean(bin_correct)) + + if p == "2": + cerr += num_examples_in_bin / total_examples * np.square(difference) + elif p == "1": + cerr += num_examples_in_bin / total_examples * difference + elif p == "infty" or p == "infinity" or p == "max": + cerr = np.maximum(cerr, difference) + else: + assert False, "p must be '1', '2', or 'infty'" + + if p == "2": + cerr = np.sqrt(cerr) + + return cerr + + +def hle_text_only(line, task_name: str = None): + if line["image"] is not None: + return + + return Doc( + task_name=task_name, + query=f"Question: {line['question']}\nAnswer:", + choices=[line["answer"]], + gold_index=0, + ) + + +hle_metrics = SampleLevelMetricGrouping( + metric_name=["accuracy", "confidence_half_width", "calibration_error"], + higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]}, + category=MetricCategory.GENERATIVE, + use_case=MetricUseCase.ACCURACY, + sample_level_fn=JudgeLLMHLE().compute, + corpus_level_fn=JudgeLLMHLE().compute_corpus, +) + +hle = LightevalTaskConfig( + name="hle", + suite=["lighteval"], + prompt_function=hle_text_only, + hf_repo="cais/hle", + hf_subset="main", + hf_avail_splits=["test"], + evaluation_splits=["test"], + few_shots_split=None, + few_shots_select=None, + generation_size=1, # TODO + metric=[Metrics.exact_match], + stop_sequence=["\n"], + trust_dataset=True, + version=0, +) + + +TASKS_TABLE = [hle] + +extend_enum(Metrics, "hle_metrics", hle_metrics) From f9bb2a11e1ee6990ff6d070bdd46ee50dec3b465 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Tue, 28 Jan 2025 13:22:53 +0000 Subject: [PATCH 02/11] wip --- src/lighteval/metrics/metrics_sample.py | 2 +- src/lighteval/tasks/extended/__init__.py | 3 ++- src/lighteval/tasks/extended/hle/main.py | 8 ++++---- 3 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/lighteval/metrics/metrics_sample.py b/src/lighteval/metrics/metrics_sample.py index cf680422c..5d0b77f55 100644 --- a/src/lighteval/metrics/metrics_sample.py +++ b/src/lighteval/metrics/metrics_sample.py @@ -853,7 +853,7 @@ def edit_similarity(self, s1, s2): class JudgeLLM: - available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4"] + available_models_openai = ["gpt-3.5-turbo", "gpt-4o", "gpt-4-turbo", "gpt-4", "gpt-4o-2024-08-06"] def __init__( self, diff --git a/src/lighteval/tasks/extended/__init__.py b/src/lighteval/tasks/extended/__init__.py index 97d99188a..f55854651 100644 --- a/src/lighteval/tasks/extended/__init__.py +++ b/src/lighteval/tasks/extended/__init__.py @@ -24,12 +24,13 @@ if can_load_extended_tasks(): + import lighteval.tasks.extended.hle.main as hle import lighteval.tasks.extended.ifeval.main as ifeval import lighteval.tasks.extended.mix_eval.main as mix_eval import lighteval.tasks.extended.mt_bench.main as mt_bench import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks - AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval] + AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, hle] else: AVAILABLE_EXTENDED_TASKS_MODULES = [] diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py index 21744744d..a213789d7 100644 --- a/src/lighteval/tasks/extended/hle/main.py +++ b/src/lighteval/tasks/extended/hle/main.py @@ -170,7 +170,7 @@ def calib_err(confidence, correct, p="2", beta=100): def hle_text_only(line, task_name: str = None): - if line["image"] is not None: + if line["image"] not in [None, ""]: return return Doc( @@ -195,13 +195,13 @@ def hle_text_only(line, task_name: str = None): suite=["lighteval"], prompt_function=hle_text_only, hf_repo="cais/hle", - hf_subset="main", + hf_subset="default", hf_avail_splits=["test"], evaluation_splits=["test"], few_shots_split=None, few_shots_select=None, - generation_size=1, # TODO - metric=[Metrics.exact_match], + generation_size=8192, # TODO + metric=[Metrics.exact_match, hle_metrics], stop_sequence=["\n"], trust_dataset=True, version=0, From db4c4e85acb9d61d671e60c0ef9989d00cf32e1a Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Tue, 28 Jan 2025 15:50:07 +0000 Subject: [PATCH 03/11] added option for inference endpoints in openai --- .../serverless_model_with_openai.yaml | 5 ++++ .../models/endpoints/openai_model.py | 23 +++++++++++++------ 2 files changed, 21 insertions(+), 7 deletions(-) create mode 100644 examples/model_configs/serverless_model_with_openai.yaml diff --git a/examples/model_configs/serverless_model_with_openai.yaml b/examples/model_configs/serverless_model_with_openai.yaml new file mode 100644 index 000000000..0e6947dff --- /dev/null +++ b/examples/model_configs/serverless_model_with_openai.yaml @@ -0,0 +1,5 @@ +model: + model_name: "deepseek-ai/DeepSeek-R1" #meta-llama/Llama-3.1-8B-Instruct" #Qwen/Qwen2.5-14B" #Qwen/Qwen2.5-7B" +api: + base_url: "https://huggingface.co/api/inference-proxy/together" + api_key: "hf_" diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index 37b8ca347..fa3d179fc 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -28,6 +28,7 @@ from typing import Optional from tqdm import tqdm +from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset from lighteval.models.abstract_model import LightevalModel @@ -64,6 +65,8 @@ class OpenAIModelConfig: model: str generation_parameters: GenerationParameters = None + base_url: str = "https://api.openai.com/v1" + api_key: str = os.environ.get("OPENAI_API_KEY", None) def __post_init__(self): if not self.generation_parameters: @@ -74,17 +77,19 @@ def from_path(cls, path: str) -> "OpenAIModelConfig": import yaml with open(path, "r") as f: - config = yaml.safe_load(f)["model"] + loaded_file = yaml.safe_load(f) + config = loaded_file["model"] + api = loaded_file.get("api", {}) generation_parameters = GenerationParameters.from_dict(config) - return cls(model=config["model_name"], generation_parameters=generation_parameters) + return cls(model=config["model_name"], generation_parameters=generation_parameters, **api) class OpenAIClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 def __init__(self, config: OpenAIModelConfig, env_config) -> None: - api_key = os.environ["OPENAI_API_KEY"] - self.client = OpenAI(api_key=api_key) + self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) + self.config = config self.generation_parameters = config.generation_parameters self.sampling_params = self.generation_parameters.to_vllm_openai_dict() @@ -99,21 +104,25 @@ def __init__(self, config: OpenAIModelConfig, env_config) -> None: self.API_RETRY_MULTIPLIER = 2 self.CONCURENT_CALLS = 100 self.model = config.model - self._tokenizer = tiktoken.encoding_for_model(self.model) + try: + self._tokenizer = tiktoken.encoding_for_model(self.model) + except KeyError: + self._tokenizer = AutoTokenizer.from_pretrained(self.model) self.pairwise_tokenization = False def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): for _ in range(self.API_MAX_RETRY): try: + response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {} response = self.client.chat.completions.create( model=self.model, messages=[{"role": "user", "content": prompt}], - response_format={"type": "text"}, max_tokens=max_new_tokens if max_new_tokens > 0 else None, logprobs=return_logits, logit_bias=logit_bias, n=num_samples, **self.sampling_params, + **response_format, ) return response except Exception as e: @@ -181,7 +190,7 @@ def greedy_until( position=0, disable=False, # self.disable_tqdm, ): - max_new_tokens = dataset[0].generation_size # could be none + max_new_tokens = 500 # dataset[0].generation_size # could be none return_logits = dataset[0].use_logits num_samples = dataset[0].num_samples contexts = [c.context for c in dataset] From 59ce9e1fcc41e7b3c0d0f39907dc1d57c1d106b4 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 29 Jan 2025 09:58:07 +0000 Subject: [PATCH 04/11] metrics working --- src/lighteval/metrics/llm_as_judge.py | 38 +++++++++++++++--------- src/lighteval/tasks/extended/hle/main.py | 28 ++++++++++++----- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 8350fb088..1d1c6a20c 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -30,6 +30,7 @@ from tqdm import tqdm from lighteval.utils.imports import is_litellm_available, is_openai_available, is_vllm_available +from lighteval.utils.utils import as_list logging.getLogger("openai").setLevel(logging.ERROR) @@ -37,6 +38,9 @@ logger = logging.getLogger(__name__) +DEFAULT_FORMAT = {"type": "text"} + + class JudgeLM: """ A class representing a judge for evaluating answers using either the OpenAI or Transformers library. @@ -93,7 +97,7 @@ def __init__( self.api_key = api_key self.backend = judge_backend - self.response_format = response_format + self.response_format = response_format if not None else DEFAULT_FORMAT def __lazy_load_client(self): match self.backend: @@ -248,28 +252,34 @@ def __call_api_parallel(self, prompts): def __call_api(self, prompt): for _ in range(self.API_MAX_RETRY): try: - if self.response_format: + # Base model + response = self.client.beta.chat.completions.parse( + model=self.model, + messages=as_list(prompt), + response_format=self.response_format, + max_tokens=4096, + temperature=0.0, + n=1, + ) + answer = response.choices[0].message.parsed + return answer + except TypeError: + try: + # Finetune response = self.client.chat.completions.create( model=self.model, - messages=prompt, + messages=as_list(prompt), response_format=self.response_format, - max_tokens=4096, - temperature=0.0, - n=1, - ) - answer = response.choices[0].message.parsed - return answer - else: - response = self.client.chat.completions.create( - model=self.model, - messages=prompt, - response_format={"type": "text"}, max_tokens=512, n=1, ) text = response.choices[0].message.content return text + except Exception as e: + logger.warning(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) + raise Exception("Failed to get response from the API") diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py index a213789d7..01da2f3c5 100644 --- a/src/lighteval/tasks/extended/hle/main.py +++ b/src/lighteval/tasks/extended/hle/main.py @@ -21,6 +21,7 @@ # SOFTWARE. +import logging import math from typing import List, Literal @@ -31,14 +32,17 @@ from lighteval.metrics.metrics import Metrics from lighteval.metrics.metrics_sample import JudgeLLM from lighteval.metrics.utils.metric_utils import ( + CorpusLevelMetricGrouping, MetricCategory, MetricUseCase, - SampleLevelMetricGrouping, ) from lighteval.tasks.lighteval_task import LightevalTaskConfig from lighteval.tasks.requests import Doc +logger = logging.getLogger(__name__) + + class ExtractedAnswer(BaseModel): extracted_final_answer: str reasoning: str @@ -77,7 +81,10 @@ def get_judge_prompt(question: str, answer: str, gold: str, **kwargs): ] -def process_judge_response_hle(response: ExtractedAnswer): +def process_judge_response_hle(response: ExtractedAnswer | List[ExtractedAnswer]): + # todo: add support for batched responses + if isinstance(response, list): + response = response[0] return { # "correct_answer": correct_answer, "model_answer": response.extracted_final_answer, @@ -105,7 +112,11 @@ def compute(self, predictions, formatted_doc: Doc, **kwargs): score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold) score["correct_answer"] = gold - return score + return { + "accuracy": score, + "confidence_half_width": score, + "calibration_error": score, + } def compute_corpus(self, scores: List[dict]): n = len(scores) @@ -142,6 +153,10 @@ def calib_err(confidence, correct, p="2", beta=100): confidence = confidence[idxs] correct = correct[idxs] bins = [[i * beta, (i + 1) * beta] for i in range(len(confidence) // beta)] + if len(bins) == 0: + logger.warning("Error when computing the bins for calibration error") + return -1 + bins[-1] = [bins[-1][0], len(confidence)] cerr = 0 @@ -181,7 +196,7 @@ def hle_text_only(line, task_name: str = None): ) -hle_metrics = SampleLevelMetricGrouping( +hle_metrics = CorpusLevelMetricGrouping( metric_name=["accuracy", "confidence_half_width", "calibration_error"], higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]}, category=MetricCategory.GENERATIVE, @@ -189,6 +204,7 @@ def hle_text_only(line, task_name: str = None): sample_level_fn=JudgeLLMHLE().compute, corpus_level_fn=JudgeLLMHLE().compute_corpus, ) +extend_enum(Metrics, "hle_metrics", hle_metrics) hle = LightevalTaskConfig( name="hle", @@ -201,7 +217,7 @@ def hle_text_only(line, task_name: str = None): few_shots_split=None, few_shots_select=None, generation_size=8192, # TODO - metric=[Metrics.exact_match, hle_metrics], + metric=[Metrics.exact_match, Metrics.hle_metrics], stop_sequence=["\n"], trust_dataset=True, version=0, @@ -209,5 +225,3 @@ def hle_text_only(line, task_name: str = None): TASKS_TABLE = [hle] - -extend_enum(Metrics, "hle_metrics", hle_metrics) From 23ad47669b5209b145be5be40b1b06216beb38c6 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 29 Jan 2025 14:09:20 +0000 Subject: [PATCH 05/11] attempt at async code --- .../models/endpoints/openai_model.py | 76 +++++++++++++++++-- 1 file changed, 71 insertions(+), 5 deletions(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index fa3d179fc..56c81b0d0 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -20,6 +20,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. +import asyncio import logging import os import time @@ -28,6 +29,7 @@ from typing import Optional from tqdm import tqdm +from tqdm.asyncio import tqdm_asyncio from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -55,7 +57,7 @@ import logging import tiktoken - from openai import OpenAI + from openai import AsyncOpenAI, OpenAI logging.getLogger("openai").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -87,8 +89,12 @@ def from_path(cls, path: str) -> "OpenAIModelConfig": class OpenAIClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 - def __init__(self, config: OpenAIModelConfig, env_config) -> None: - self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) + def __init__(self, config: OpenAIModelConfig, env_config, is_async: bool = False) -> None: + self.is_async = is_async + if is_async: + self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) + else: + self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) self.config = config self.generation_parameters = config.generation_parameters self.sampling_params = self.generation_parameters.to_vllm_openai_dict() @@ -124,11 +130,12 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b **self.sampling_params, **response_format, ) + self.API_RETRY_SLEEP = 3 return response except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) - self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER + self.API_RETRY_SLEEP = self.API_RETRY_SLEEP * self.API_RETRY_MULTIPLIER raise Exception("Failed to get response from the API") def __call_api_parallel( @@ -162,6 +169,62 @@ def __call_api_parallel( return results + async def __call_api_async_one(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): + for _ in range(self.API_MAX_RETRY): + try: + response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {} + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + max_tokens=max_new_tokens if max_new_tokens > 0 else None, + logprobs=return_logits, + logit_bias=logit_bias, + n=num_samples, + **self.sampling_params, + **response_format, + ) + return response + except Exception as e: + logger.warning(f"{type(e), e}") + time.sleep(self.API_RETRY_SLEEP) + self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER + raise Exception("Failed to get response from the API") + + async def __call_api_async( + self, + prompts, + return_logits: bool | list[bool], + max_new_tokens: int | list[int], + num_samples: int | list[int], + logit_bias: list[dict[int, float]] | None = None, + ): + # Convert single values to lists + return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits + max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens + num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples + logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias + + # Validate input lengths + assert ( + len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass) + ), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same" + + async with asyncio.Semaphore(10): # 10 = num workers + # Create tasks for each prompt + tasks = [ + await self.__call_api_async_one(prompt, ret_log, max_tok, num_samp, log_bias) + for prompt, ret_log, max_tok, num_samp, log_bias in zip( + prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass + ) + ] + + results = await tqdm_asyncio.gather(*tasks, return_exceptions=True) + + if None in results: + raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") + + return results + def greedy_until( self, requests: list[GreedyUntilRequest], @@ -195,7 +258,10 @@ def greedy_until( num_samples = dataset[0].num_samples contexts = [c.context for c in dataset] - responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) + if self.is_async: + responses = asyncio.run(self.__call_api_async(contexts, return_logits, max_new_tokens, num_samples)) + else: + responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) for response in responses: result: list[str] = [output.message.content for output in response.choices] From c3e02eaa076b423dda9d832742c334e483a55608 Mon Sep 17 00:00:00 2001 From: "clementine@huggingface.co" Date: Wed, 29 Jan 2025 14:20:34 +0000 Subject: [PATCH 06/11] revert harcoding --- src/lighteval/models/endpoints/openai_model.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index 56c81b0d0..c34306734 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -253,7 +253,7 @@ def greedy_until( position=0, disable=False, # self.disable_tqdm, ): - max_new_tokens = 500 # dataset[0].generation_size # could be none + max_new_tokens = dataset[0].generation_size # could be none return_logits = dataset[0].use_logits num_samples = dataset[0].num_samples contexts = [c.context for c in dataset] From a58e6d60d3413845fcc1f6ac460bbe90b54e8c56 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 30 Jan 2025 19:50:08 +0100 Subject: [PATCH 07/11] revert async --- .../models/endpoints/openai_model.py | 73 +------------------ 1 file changed, 4 insertions(+), 69 deletions(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index c34306734..04b92ccb6 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -20,7 +20,6 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -import asyncio import logging import os import time @@ -29,7 +28,6 @@ from typing import Optional from tqdm import tqdm -from tqdm.asyncio import tqdm_asyncio from transformers import AutoTokenizer from lighteval.data import GenerativeTaskDataset, LoglikelihoodDataset @@ -57,7 +55,7 @@ import logging import tiktoken - from openai import AsyncOpenAI, OpenAI + from openai import OpenAI logging.getLogger("openai").setLevel(logging.ERROR) logging.getLogger("httpx").setLevel(logging.ERROR) @@ -89,12 +87,8 @@ def from_path(cls, path: str) -> "OpenAIModelConfig": class OpenAIClient(LightevalModel): _DEFAULT_MAX_LENGTH: int = 4096 - def __init__(self, config: OpenAIModelConfig, env_config, is_async: bool = False) -> None: - self.is_async = is_async - if is_async: - self.client = AsyncOpenAI(api_key=config.api_key, base_url=config.base_url) - else: - self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) + def __init__(self, config: OpenAIModelConfig, env_config) -> None: + self.client = OpenAI(api_key=config.api_key, base_url=config.base_url) self.config = config self.generation_parameters = config.generation_parameters self.sampling_params = self.generation_parameters.to_vllm_openai_dict() @@ -169,62 +163,6 @@ def __call_api_parallel( return results - async def __call_api_async_one(self, prompt, return_logits, max_new_tokens, num_samples, logit_bias): - for _ in range(self.API_MAX_RETRY): - try: - response_format = {"response_format": {"type": "text"}} if "openai" in self.config.base_url else {} - response = await self.client.chat.completions.create( - model=self.model, - messages=[{"role": "user", "content": prompt}], - max_tokens=max_new_tokens if max_new_tokens > 0 else None, - logprobs=return_logits, - logit_bias=logit_bias, - n=num_samples, - **self.sampling_params, - **response_format, - ) - return response - except Exception as e: - logger.warning(f"{type(e), e}") - time.sleep(self.API_RETRY_SLEEP) - self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER - raise Exception("Failed to get response from the API") - - async def __call_api_async( - self, - prompts, - return_logits: bool | list[bool], - max_new_tokens: int | list[int], - num_samples: int | list[int], - logit_bias: list[dict[int, float]] | None = None, - ): - # Convert single values to lists - return_logitss = [return_logits for _ in prompts] if not isinstance(return_logits, list) else return_logits - max_new_tokenss = [max_new_tokens for _ in prompts] if not isinstance(max_new_tokens, list) else max_new_tokens - num_sampless = [num_samples for _ in prompts] if not isinstance(num_samples, list) else num_samples - logit_biass = [logit_bias for _ in prompts] if logit_bias is None else logit_bias - - # Validate input lengths - assert ( - len(prompts) == len(return_logitss) == len(max_new_tokenss) == len(num_sampless) == len(logit_biass) - ), "Length of prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass should be same" - - async with asyncio.Semaphore(10): # 10 = num workers - # Create tasks for each prompt - tasks = [ - await self.__call_api_async_one(prompt, ret_log, max_tok, num_samp, log_bias) - for prompt, ret_log, max_tok, num_samp, log_bias in zip( - prompts, return_logitss, max_new_tokenss, num_sampless, logit_biass - ) - ] - - results = await tqdm_asyncio.gather(*tasks, return_exceptions=True) - - if None in results: - raise ValueError("Some entries are not annotated due to errors in annotate_p, please inspect and retry.") - - return results - def greedy_until( self, requests: list[GreedyUntilRequest], @@ -258,10 +196,7 @@ def greedy_until( num_samples = dataset[0].num_samples contexts = [c.context for c in dataset] - if self.is_async: - responses = asyncio.run(self.__call_api_async(contexts, return_logits, max_new_tokens, num_samples)) - else: - responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) + responses = self.__call_api_parallel(contexts, return_logits, max_new_tokens, num_samples) for response in responses: result: list[str] = [output.message.content for output in response.choices] From 1dd74f0a6cbc7457e8f2c450042d134f44e576f5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Cl=C3=A9mentine?= Date: Thu, 30 Jan 2025 19:52:55 +0100 Subject: [PATCH 08/11] fix review --- src/lighteval/models/endpoints/openai_model.py | 2 +- src/lighteval/tasks/extended/hle/main.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/lighteval/models/endpoints/openai_model.py b/src/lighteval/models/endpoints/openai_model.py index 04b92ccb6..9dd497e43 100644 --- a/src/lighteval/models/endpoints/openai_model.py +++ b/src/lighteval/models/endpoints/openai_model.py @@ -129,7 +129,7 @@ def __call_api(self, prompt, return_logits, max_new_tokens, num_samples, logit_b except Exception as e: logger.warning(f"{type(e), e}") time.sleep(self.API_RETRY_SLEEP) - self.API_RETRY_SLEEP = self.API_RETRY_SLEEP * self.API_RETRY_MULTIPLIER + self.API_RETRY_SLEEP = self.API_RETRY_SLEEP**self.API_RETRY_MULTIPLIER raise Exception("Failed to get response from the API") def __call_api_parallel( diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py index 01da2f3c5..b0b06315c 100644 --- a/src/lighteval/tasks/extended/hle/main.py +++ b/src/lighteval/tasks/extended/hle/main.py @@ -216,9 +216,9 @@ def hle_text_only(line, task_name: str = None): evaluation_splits=["test"], few_shots_split=None, few_shots_select=None, - generation_size=8192, # TODO + generation_size=8192, metric=[Metrics.exact_match, Metrics.hle_metrics], - stop_sequence=["\n"], + stop_sequence=[], trust_dataset=True, version=0, ) From 6b3976acd0177c87151df09a6dc08a8a7d22a723 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 18 Feb 2025 14:50:37 +0000 Subject: [PATCH 09/11] commit --- src/lighteval/metrics/llm_as_judge.py | 2 +- src/lighteval/models/vllm/vllm_model.py | 1 + src/lighteval/tasks/extended/__init__.py | 1 - src/lighteval/tasks/extended/hle/main.py | 34 +++++++++++++++--------- 4 files changed, 24 insertions(+), 14 deletions(-) diff --git a/src/lighteval/metrics/llm_as_judge.py b/src/lighteval/metrics/llm_as_judge.py index 1d1c6a20c..eb2f47fa7 100644 --- a/src/lighteval/metrics/llm_as_judge.py +++ b/src/lighteval/metrics/llm_as_judge.py @@ -240,7 +240,7 @@ def __call_api(prompt): def __call_api_parallel(self, prompts): results = [] - with ThreadPoolExecutor(100) as executor: + with ThreadPoolExecutor(10) as executor: for entry in tqdm(executor.map(self.__call_api, prompts), total=len(prompts)): results.append(entry) diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index c606c04ed..760515958 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -361,6 +361,7 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r if x is not None ] else: + breakpoint() outputs = self.model.generate( prompt_token_ids=inputs, sampling_params=sampling_params, diff --git a/src/lighteval/tasks/extended/__init__.py b/src/lighteval/tasks/extended/__init__.py index bd1d1d1d3..39963eac1 100644 --- a/src/lighteval/tasks/extended/__init__.py +++ b/src/lighteval/tasks/extended/__init__.py @@ -32,7 +32,6 @@ import lighteval.tasks.extended.olympiade_bench.main as olympiad_bench import lighteval.tasks.extended.tiny_benchmarks.main as tiny_benchmarks - AVAILABLE_EXTENDED_TASKS_MODULES = [ifeval, tiny_benchmarks, mt_bench, mix_eval, olympiad_bench, hle, lcb] else: diff --git a/src/lighteval/tasks/extended/hle/main.py b/src/lighteval/tasks/extended/hle/main.py index b0b06315c..9f6a85610 100644 --- a/src/lighteval/tasks/extended/hle/main.py +++ b/src/lighteval/tasks/extended/hle/main.py @@ -105,18 +105,27 @@ def __init__(self): response_format=ExtractedAnswer, ) - def compute(self, predictions, formatted_doc: Doc, **kwargs): + def compute(self, sample_ids: list[str], responses: list, formatted_docs: list[Doc]) -> list[dict[str, float]]: # If we are evaluating a multiturn task, we need to have specific field in the formatted doc - gold = formatted_doc.get_golds()[0] - - score, _, _ = self.judge.evaluate_answer(question=formatted_doc.query, answer=predictions[0], gold=gold) - - score["correct_answer"] = gold - return { - "accuracy": score, - "confidence_half_width": score, - "calibration_error": score, - } + questions = [formatted_doc.specific["question"] for formatted_doc in formatted_docs] + golds = [formatted_doc.get_golds()[0] for formatted_doc in formatted_docs] + predictions = [response[0].result[0] for response in responses] + options = [None] * len(questions) + + score, _, _ = self.judge.evaluate_answer_batch(questions, predictions, options, golds) + + metrics = [] + for i in range(len(sample_ids)): + score[i]["correct_answer"] = golds[i] + metrics.append( + { + "accuracy": score[i], + "confidence_half_width": score[i], + "calibration_error": score[i], + } + ) + + return metrics def compute_corpus(self, scores: List[dict]): n = len(scores) @@ -193,13 +202,14 @@ def hle_text_only(line, task_name: str = None): query=f"Question: {line['question']}\nAnswer:", choices=[line["answer"]], gold_index=0, + specific={"question": line["question"]}, ) hle_metrics = CorpusLevelMetricGrouping( metric_name=["accuracy", "confidence_half_width", "calibration_error"], higher_is_better={n: True for n in ["accuracy", "confidence_half_width", "calibration_error"]}, - category=MetricCategory.GENERATIVE, + category=MetricCategory.LLM_AS_JUDGE, use_case=MetricUseCase.ACCURACY, sample_level_fn=JudgeLLMHLE().compute, corpus_level_fn=JudgeLLMHLE().compute_corpus, From 9054c9978150a6e431dcf2121376a53f66501315 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 18 Feb 2025 14:51:48 +0000 Subject: [PATCH 10/11] commit --- examples/model_configs/vllm_model_config.yaml | 12 ++---------- src/lighteval/models/vllm/vllm_model.py | 1 - 2 files changed, 2 insertions(+), 11 deletions(-) diff --git a/examples/model_configs/vllm_model_config.yaml b/examples/model_configs/vllm_model_config.yaml index 5192cb558..82114cd95 100644 --- a/examples/model_configs/vllm_model_config.yaml +++ b/examples/model_configs/vllm_model_config.yaml @@ -1,14 +1,6 @@ model: base_params: - model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... + model_args: "pretrained=NousResearch/DeepHermes-3-Llama-3-8B-Preview,revision=main,dtype=bfloat16,max_model_length=32000" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... generation: temperature: 0.3 - repetition_penalty: 1.0 - frequency_penalty: 0.0 - presence_penalty: 0.0 - seed: 42 - top_k: -1 - min_p: 0.0 - top_p: 0.9 - max_new_tokens: 256 - stop_tokens: ["", ""] + max_new_tokens: 12000 diff --git a/src/lighteval/models/vllm/vllm_model.py b/src/lighteval/models/vllm/vllm_model.py index 760515958..c606c04ed 100644 --- a/src/lighteval/models/vllm/vllm_model.py +++ b/src/lighteval/models/vllm/vllm_model.py @@ -361,7 +361,6 @@ def run_inference_one_model(model_args: dict, sampling_params: SamplingParams, r if x is not None ] else: - breakpoint() outputs = self.model.generate( prompt_token_ids=inputs, sampling_params=sampling_params, From ebd1a99227dabc304c265898c1582ee0fb326ae9 Mon Sep 17 00:00:00 2001 From: Nathan Habib Date: Tue, 18 Feb 2025 14:52:27 +0000 Subject: [PATCH 11/11] commit --- examples/model_configs/vllm_model_config.yaml | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/examples/model_configs/vllm_model_config.yaml b/examples/model_configs/vllm_model_config.yaml index 82114cd95..5192cb558 100644 --- a/examples/model_configs/vllm_model_config.yaml +++ b/examples/model_configs/vllm_model_config.yaml @@ -1,6 +1,14 @@ model: base_params: - model_args: "pretrained=NousResearch/DeepHermes-3-Llama-3-8B-Preview,revision=main,dtype=bfloat16,max_model_length=32000" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... + model_args: "pretrained=HuggingFaceTB/SmolLM-1.7B,revision=main,dtype=bfloat16" # pretrained=model_name,trust_remote_code=boolean,revision=revision_to_use,model_parallel=True ... generation: temperature: 0.3 - max_new_tokens: 12000 + repetition_penalty: 1.0 + frequency_penalty: 0.0 + presence_penalty: 0.0 + seed: 42 + top_k: -1 + min_p: 0.0 + top_p: 0.9 + max_new_tokens: 256 + stop_tokens: ["", ""]