Skip to content

Commit

Permalink
feat: add wiki and retrieval
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Feb 13, 2024
1 parent eac4bfa commit b390f00
Show file tree
Hide file tree
Showing 7 changed files with 285 additions and 21 deletions.
33 changes: 30 additions & 3 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,11 @@ TODO: move to website
The `fanoutqa` package requires Python 3.8+.

To work with just the data, use `pip install fanoutqa`.
Use `pip install "fanoutqa[all]"` and read the following section to include a baseline retriever and evaluation metrics.

### Optional

To include a baseline BM25-based retriever, use `pip install "fanoutqa[retrieval]"`.

To run evaluations on the dev set, you will need to run a couple more steps:

Expand All @@ -39,6 +44,14 @@ unzip BLEURT-20.zip
rm BLEURT-20.zip
```

## Quickstart

1. Use `fanoutqa.load_dev()` or `fanoutqa.load_test()` to load the dataset.
2. Run your generations.
1. Use `fanoutqa.wiki_search(title)` and `fanoutqa.wiki_content(evidence)` to retrieve the contents of
Wikipedia pages for the Open Book and Evidence Provided settings.
3. Evaluate your generations with `fanoutqa.evaluate(dev_questions, answers)` (see below for the schema).

## Data Format

To load the dev or test questions, simply use `fanoutqa.load_dev()` or `fanoutqa.load_test()`. This will return a list
Expand Down Expand Up @@ -94,12 +107,19 @@ class TestQuestion:

## Wikipedia Retrieval

TODO
To retrieve the contents of Wikipedia pages used as evidence, this package queries Wikipedia's Revisions API. There
are two main functions to interface with Wikipedia:

- `wiki_search(query)` returns a list of Evidence (Wikipedia pages that best match the query)
- `wiki_content(evidence)` takes an Evidence and returns its content (as of the dataset epoch) as Markdown.

To save on time waiting for requests and computation power (both locally and on Wikipedia's end), this package
aggressively caches retrieved Wikipedia pages.
aggressively caches retrieved Wikipedia pages. By default, this cache is located in `~/.cache/fanoutqa/wikicache`.
We provide many cached pages you can prepopulate this cache with, by using the following commands:

TODO: instructions for setting cache and downloading cache from server
```shell
mkdir -p ~/.cache/fanoutqa/wikicache
```

## Evaluation

Expand Down Expand Up @@ -130,3 +150,10 @@ In the email body, please include details about your system, including at least:
- a link to your paper and recommended short citation, if applicable
- whether it is a new foundation model, a fine-tune, a prompting approach, or other

## Additional Resources

Although this package queries live Wikipedia and uses the Revisions API to get page content as of the dataset epoch,
we also provide a snapshot of English Wikipedia as of Nov 20, 2023. You can download this
snapshot [here](https://datasets.mechanus.zhu.codes/fanoutqa/enwiki-20231120-pages-articles-multistream.xml.bz2) (23G)
and its
index [here](https://datasets.mechanus.zhu.codes/fanoutqa/enwiki-20231120-pages-articles-multistream-index.txt.bz2).
2 changes: 1 addition & 1 deletion fanoutqa/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
from .models import DevQuestion, TestQuestion
from .utils import load_dev, load_test
from .wiki import wiki_content, wiki_search
109 changes: 109 additions & 0 deletions fanoutqa/retrieval.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
"""This module contains a baseline implementation of a retriever for use with long Wikipedia articles"""

from dataclasses import dataclass
from typing import Iterable

try:
import numpy as np
from rank_bm25 import BM25Plus
except ImportError as e:
raise ImportError(
"Using the baseline retriever requires the rank_bm25 package. Use `pip install fanoutqa[retrieval]`."
) from e

from .models import Evidence
from .norm import normalize
from .wiki import wiki_content


@dataclass
class RetrievalResult:
title: str
"""The title of the article this fragment comes from."""

content: str
"""The content of the fragment."""


class Corpus:
"""A corpus of wiki docs. Indexes the docs on creation, normalizing the text beforehand with lemmatization.
Splits the documents into chunks no longer than a given length, preferring splitting on paragraph and sentence
boundaries. Documents will be converted to Markdown.
Uses BM25+ (Lv and Zhai, 2011), a TF-IDF based approach to retrieve document fragments.
To retrieve chunks corresponding to a query, iterate over ``Corpus.best(query)``.
.. code-block:: python
# example of how to use in the Evidence Provided setting
prompt = "..."
corpus = fanoutqa.retrieval.Corpus(q.necessary_evidence)
for fragment in corpus.best(q.question):
# use your own structured prompt format here
prompt += f"# {fragment.title}\n{fragment.content}\n\n"
"""

def __init__(self, documents: list[Evidence], doc_len: int = 2048):
"""
:param documents: The list of evidences to index
:param doc_len: The maximum length, in characters, of each chunk
"""

self.documents = []
normalized_corpus = []
for doc in documents:
title = doc.title
content = wiki_content(doc)
for chunk in chunk_text(content, max_chunk_size=doc_len):
self.documents.append(RetrievalResult(title, chunk))
normalized_corpus.append(self.tokenize(chunk))

self.index = BM25Plus(normalized_corpus)

@staticmethod
def tokenize(text: str):
return normalize(text).split(" ")

def best(self, q: str) -> Iterable[RetrievalResult]:
"""Yield the best matching fragments to the given query."""

tok_q = self.tokenize(q)
scores = self.index.get_scores(tok_q)
idxs = np.argsort(scores)[::-1]
for idx in idxs:
yield self.documents[idx]


def chunk_text(text, max_chunk_size=1024, chunk_on=("\n\n", "\n", ". ", ", ", " "), chunker_i=0):
"""
Recursively chunks *text* into a list of str, with each element no longer than *max_chunk_size*.
Prefers splitting on the elements of *chunk_on*, in order.
"""

if len(text) <= max_chunk_size: # the chunk is small enough
return [text]
if chunker_i >= len(chunk_on): # we have no more preferred chunk_on characters
# optimization: instead of merging a thousand characters, just use list slicing
return [text[:max_chunk_size], *chunk_text(text[max_chunk_size:], max_chunk_size, chunk_on, chunker_i + 1)]

# split on the current character
chunks = []
split_char = chunk_on[chunker_i]
for chunk in text.split(split_char):
chunk = f"{chunk}{split_char}"
if len(chunk) > max_chunk_size: # this chunk needs to be split more, recurse
chunks.extend(chunk_text(chunk, max_chunk_size, chunk_on, chunker_i + 1))
elif chunks and len(chunk) + len(chunks[-1]) <= max_chunk_size: # this chunk can be merged
chunks[-1] += chunk
else:
chunks.append(chunk)

# if the last chunk is just the split_char, yeet it
if chunks[-1] == split_char:
chunks.pop()

# remove extra split_char from last chunk
chunks[-1] = chunks[-1][: -len(split_char)]
return chunks
35 changes: 35 additions & 0 deletions fanoutqa/utils.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,17 @@
import datetime
import json
import os
from pathlib import Path
from typing import TypeAlias, Union

from markdownify import MarkdownConverter

from .models import DevQuestion, TestQuestion

AnyPath: TypeAlias = Union[str, bytes, os.PathLike]
PKG_ROOT = Path(__file__).parent
DATASET_EPOCH = datetime.datetime(year=2023, month=11, day=20, tzinfo=datetime.timezone.utc)
"""The day before which to get revisions from Wikipedia, to ensure that the contents of pages don't change over time."""


def load_dev(fp: AnyPath = None) -> list[DevQuestion]:
Expand All @@ -33,3 +38,33 @@ def load_test(fp: AnyPath = None) -> list[TestQuestion]:
with open(fp) as f:
data = json.load(f)
return [TestQuestion.from_dict(d) for d in data]


# markdown
# We make some minor adjustments to markdownify's default style to make it look a little bit nicer
def discard(*_):
return ""


class MDConverter(MarkdownConverter):
def convert_img(self, el, text, convert_as_inline):
alt = el.attrs.get("alt", None) or ""
return f"![{alt}](image)"

def convert_a(self, el, text, convert_as_inline):
return text

# noinspection PyMethodMayBeStatic,PyUnusedLocal
def convert_div(self, el, text, convert_as_inline):
content = text.strip()
if not content:
return ""
return f"{content}\n"

# sometimes these appear inline and are just annoying
convert_script = discard
convert_style = discard


def markdownify(html: str):
return MDConverter(heading_style="atx").convert(html)
90 changes: 90 additions & 0 deletions fanoutqa/wiki.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Utils for working with Wikipedia"""

import functools
import logging
import urllib.parse
from pathlib import Path

import httpx

from .models import Evidence
from .utils import DATASET_EPOCH, markdownify

USER_AGENT = "fanoutqa/1.0.0 ([email protected])"
CACHE_DIR = Path("~/.cache/fanoutqa/wikicache")
CACHE_DIR.mkdir(exist_ok=True, parents=True)

log = logging.getLogger(__name__)
wikipedia = httpx.Client(base_url="https://en.wikipedia.org/w/api.php", headers={"User-Agent": USER_AGENT})


class LazyEvidence(Evidence):
"""A subclass of Evidence without a known revision ID; lazily loads it when needed."""

def __init__(self, title: str, pageid: int, url: str = None):
self.title = title
self.pageid = pageid
self._url = url

@property
def url(self):
if self._url is not None:
return self._url
encoded_title = urllib.parse.quote(self.title)
return f"https://en.wikipedia.org/wiki/{encoded_title}"

@functools.cached_property
def revid(self):
resp = wikipedia.get(
"",
params={
"format": "json",
"action": "query",
"prop": "revisions",
"rvprop": "ids|timestamp",
"rvlimit": 1,
"pageids": self.pageid,
"rvstart": DATASET_EPOCH.isoformat(),
},
)
resp.raise_for_status()
data = resp.json()
page = data["query"]["pages"][self.pageid]
return page["revisions"][0]["revid"]


@functools.lru_cache()
def wiki_search(query: str, results=10) -> list[Evidence]:
"""Return a list of Evidence documents given the search query."""
# get the list of articles that match the query
resp = wikipedia.get(
"", params={"format": "json", "action": "query", "list": "search", "srsearch": query, "srlimit": results}
)
resp.raise_for_status()
data = resp.json()

# and return a LazyEvidence for each
return [LazyEvidence(title=d["title"], pageid=d["pageid"]) for d in data["query"]["search"]]


def wiki_content(doc: Evidence) -> str:
"""Get the page content in markdown, including tables and infoboxes, appropriate for displaying to an LLM."""
# get the cached content, if available
cache_filename = CACHE_DIR / f"{doc.pageid}-dated.md"
if cache_filename.exists():
return cache_filename.read_text()

# otherwise retrieve it from Wikipedia
resp = wikipedia.get("", params={"format": "json", "action": "parse", "oldid": doc.revid, "prop": "text"})
resp.raise_for_status()
data = resp.json()
try:
html = data["parse"]["text"]["*"]
except KeyError:
log.warning(f"Could not find dated revision of {doc.title} - maybe the page did not exist yet?")
html = ""

# MD it, cache it, and return
text = markdownify(html)
cache_filename.write_text(text)
return text
21 changes: 20 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,26 @@ classifiers = [
"Intended Audience :: Science/Research",
"Topic :: Scientific/Engineering :: Artificial Intelligence",
]
dependencies = []
dependencies = [
"ftfy>=6.1.3,<=7.0.0",
"httpx>=0.26.0,<1.0.0",
"markdownify~=0.11.6",
"spacy>=3.7.2,<4.0.0",
# spacy model
"en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl",
]

[project.optional-dependencies]
all = [".[retrieval,eval]"]

retrieval = [
"rank-bm25~=0.2.2",
]

eval = [
"rouge-score~=0.1.2",
"git+https://github.com/google-research/bleurt.git@master",
]

[project.urls]
"Homepage" = "https://github.com/zhudotexe/fanoutqa"
Expand Down
16 changes: 0 additions & 16 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,19 +1,3 @@
ftfy~=6.1.3
spacy~=3.7.2
tqdm~=4.66.1

# spacy model
en-core-web-sm @ https://github.com/explosion/spacy-models/releases/download/en_core_web_sm-3.7.1/en_core_web_sm-3.7.1-py3-none-any.whl

# wikipedia
beautifulsoup4~=4.12.3
markdownify~=0.11.6
pymediawiki~=0.7.4

# eval
# rouge-score~=0.1.2
# git+https://github.com/google-research/bleurt.git@master

# dev requirements -- to just install the package from source, use pip install .
# include the main package deps
-e .
Expand Down

0 comments on commit b390f00

Please sign in to comment.