-
-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
7 changed files
with
285 additions
and
21 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,2 +1,2 @@ | ||
from .models import DevQuestion, TestQuestion | ||
from .utils import load_dev, load_test | ||
from .wiki import wiki_content, wiki_search |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,109 @@ | ||
"""This module contains a baseline implementation of a retriever for use with long Wikipedia articles""" | ||
|
||
from dataclasses import dataclass | ||
from typing import Iterable | ||
|
||
try: | ||
import numpy as np | ||
from rank_bm25 import BM25Plus | ||
except ImportError as e: | ||
raise ImportError( | ||
"Using the baseline retriever requires the rank_bm25 package. Use `pip install fanoutqa[retrieval]`." | ||
) from e | ||
|
||
from .models import Evidence | ||
from .norm import normalize | ||
from .wiki import wiki_content | ||
|
||
|
||
@dataclass | ||
class RetrievalResult: | ||
title: str | ||
"""The title of the article this fragment comes from.""" | ||
|
||
content: str | ||
"""The content of the fragment.""" | ||
|
||
|
||
class Corpus: | ||
"""A corpus of wiki docs. Indexes the docs on creation, normalizing the text beforehand with lemmatization. | ||
Splits the documents into chunks no longer than a given length, preferring splitting on paragraph and sentence | ||
boundaries. Documents will be converted to Markdown. | ||
Uses BM25+ (Lv and Zhai, 2011), a TF-IDF based approach to retrieve document fragments. | ||
To retrieve chunks corresponding to a query, iterate over ``Corpus.best(query)``. | ||
.. code-block:: python | ||
# example of how to use in the Evidence Provided setting | ||
prompt = "..." | ||
corpus = fanoutqa.retrieval.Corpus(q.necessary_evidence) | ||
for fragment in corpus.best(q.question): | ||
# use your own structured prompt format here | ||
prompt += f"# {fragment.title}\n{fragment.content}\n\n" | ||
""" | ||
|
||
def __init__(self, documents: list[Evidence], doc_len: int = 2048): | ||
""" | ||
:param documents: The list of evidences to index | ||
:param doc_len: The maximum length, in characters, of each chunk | ||
""" | ||
|
||
self.documents = [] | ||
normalized_corpus = [] | ||
for doc in documents: | ||
title = doc.title | ||
content = wiki_content(doc) | ||
for chunk in chunk_text(content, max_chunk_size=doc_len): | ||
self.documents.append(RetrievalResult(title, chunk)) | ||
normalized_corpus.append(self.tokenize(chunk)) | ||
|
||
self.index = BM25Plus(normalized_corpus) | ||
|
||
@staticmethod | ||
def tokenize(text: str): | ||
return normalize(text).split(" ") | ||
|
||
def best(self, q: str) -> Iterable[RetrievalResult]: | ||
"""Yield the best matching fragments to the given query.""" | ||
|
||
tok_q = self.tokenize(q) | ||
scores = self.index.get_scores(tok_q) | ||
idxs = np.argsort(scores)[::-1] | ||
for idx in idxs: | ||
yield self.documents[idx] | ||
|
||
|
||
def chunk_text(text, max_chunk_size=1024, chunk_on=("\n\n", "\n", ". ", ", ", " "), chunker_i=0): | ||
""" | ||
Recursively chunks *text* into a list of str, with each element no longer than *max_chunk_size*. | ||
Prefers splitting on the elements of *chunk_on*, in order. | ||
""" | ||
|
||
if len(text) <= max_chunk_size: # the chunk is small enough | ||
return [text] | ||
if chunker_i >= len(chunk_on): # we have no more preferred chunk_on characters | ||
# optimization: instead of merging a thousand characters, just use list slicing | ||
return [text[:max_chunk_size], *chunk_text(text[max_chunk_size:], max_chunk_size, chunk_on, chunker_i + 1)] | ||
|
||
# split on the current character | ||
chunks = [] | ||
split_char = chunk_on[chunker_i] | ||
for chunk in text.split(split_char): | ||
chunk = f"{chunk}{split_char}" | ||
if len(chunk) > max_chunk_size: # this chunk needs to be split more, recurse | ||
chunks.extend(chunk_text(chunk, max_chunk_size, chunk_on, chunker_i + 1)) | ||
elif chunks and len(chunk) + len(chunks[-1]) <= max_chunk_size: # this chunk can be merged | ||
chunks[-1] += chunk | ||
else: | ||
chunks.append(chunk) | ||
|
||
# if the last chunk is just the split_char, yeet it | ||
if chunks[-1] == split_char: | ||
chunks.pop() | ||
|
||
# remove extra split_char from last chunk | ||
chunks[-1] = chunks[-1][: -len(split_char)] | ||
return chunks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,90 @@ | ||
"""Utils for working with Wikipedia""" | ||
|
||
import functools | ||
import logging | ||
import urllib.parse | ||
from pathlib import Path | ||
|
||
import httpx | ||
|
||
from .models import Evidence | ||
from .utils import DATASET_EPOCH, markdownify | ||
|
||
USER_AGENT = "fanoutqa/1.0.0 ([email protected])" | ||
CACHE_DIR = Path("~/.cache/fanoutqa/wikicache") | ||
CACHE_DIR.mkdir(exist_ok=True, parents=True) | ||
|
||
log = logging.getLogger(__name__) | ||
wikipedia = httpx.Client(base_url="https://en.wikipedia.org/w/api.php", headers={"User-Agent": USER_AGENT}) | ||
|
||
|
||
class LazyEvidence(Evidence): | ||
"""A subclass of Evidence without a known revision ID; lazily loads it when needed.""" | ||
|
||
def __init__(self, title: str, pageid: int, url: str = None): | ||
self.title = title | ||
self.pageid = pageid | ||
self._url = url | ||
|
||
@property | ||
def url(self): | ||
if self._url is not None: | ||
return self._url | ||
encoded_title = urllib.parse.quote(self.title) | ||
return f"https://en.wikipedia.org/wiki/{encoded_title}" | ||
|
||
@functools.cached_property | ||
def revid(self): | ||
resp = wikipedia.get( | ||
"", | ||
params={ | ||
"format": "json", | ||
"action": "query", | ||
"prop": "revisions", | ||
"rvprop": "ids|timestamp", | ||
"rvlimit": 1, | ||
"pageids": self.pageid, | ||
"rvstart": DATASET_EPOCH.isoformat(), | ||
}, | ||
) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
page = data["query"]["pages"][self.pageid] | ||
return page["revisions"][0]["revid"] | ||
|
||
|
||
@functools.lru_cache() | ||
def wiki_search(query: str, results=10) -> list[Evidence]: | ||
"""Return a list of Evidence documents given the search query.""" | ||
# get the list of articles that match the query | ||
resp = wikipedia.get( | ||
"", params={"format": "json", "action": "query", "list": "search", "srsearch": query, "srlimit": results} | ||
) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
|
||
# and return a LazyEvidence for each | ||
return [LazyEvidence(title=d["title"], pageid=d["pageid"]) for d in data["query"]["search"]] | ||
|
||
|
||
def wiki_content(doc: Evidence) -> str: | ||
"""Get the page content in markdown, including tables and infoboxes, appropriate for displaying to an LLM.""" | ||
# get the cached content, if available | ||
cache_filename = CACHE_DIR / f"{doc.pageid}-dated.md" | ||
if cache_filename.exists(): | ||
return cache_filename.read_text() | ||
|
||
# otherwise retrieve it from Wikipedia | ||
resp = wikipedia.get("", params={"format": "json", "action": "parse", "oldid": doc.revid, "prop": "text"}) | ||
resp.raise_for_status() | ||
data = resp.json() | ||
try: | ||
html = data["parse"]["text"]["*"] | ||
except KeyError: | ||
log.warning(f"Could not find dated revision of {doc.title} - maybe the page did not exist yet?") | ||
html = "" | ||
|
||
# MD it, cache it, and return | ||
text = markdownify(html) | ||
cache_filename.write_text(text) | ||
return text |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters