-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
9 changed files
with
386 additions
and
4 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
from .scorer import evaluate |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,55 @@ | ||
import hashlib | ||
import os | ||
|
||
from kani import Kani | ||
from kani.engines.openai import OpenAIEngine | ||
|
||
from fanoutqa.eval.utils import str_answer | ||
from fanoutqa.models import DevQuestion | ||
from fanoutqa.utils import CACHE_DIR | ||
|
||
LLM_CACHE_DIR = CACHE_DIR / "llmcache" | ||
LLM_CACHE_DIR.mkdir(exist_ok=True) | ||
OPENAI_API_KEY = os.getenv("FANOUTQA_OPENAI_API_KEY", "") | ||
|
||
engine = OpenAIEngine(api_key=OPENAI_API_KEY, model="gpt-4-0613", temperature=0, seed=31415) | ||
factuality_system = "You are comparing a submitted answer to an expert answer on a given question." | ||
|
||
|
||
def factuality_prompt(question: str, reference: str, answer: str): | ||
return ( | ||
f"[BEGIN DATA]\n************\n[Question]: {question}\n************\n[Expert]:" | ||
f" {reference}\n************\n[Submission]: {answer}\n************\n[END DATA]\n\nCompare the factual content" | ||
" of the submitted answer with the expert answer. Ignore any differences in style, grammar, or" | ||
" punctuation.\nThe submitted answer may either be a subset or superset of the expert answer, or it may" | ||
" conflict with it. Determine which case applies. First, write out in a step by step manner your reasoning" | ||
" about the factual content to be sure that your conclusion is correct. Avoid simply stating the correct" | ||
' answers at the outset. Then print only the single character "A", "B", "C", "D", "E", or "F" (without quotes' | ||
" or punctuation) on its own line corresponding to the correct answer. At the end, repeat just the letter" | ||
" again by itself on a new line.\n(A) The submitted answer is a subset of the expert answer and is fully" | ||
" consistent with it.\n(B) The submitted answer is a superset of the expert answer and is fully consistent" | ||
" with it.\n(C) The submitted answer contains all the same details as the expert answer.\n(D) There is a" | ||
" disagreement between the submitted answer and the expert answer.\n(E) The answers differ, but these" | ||
" differences don't matter from the perspective of factuality.\n(F) The submitted answer does not answer the" | ||
" question or is otherwise invalid." | ||
) | ||
|
||
|
||
async def get_llm_factuality(question: DevQuestion, answer: str, cache_key=None): | ||
"""Query GPT-4 to determine the factual equivalence of the generated answer and reference answer.""" | ||
# cache | ||
if cache_key: | ||
ans_hash = hashlib.sha256(answer.encode()).hexdigest()[:8] | ||
cache_filename = LLM_CACHE_DIR / f"factual-{cache_key}-{question.id}-{ans_hash}.txt" | ||
if cache_filename.exists(): | ||
return cache_filename.read_text() | ||
|
||
# ask the LLM if it is subjective | ||
prompt = factuality_prompt(question.question, str_answer(question.answer), answer) | ||
ai = Kani(engine, system_prompt=factuality_system) | ||
resp = await ai.chat_round_str(prompt) | ||
|
||
if cache_key: | ||
# noinspection PyUnboundLocalVariable | ||
cache_filename.write_text(resp) | ||
return resp |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
from dataclasses import dataclass | ||
from typing import TypedDict | ||
|
||
|
||
@dataclass | ||
class AccuracyScore: | ||
loose: float | ||
"""Loose accuracy: The mean proportion of reference strings found in the generation.""" | ||
|
||
strict: float | ||
"""Strict accuracy: The proportion of questions with a loose accuracy of 1.0.""" | ||
|
||
|
||
@dataclass | ||
class RougeScorePart: | ||
precision: float | ||
recall: float | ||
fscore: float | ||
|
||
|
||
@dataclass | ||
class RougeScore: | ||
rouge1: RougeScorePart | ||
rouge2: RougeScorePart | ||
rougeL: RougeScorePart | ||
|
||
|
||
@dataclass | ||
class EvaluationScore: | ||
acc: AccuracyScore | ||
rouge: RougeScore | ||
bleurt: float | ||
gpt: float | ||
|
||
|
||
class Answer(TypedDict): | ||
id: str | ||
answer: str |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,173 @@ | ||
import asyncio | ||
import warnings | ||
from typing import Iterable, Optional | ||
|
||
import rouge_score | ||
import rouge_score.scoring | ||
from bleurt.score import BleurtScorer | ||
from rouge_score.rouge_scorer import RougeScorer | ||
|
||
from fanoutqa.eval.llm import OPENAI_API_KEY, get_llm_factuality | ||
from fanoutqa.eval.models import AccuracyScore, Answer, EvaluationScore, RougeScore, RougeScorePart | ||
from fanoutqa.eval.string import answer_in_text | ||
from fanoutqa.eval.utils import str_answer | ||
from fanoutqa.models import DevQuestion | ||
from fanoutqa.utils import batched, copy_doc | ||
|
||
ROUGE_TYPES = ("rouge1", "rouge2", "rougeL") | ||
|
||
|
||
class Scorer: | ||
def __init__( | ||
self, questions: list[DevQuestion], answers: list[Answer], only_score_answered=False, llm_cache_key: str = None | ||
): | ||
""" | ||
:param questions: The questions and reference answers, as loaded by the dataset | ||
:param answers: The generated answers to score | ||
:param only_score_answered: Whether to only score questions that have an answer (True), or consider unanswered | ||
questions to have 0 score (False, default). | ||
:param llm_cache_key: If this is provided, cache the LLM-as-judge generations with this key. We recommend | ||
setting this to a human-readable key for each system under test. | ||
""" | ||
self.questions = questions | ||
self.questions_by_id = {q.id: q for q in self.questions} | ||
self.answers = answers | ||
self.answers_by_id = {r["id"]: r for r in self.answers} | ||
|
||
# number of trials to eval | ||
self.only_score_answered = only_score_answered | ||
if self.only_score_answered: | ||
self.eval_len = len(self.answers) | ||
else: | ||
self.eval_len = len(self.questions) | ||
|
||
self.llm_cache_key = llm_cache_key | ||
|
||
# ext evallers | ||
self.rouge = RougeScorer(ROUGE_TYPES, use_stemmer=True) | ||
self.bleurt = BleurtScorer("BLEURT-20") | ||
|
||
async def score(self): | ||
acc = self.score_accuracy() | ||
rouge = self.score_rouge() | ||
bleurt_ = self.score_bleurt() | ||
# require FANOUTQA_OPENAI_API_KEY to be set to do GPT judge to prevent footguns | ||
if not OPENAI_API_KEY: | ||
warnings.warn( | ||
"No OpenAI API key found! To run GPT-as-judge scoring, set the `FANOUTQA_OPENAI_API_KEY` env var to" | ||
" your OpenAI API key." | ||
) | ||
gptscore = 0 | ||
else: | ||
gptscore = await self.score_gpt() | ||
return EvaluationScore(acc=acc, rouge=rouge, bleurt=bleurt_, gpt=gptscore) | ||
|
||
def get_qa_pairs(self) -> Iterable[tuple[DevQuestion, Optional[Answer]]]: | ||
"""Yield pairs of questions and answers to score. | ||
The answer may be None if there is no answer for a given question and ``only_score_answered`` is False. | ||
""" | ||
if self.only_score_answered: | ||
for a in self.answers: | ||
q = self.questions_by_id.get(a["id"]) | ||
yield q, a | ||
else: | ||
for q in self.questions: | ||
a = self.answers_by_id.get(q.id) | ||
if a is None: | ||
yield q, None | ||
yield q, a | ||
|
||
# scorers | ||
def score_accuracy(self) -> AccuracyScore: | ||
"""Get the loose and strict accuracy scores for the loaded qs and as.""" | ||
accs = [] | ||
n_perfect = 0 | ||
for q, a in self.get_qa_pairs(): | ||
if a is None: | ||
accs.append(0) | ||
continue | ||
result = answer_in_text(q.answer, a["answer"]) | ||
accs.append(result.score) | ||
if result.found: | ||
n_perfect += 1 | ||
|
||
assert len(accs) == self.eval_len | ||
avg_acc = sum(accs) / self.eval_len | ||
pct_perfect = n_perfect / self.eval_len | ||
return AccuracyScore(loose=avg_acc, strict=pct_perfect) | ||
|
||
def score_rouge(self) -> RougeScore: | ||
"""Get the ROUGE-1, ROUGE-2, and ROUGE-L scores (P/R/F1) for the loaded qs and as.""" | ||
scores = {t: [] for t in ROUGE_TYPES} | ||
for q, a in self.get_qa_pairs(): | ||
if a is None: | ||
for score in scores.values(): | ||
score.append(rouge_score.scoring.Score(0, 0, 0)) | ||
continue | ||
results = self.rouge.score(str_answer(q.answer), str_answer(a["answer"])) | ||
for k, v in results.items(): | ||
scores[k].append(v) | ||
|
||
assert all(len(v) == self.eval_len for v in scores.values()) | ||
out = {} | ||
for k, v in scores.items(): | ||
avg_precision = sum(s.precision for s in v) / self.eval_len | ||
avg_recall = sum(s.recall for s in v) / self.eval_len | ||
avg_fscore = sum(s.fmeasure for s in v) / self.eval_len | ||
out[k] = RougeScorePart(precision=avg_precision, recall=avg_recall, fscore=avg_fscore) | ||
return RougeScore(**out) | ||
|
||
def score_bleurt(self) -> float: | ||
"""Get the BLEURT score for the loaded qs and as.""" | ||
references = [] | ||
candidates = [] | ||
for q, a in self.get_qa_pairs(): | ||
if a is None: | ||
candidates.append("") | ||
else: | ||
candidates.append(str_answer(a["answer"])) | ||
references.append(str_answer(q.answer)) | ||
|
||
scores = self.bleurt.score(references=references, candidates=candidates) | ||
assert len(scores) == self.eval_len | ||
avg_score = sum(scores) / self.eval_len | ||
return avg_score | ||
|
||
async def score_gpt(self): | ||
"""Use GPT-4 as a judge to grade the loaded qs and as.""" | ||
accs = [] | ||
|
||
for pairs in batched(self.get_qa_pairs(), 20): | ||
# eval 20 qs at a time | ||
coros = [] | ||
for q, a in pairs: | ||
if a is None: | ||
accs.append(0) | ||
continue | ||
# sometimes we have fun neural text degeneration, just cut it off | ||
ans = a["answer"] | ||
if len(a["answer"]) > 4000: | ||
warnings.warn(f"The answer to question ID {a['id']} is too long, trimming it to 4000 characters.") | ||
ans = ans[:4000] | ||
coro = get_llm_factuality(q, ans, cache_key=self.llm_cache_key) | ||
coros.append(coro) | ||
|
||
# and score their answers | ||
# B, C, E = full score, anything else = 0 | ||
answers = await asyncio.gather(*coros) | ||
for result in answers: | ||
mc = result.strip()[-1].lower() | ||
if mc in "bce": | ||
accs.append(1) | ||
else: | ||
accs.append(0) | ||
|
||
assert len(accs) == self.eval_len | ||
avg_acc = sum(accs) / self.eval_len | ||
return avg_acc | ||
|
||
|
||
@copy_doc(Scorer.__init__) | ||
def evaluate(questions: list[DevQuestion], answers: list[Answer], **kwargs) -> EvaluationScore: | ||
scorer = Scorer(questions, answers, **kwargs) | ||
return asyncio.run(scorer.score()) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
import itertools | ||
import re | ||
from collections import namedtuple | ||
|
||
from fanoutqa.models import AnswerType | ||
from fanoutqa.norm import normalize | ||
|
||
AccuracyResult = namedtuple("AccuracyResult", "found score missing") | ||
|
||
|
||
def answer_in_text(reference: AnswerType, candidate: str) -> AccuracyResult: | ||
"""What proportion of answer strings found in the reference can also be found in the candidate?""" | ||
if isinstance(reference, list): | ||
missing = [] | ||
for a in reference: | ||
result = answer_in_text(a, candidate) | ||
missing.extend(result.missing) | ||
n_found = len(reference) - len(missing) | ||
return AccuracyResult(found=n_found == len(reference), score=n_found / len(reference), missing=missing) | ||
elif isinstance(reference, dict): | ||
missing = [] | ||
vals = itertools.chain(reference.keys(), reference.values()) | ||
for a in vals: | ||
result = answer_in_text(a, candidate) | ||
missing.extend(result.missing) | ||
n_ref = len(reference) * 2 | ||
n_found = n_ref - len(missing) # kvs | ||
return AccuracyResult(found=n_found == n_ref, score=n_found / n_ref, missing=missing) | ||
else: | ||
if isinstance(reference, bool): | ||
reference = "yes" if reference else "no" | ||
# primitive | ||
norm_ans = normalize(reference) | ||
norm_cand = normalize(candidate) | ||
# ensure the answer is surrounded by word boundaries | ||
if not re.search(rf"\b{re.escape(norm_ans)}\b", norm_cand): | ||
return AccuracyResult(found=False, score=0, missing=[norm_ans]) | ||
return AccuracyResult(found=True, score=1, missing=[]) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from fanoutqa.models import AnswerType | ||
|
||
|
||
def str_answer(ans: AnswerType) -> str: | ||
"""Ensure the answer is a string for string-based metrics like ROUGE. Don't normalize it otherwise.""" | ||
if isinstance(ans, list): | ||
return "\n".join(map(str_answer, ans)) | ||
elif isinstance(ans, dict): | ||
return "\n".join(f"{k} - {str_answer(v)}" for k, v in ans.items()) | ||
elif isinstance(ans, bool): | ||
return "yes" if ans else "no" | ||
elif ans is None: | ||
return "" | ||
return str(ans) |
Oops, something went wrong.