Skip to content

Commit

Permalink
webarena debugging
Browse files Browse the repository at this point in the history
  • Loading branch information
zhudotexe committed Dec 12, 2024
1 parent edfe2ad commit 10d3286
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 18 deletions.
3 changes: 2 additions & 1 deletion bench_engines.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,8 @@ def get_engine(model_class: str, model_id: str, context_size: int = None):
return OpenAIEngine(model="gpt-4o-2024-05-13", temperature=0, max_context_size=context_size)
if model_id == "gpt-3.5-turbo-0125":
return OpenAIEngine(model="gpt-3.5-turbo-0125", temperature=0, max_context_size=context_size)
if model_id == "gpt-4o-mini":
return OpenAIEngine(model="gpt-4o-mini", temperature=0, max_context_size=context_size)
# ==== MISTRAL ====
if model_class == "mistral":
from kani.ext.vllm import VLLMEngine
Expand Down Expand Up @@ -166,7 +168,6 @@ def get_engine(model_class: str, model_id: str, context_size: int = None):
},
sampling_params=SamplingParams(temperature=0.7, max_tokens=2048, min_tokens=1),
)
# todo: cohere
raise ValueError("unknown engine")


Expand Down
19 changes: 11 additions & 8 deletions bench_webarena.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,14 +81,17 @@ def wa_ensure_auth(config_file: Path) -> Path:
comb = get_site_comb_from_filepath(cookie_file_name)
temp_dir = tempfile.mkdtemp()
# subprocess to renew the cookie
subprocess.run([
"python",
"experiments/webarena/auto_login.py",
"--auth_folder",
temp_dir,
"--site_list",
*comb,
])
subprocess.run(
[
"python",
"experiments/webarena/auto_login.py",
"--auth_folder",
temp_dir,
"--site_list",
*comb,
],
check=True,
)
_c["storage_state"] = f"{temp_dir}/{cookie_file_name}"
assert os.path.exists(_c["storage_state"])
# write a temp copy of the config file
Expand Down
18 changes: 9 additions & 9 deletions redel/tools/webarena/patches.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,15 +139,15 @@ def patch_to_support_webarena():

# WebArena runs a subprocess to login to get cookies
# which spews logs / warnings, so we silence them
_subprocess_run = subprocess.run

def subprocess_run(*args, **kwargs):
if any("auto_login.py" in a for a in args[0]):
kwargs["stdout"] = subprocess.PIPE
kwargs["stderr"] = subprocess.PIPE
return _subprocess_run(*args, **kwargs)

subprocess.run = lambda *args, **kwargs: subprocess_run(*args, **kwargs)
# _subprocess_run = subprocess.run
#
# def subprocess_run(*args, **kwargs):
# if any("auto_login.py" in a for a in args[0]):
# kwargs["stdout"] = subprocess.PIPE
# kwargs["stderr"] = subprocess.PIPE
# return _subprocess_run(*args, **kwargs)
#
# subprocess.run = lambda *args, **kwargs: subprocess_run(*args, **kwargs)

# WebArena's get_bounding_client_rect method is very slow
with ignore_webarena_warnings():
Expand Down
4 changes: 4 additions & 0 deletions test-webarena.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
#!/bin/zsh

source slurm/webarena-env.sh
python bench_webarena.py --config baseline --model-class openai --large-model gpt-4o-mini --small-model gpt-4o-mini --save-dir experiments/webarena/dev/baseline

0 comments on commit 10d3286

Please sign in to comment.