Skip to content

Commit

Permalink
chore: batch and cache tp NL JSONs
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Dec 9, 2024
1 parent 1145092 commit 10a3b36
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 21 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -172,4 +172,5 @@ BLEURT-20/
/experiments/webarena/**/webarena_trace.zip
/experiments/webarena/**/traces/
/experiments/webarena/traces
/utils/tp_nl_to_json_cache

53 changes: 34 additions & 19 deletions score_travelplanner.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,43 @@
import sys
from pathlib import Path

from redel.utils import read_jsonl
from redel.utils import batched, read_jsonl
from utils.tp_nl_to_json import nl_to_tp_json

EXPECTED_RESULTS = 180
ID_TO_IDX_MAP = Path(__file__).parent / "experiments/travelplanner/id_to_idx.json"


async def transform_one(result, id_to_idx):
idx = result.get("idx", id_to_idx[result["id"]])
query = result["question"]
text_answer = result["answer"] or None
plan = result["plan"] or None

# some stupid fixes for the eval script
if text_answer and not plan:
print(f"No plan output for {idx=}, using 2-step")
plan = await nl_to_tp_json(text_answer)
# fix some gpt weirdness
if not plan:
plan = None
else:
if not isinstance(plan, list):
plan = [plan]
if plan and not plan[0]:
plan = None

if plan:
for day in plan:
for key in ("accommodation", "breakfast", "lunch", "dinner", "attraction"):
# every accommodation needs `, {CITY_NAME}` at the end (just tack on current city if missing)
if day.get(key) and day[key] != "-" and "," not in day[key]:
*_, current_city = day["current_city"].split("to ")
day[key] += f", {current_city}"

return {"idx": idx, "query": query, "plan": plan}


async def transform_submission(fp: Path):
"""Read in the answers and generations and transform them all into the correct TP eval format."""
results = read_jsonl(fp)
Expand All @@ -29,24 +59,9 @@ async def transform_submission(fp: Path):
except FileNotFoundError:
id_to_idx = {}

for result in results:
idx = result.get("idx", id_to_idx[result["id"]])
query = result["question"]
text_answer = result["answer"] or None
plan = result["plan"] or None

# some stupid fixes for the eval script
if plan:
for day in plan:
# every accommodation needs `, {CITY_NAME}` at the end (just tack on current city if missing)
if day["accommodation"] and day["accommodation"] != "-" and not "," in day["accommodation"]:
*_, current_city = day["current_city"].split("to ")
day["accommodation"] += f", {current_city}"
elif text_answer:
print(f"No plan output for {idx=}, using 2-step")
plan = await nl_to_tp_json(text_answer)

transformed.append({"idx": idx, "query": query, "plan": plan})
for result_batch in batched(results, 20):
result_batch_results = await asyncio.gather(*(transform_one(r, id_to_idx) for r in result_batch))
transformed.extend(result_batch_results)

# ensure there are the right number of results, and they're sorted by idx
missing_idxs = set(range(EXPECTED_RESULTS)).difference({t["idx"] for t in transformed})
Expand Down
32 changes: 30 additions & 2 deletions utils/tp_nl_to_json.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
https://github.com/OSU-NLP-Group/TravelPlanner/blob/main/postprocess/openai_request.py
https://github.com/OSU-NLP-Group/TravelPlanner/blob/main/postprocess/parsing.py
"""

import hashlib
import json
from pathlib import Path

from kani import Kani
from kani.engines.openai import OpenAIEngine
Expand Down Expand Up @@ -36,12 +39,37 @@


async def nl_to_tp_json(text):
# check for existing cache
cache_key = hashlib.sha256(text.encode()).hexdigest()
cache_base = Path(__file__).parent / "tp_nl_to_json_cache"
cache_base.mkdir(exist_ok=True)
fn = cache_base / f"{cache_key}.json"
if fn.exists():
with open(fn) as f:
return json.load(f)
# get result, return if null
result = await _nl_to_tp_json(text)
if not result:
result = None
# cache new result
with open(fn, "w") as f:
json.dump(result, f)
return result


async def _nl_to_tp_json(text):
engine = OpenAIEngine(model="gpt-4", temperature=0)
ai = Kani(engine)
query = f"{prefix}\nText: {text}\nPlease output the corresponding JSON only."
resp = await ai.chat_round_str(query)
try:
return json.loads(resp)
except json.JSONDecodeError:
json_text = resp[resp.find("["):resp.rfind("]") + 1]
return json.loads(json_text)
print("Could not decode JSON response, searching for brackets...")
print(resp)
json_text = resp[resp.find("[") : resp.rfind("]") + 1]
try:
return json.loads(json_text)
except json.JSONDecodeError:
print("Could not decode JSON response again, returning null!")
return

0 comments on commit 10a3b36

Please sign in to comment.