From 023a4fd02a872f3a6b8706dd522abd57c7b33512 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 27 Mar 2026 11:16:16 -0700 Subject: [PATCH 1/8] Add --continue-from flag to replay and extend previous runs Enables extending experiments beyond their original step limit by replaying a previous trajectory from wandb, then letting the LLM agent take over. Features: - --continue-from : replay a specific wandb run - --continue-from (no value): auto-find the best matching run by game, LLM model, agent type, and seed (picks the run with the most steps) - Replay phase feeds recorded actions to the env with no LLM calls, preserving original token usage stats from the wandb trajectory - Verifies replay fidelity by comparing observations (warns on divergence) - Truncates trajectory when target steps < existing run steps - Skips LLM loop when game was already won with max score - Logs as a new wandb run referencing the original run ID in config New files: - tales/wandb_utils.py: fetch_run_trajectory() and find_matching_run() - scripts/test_replay_determinism.py: validates deterministic replay across all 5 environment frameworks (Jericho, TextWorld, ALFWorld, TextWorldExpress, ScienceWorld) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 23 ++++ benchmark.py | 202 ++++++++++++++++++++++++++++- scripts/test_replay_determinism.py | 158 ++++++++++++++++++++++ tales/wandb_utils.py | 174 +++++++++++++++++++++++++ 4 files changed, 556 insertions(+), 1 deletion(-) create mode 100644 scripts/test_replay_determinism.py create mode 100644 tales/wandb_utils.py diff --git a/README.md b/README.md index 9a804ef..6678c9b 100644 --- a/README.md +++ b/README.md @@ -75,6 +75,29 @@ In order to benchmark a given LLM acting as language agent playing text-based ga python benchmark.py --agent agents/llm.py zero-shot --envs TWCookingLevel1 +### Continuing a Previous Run + +If you have a previous run that was limited to N steps (e.g., 100), you can extend it to more steps (e.g., 1000) using the `--continue-from` flag. This replays the original trajectory without making LLM calls, then lets the LLM take over for the remaining steps. + +**With an explicit run ID:** + + python benchmark.py reasoning --llm gpt-4o --conversation --continue-from --nb-steps 1000 --envs JerichoEnvZork1 --wandb + +**Auto-find matching run** (searches wandb for a run matching the current game, agent, and seed): + + python benchmark.py reasoning --llm gpt-4o --conversation --continue-from --nb-steps 1000 --envs JerichoEnvZork1 --wandb + +**How it works:** +1. Fetches the original run's config and rollout from the wandb API (or auto-finds one) +2. Recreates the environment with the same seed for deterministic replay +3. Replays all recorded actions (no LLM calls, preserving original token usage stats) +4. Verifies observations match the logged trajectory (warns on divergence) +5. Hands off to the LLM agent for the remaining steps +6. Logs as a new wandb run referencing the original run ID + +> [!NOTE] +> The `--continue-from` flag expects a wandb run ID (e.g., `abc123de`) from the `pearls-lab/text-games-benchmark` project, or no value to auto-find. The agent type and parameters must match the original run. When auto-finding, if no matching run is found, the game runs from scratch. + ### API-based LLMs `llm` natively supports OpenAI models and self-hosted models that offer an OpenAI-compatible API (e.g. like vLLM does - more on this below). diff --git a/benchmark.py b/benchmark.py index 7a9756f..9bbb37b 100644 --- a/benchmark.py +++ b/benchmark.py @@ -19,11 +19,65 @@ import tales from tales.logger import log, setup_logging from tales.utils import NumpyEncoder +from tales.wandb_utils import fetch_run_trajectory, find_matching_run os.environ["WANDB_MODE"] = "disabled" def evaluate(agent, env_name, args): + # Fetch trajectory if continuing from a previous run. + trajectory_df = None + continue_from = getattr(args, "continue_from", None) + if continue_from: + # Auto-find matching run if no explicit run ID was provided. + if continue_from == "auto": + continue_from = find_matching_run(env_name, agent.params, args.game_seed) + if continue_from is None: + log.info( + colored( + f"No matching previous run found for {env_name}. " + f"Running from scratch.", + "yellow", + ) + ) + + if continue_from: + original_config, trajectory_df = fetch_run_trajectory(continue_from) + + # Validate that the game matches. + original_game = original_config.get("game") + if original_game and original_game != env_name: + raise ValueError( + f"Environment mismatch: --continue-from run played '{original_game}' " + f"but current run targets '{env_name}'." + ) + + # Override game_seed from original run to ensure deterministic replay. + original_seed = original_config.get("game_seed") + if original_seed is not None and original_seed != args.game_seed: + log.info( + f"Overriding --game-seed from {args.game_seed} to {original_seed} " + f"(from original run)." + ) + args.game_seed = original_seed + + # Truncate trajectory if it has more steps than the target. + if len(trajectory_df) > args.nb_steps: + log.info( + f"Trajectory has {len(trajectory_df)} steps but target is {args.nb_steps}. " + f"Truncating replay to {args.nb_steps} steps." + ) + trajectory_df = trajectory_df.iloc[: args.nb_steps] + + replay_steps = len(trajectory_df) + log.info( + colored( + f"Continuing from run {continue_from}: " + f"replaying {replay_steps} steps, then LLM takes over up to {args.nb_steps}.", + "cyan", + ) + ) + env_params = ( f"a{int(args.admissible_commands)}_s{args.game_seed}_steps{args.nb_steps}" ) @@ -89,6 +143,10 @@ def evaluate(agent, env_name, args): "admissible_commands": args.admissible_commands, **agent.params, } + if trajectory_df is not None: + wandb_config["continued_from_run_id"] = original_config["_run_id"] + wandb_config["continued_from_run_url"] = original_config["_run_url"] + wandb_config["replay_steps"] = len(trajectory_df) wandb_run = wandb.init( project="tales", config=wandb_config, @@ -138,10 +196,147 @@ def evaluate(agent, env_name, args): }, step=0, ) + + # Replay phase: feed recorded actions to the environment (no LLM calls). + replay_steps = 0 + if trajectory_df is not None: + replay_steps = len(trajectory_df) + log.info(colored(f"Replaying {replay_steps} steps...", "cyan")) + + replay_pbar = tqdm( + trajectory_df.iterrows(), + total=args.nb_steps, + desc=f" {env_name} (replay)", + unit="steps", + leave=False, + ) + for _, row in replay_pbar: + step = int(row["Step"]) + action = str(row["Action"]) + + replay_pbar.set_postfix_str( + f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%})" + ) + + prev_obs = obs + + # Feed the recorded action to the environment. + if "\n" in action.strip(): + obs = "The game only allows one action per step." + else: + obs, _, done, info = env.step(action) + + score = info["score"] + moves = info["moves"] + feedback = info["feedback"] + norm_score = score / max_score + highscore = max(score, highscore) + norm_highscore = highscore / max_score + + if ( + args.admissible_commands + and info["admissible_commands"] + and action not in info["admissible_commands"] + ): + nb_invalid_actions += 1 + + # Verify replay fidelity by comparing observations. + logged_obs = row.get("Observation") + if logged_obs is not None and isinstance(logged_obs, str): + if prev_obs.strip() != logged_obs.strip(): + log.warning( + f"Replay divergence at step {step}:\n" + f" Expected: {logged_obs[:200]!r}\n" + f" Got: {prev_obs[:200]!r}" + ) + + # Build agent history so it has context for subsequent LLM calls. + agent.history.append((f"{prev_obs}\n> ", f"{action}\n")) + + msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s} (replay)" + msg = msg.format(step, time.time() - start_time, score, moves, action) + log.info(msg) + + # Log to wandb with original token usage from the trajectory. + nb_tokens = row.get("Token Usage", 0) or 0 + nb_tokens_thinking = row.get("Thinking Tokens", 0) or 0 + wandb_run.log( + { + "episode/moves": moves, + "episode/score": score, + "episode/highscore": highscore, + "episode/normalized_score": norm_score, + "episode/normalized_highscore": norm_highscore, + "episode/token_usage": nb_tokens, + "episode/token_usage_thinking": nb_tokens_thinking, + }, + step=step, + ) + + # Store results with original token usage from the trajectory. + # fmt: off + results.append([ + step, score, max_score, norm_score, moves, + prev_obs, action, feedback, + row.get("Prompt", ""), row.get("Response", ""), row.get("Thinking"), + row.get("Token Usage", 0) or 0, row.get("Prompt Tokens", 0) or 0, + row.get("Response Tokens", 0) or 0, row.get("Thinking Tokens", 0) or 0, + ]) + # fmt: on + + if not done: + log.debug(obs) + + if done: + if info["won"]: + nb_wins += 1 + if highscore == max_score: + log.debug(obs) + # Don't break during replay; continue replaying. + # Don't reset either — the original run broke here. + continue + elif info["lost"]: + nb_losts += 1 + + # Reset the game just like the original run did. + last_obs = obs + obs, info = env.reset() + obs = last_obs + "\n\n-= Restarting =-\n" + obs + agent.reset(obs, info, env_name) + nb_resets += 1 + + log.debug(f"{obs}") + + replay_pbar.close() + + if highscore == max_score: + log.info( + colored( + f"Replay complete: game already won with max score ({highscore}/{max_score}). " + f"No further steps needed.", + "green", + ) + ) + # Skip the LLM loop entirely. + replay_steps = args.nb_steps + else: + log.info( + colored( + f"Replay complete: {replay_steps} steps, score={score}, highscore={highscore}. " + f"LLM takes over from step {replay_steps + 1}.", + "cyan", + ) + ) + try: pbar = tqdm( - range(1, args.nb_steps + 1), desc=f" {env_name}", unit="steps", leave=False + range(replay_steps + 1, args.nb_steps + 1), + initial=replay_steps, + total=args.nb_steps, + desc=f" {env_name}", + unit="steps", + leave=False, ) for step in pbar: pbar.set_postfix_str( @@ -455,6 +650,11 @@ def _add_general_settings(parser): help="Force overwriting only log files that have failed.") general_group.add_argument("--debug", action="store_true", help="Debug mode.") + general_group.add_argument("--continue-from", dest="continue_from", + nargs="?", const="auto", + help="Continue from a previous wandb run. " + "Pass a run ID to replay a specific run, or use without a value to auto-find " + "a matching run based on the current config (game, agent, seed).") subgroup = general_group.add_mutually_exclusive_group() subgroup.add_argument( diff --git a/scripts/test_replay_determinism.py b/scripts/test_replay_determinism.py new file mode 100644 index 0000000..a900517 --- /dev/null +++ b/scripts/test_replay_determinism.py @@ -0,0 +1,158 @@ +"""Test that replaying recorded actions through an environment produces the same trajectory. + +This script runs a short game with the random agent, records the trajectory, +then replays the same actions and verifies the observations match. +""" + +import gymnasium as gym +import numpy as np + +import tales + +NB_STEPS = 15 +GAME_SEED = 42 + + +def run_game(env_name, seed, nb_steps, actions=None): + """Run a game and return the trajectory. + + If actions is provided, replay those actions instead of generating random ones. + Returns a list of dicts with step, obs, action, feedback, score, done. + """ + env = gym.make( + f"tales/{env_name}-v0", + disable_env_checker=True, + admissible_commands=False, + ) + + obs, info = env.reset(seed=seed) + rng = np.random.RandomState(seed) + trajectory = [] + + for step in range(1, nb_steps + 1): + if actions is not None: + action = actions[step - 1] + else: + # Use a simple deterministic action for testing. + action = rng.choice( + [ + "look", + "inventory", + "north", + "south", + "east", + "west", + "take all", + "open door", + "examine room", + ] + ) + + prev_obs = obs + if "\n" in action.strip(): + obs = "The game only allows one action per step." + done = False + info_after = info + else: + obs, _, done, info_after = env.step(action) + + trajectory.append( + { + "step": step, + "prev_obs": prev_obs, + "action": action, + "obs_after": obs, + "feedback": info_after.get("feedback", ""), + "score": info_after["score"], + "done": done, + } + ) + + info = info_after + + if done: + last_obs = obs + obs, info = env.reset() + obs = last_obs + "\n\n-= Restarting =-\n" + obs + + env.close() + return trajectory + + +def test_env(env_name): + """Test replay determinism for a single environment.""" + print(f"\n{'='*60}") + print(f"Testing: {env_name}") + print(f"{'='*60}") + + # Run 1: generate trajectory + traj1 = run_game(env_name, GAME_SEED, NB_STEPS) + actions = [t["action"] for t in traj1] + + print(f" Run 1: {len(traj1)} steps, final score={traj1[-1]['score']}") + + # Run 2: replay the same actions + traj2 = run_game(env_name, GAME_SEED, NB_STEPS, actions=actions) + + print(f" Run 2: {len(traj2)} steps, final score={traj2[-1]['score']}") + + # Compare trajectories + mismatches = 0 + for i, (t1, t2) in enumerate(zip(traj1, traj2)): + if t1["prev_obs"].strip() != t2["prev_obs"].strip(): + print(f" MISMATCH at step {t1['step']} (prev_obs):") + print(f" Run 1: {t1['prev_obs'][:100]!r}") + print(f" Run 2: {t2['prev_obs'][:100]!r}") + mismatches += 1 + if t1["obs_after"].strip() != t2["obs_after"].strip(): + print(f" MISMATCH at step {t1['step']} (obs_after):") + print(f" Run 1: {t1['obs_after'][:100]!r}") + print(f" Run 2: {t2['obs_after'][:100]!r}") + mismatches += 1 + if t1["score"] != t2["score"]: + print( + f" MISMATCH at step {t1['step']} (score): {t1['score']} vs {t2['score']}" + ) + mismatches += 1 + + if mismatches == 0: + print(f" ✅ PASS: All {len(traj1)} steps match perfectly.") + else: + print(f" ❌ FAIL: {mismatches} mismatches found.") + + return mismatches == 0 + + +def main(): + # Test one environment from each framework. + test_envs = [] + + # Pick one env per framework. + for task in tales.tasks: + envs = tales.envs_per_task.get(task, []) + if envs: + test_envs.append((task, sorted(envs)[0])) + + print(f"Testing replay determinism for {len(test_envs)} environments:") + for task, env in test_envs: + print(f" {task}: {env}") + + results = {} + for task, env_name in test_envs: + try: + passed = test_env(env_name) + results[env_name] = "PASS" if passed else "FAIL" + except Exception as e: + print(f" ❌ ERROR: {e}") + results[env_name] = f"ERROR: {e}" + + print(f"\n{'='*60}") + print("SUMMARY") + print(f"{'='*60}") + for env_name, result in results.items(): + status = "✅" if result == "PASS" else "❌" + print(f" {status} {env_name}: {result}") + + +if __name__ == "__main__": + main() diff --git a/tales/wandb_utils.py b/tales/wandb_utils.py new file mode 100644 index 0000000..b436a17 --- /dev/null +++ b/tales/wandb_utils.py @@ -0,0 +1,174 @@ +import json +import logging +import tempfile + +import pandas as pd +import wandb + +log = logging.getLogger("tales") + +WANDB_PROJECT = "pearls-lab/text-games-benchmark" + +ROLLOUT_COLUMNS = [ + "Step", + "Score", + "Max Score", + "Normalized Score", + "Moves", + "Observation", + "Action", + "Feedback", + "Prompt", + "Response", + "Thinking", + "Token Usage", + "Prompt Tokens", + "Response Tokens", + "Thinking Tokens", +] + + +def find_matching_run(env_name, agent_params, game_seed, project=WANDB_PROJECT): + """Find a matching wandb run based on game and agent config fields. + + Searches the wandb project for finished runs that match the core experiment + identity: game, LLM model, agent type, LLM seed, and game seed. Fields + that may change between runs (like context_limit or max_steps) are + intentionally excluded from matching. + + Among matching runs, returns the one with the most completed steps. + The caller is responsible for truncating the trajectory if needed. + + Args: + env_name: The game/environment name (e.g., "JerichoEnvZork1"). + agent_params: Dict of agent parameters (from agent.params). + game_seed: The game seed (can be None). + project: The wandb project path. + + Returns: + The run ID of the best matching run, or None if no match found. + """ + api = wandb.Api() + + # Match on stable identity fields only (not context_limit or max_steps). + llm = agent_params.get("llm") + agent_type = agent_params.get("agent_type") + seed = agent_params.get("seed") + + log.info( + f"Searching for matching run: game={env_name}, llm={llm}, " + f"agent_type={agent_type}, seed={seed}, game_seed={game_seed}" + ) + + # Use wandb config filters for the fields that are top-level in config. + filters = { + "config.game": env_name, + "state": "finished", + } + if llm is not None: + filters["config.llm"] = llm + if agent_type is not None: + filters["config.agent_type"] = agent_type + + try: + runs = api.runs(project, filters=filters, order="-created_at") + except wandb.errors.CommError as e: + log.warning(f"Failed to search wandb runs: {e}") + return None + + # Filter by seed and game_seed, collecting candidates. + candidates = [] + for run in runs: + cfg = run.config + if seed is not None and cfg.get("seed") != seed: + continue + if game_seed is not None and cfg.get("game_seed") != game_seed: + continue + + run_steps = run.summary.get("total/Env. Steps", 0) or 0 + candidates.append((run, run_steps)) + + if not candidates: + log.warning( + f"No matching run found for: game={env_name}, llm={llm}, " + f"agent_type={agent_type}, seed={seed}, game_seed={game_seed}" + ) + return None + + # Pick the run with the most steps (will be truncated to nb_steps if needed). + best_run, best_steps = max(candidates, key=lambda x: x[1]) + + log.info( + f"Found matching run: {best_run.name} " + f"({best_steps} steps, id={best_run.id}, url={best_run.url})" + ) + return best_run.id + + +def fetch_run_trajectory(run_id, project=WANDB_PROJECT): + """Fetch run config and rollout trajectory from wandb. + + Args: + run_id: The wandb run ID (e.g., "abc123de"). + project: The wandb project path (e.g., "entity/project"). + + Returns: + A tuple of (run_config, trajectory_df) where: + - run_config is a dict with the original run's configuration + plus metadata (run_id, run_url, run_name). + - trajectory_df is a DataFrame with one row per step, + columns matching ROLLOUT_COLUMNS. + + Raises: + ValueError: If the run or rollout data cannot be found. + """ + api = wandb.Api() + run_path = f"{project}/{run_id}" + log.info(f"Fetching run {run_path} from wandb...") + + try: + run = api.run(run_path) + except wandb.errors.CommError as e: + raise ValueError(f"Could not find wandb run '{run_path}': {e}") from e + + # Extract config. + run_config = dict(run.config) + run_config["_run_id"] = run.id + run_config["_run_url"] = run.url + run_config["_run_name"] = run.name + run_config["_run_state"] = run.state + + # Download the rollout JSONL file. + trajectory_df = _download_rollout(run) + + log.info( + f"Fetched trajectory: {len(trajectory_df)} steps from run '{run.name}' ({run.state})" + ) + return run_config, trajectory_df + + +def _download_rollout(run): + """Download and parse the rollout JSONL file from a wandb run.""" + rollout_file = None + for f in run.files(): + if f.name.endswith(".jsonl"): + rollout_file = f + break + + if rollout_file is None: + raise ValueError( + f"No rollout JSONL file found in wandb run '{run.id}'. " + f"Available files: {[f.name for f in run.files()]}" + ) + + with tempfile.TemporaryDirectory() as tmpdir: + rollout_file.download(root=tmpdir, replace=True) + filepath = f"{tmpdir}/{rollout_file.name}" + df = pd.read_json(filepath, orient="records", lines=True) + + # Validate columns. + missing = set(ROLLOUT_COLUMNS) - set(df.columns) + if missing: + log.warning(f"Rollout is missing expected columns: {missing}") + + return df From c3e84618b4dcfe083904904fb134169d3ba80438 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 27 Mar 2026 11:34:11 -0700 Subject: [PATCH 2/8] Refactor evaluate() into modular replay and play functions Extract shared helpers (_make_state, _step_env, _check_invalid, _handle_done, _record_step) and standalone replay_trajectory() and play_with_agent() functions from the monolithic evaluate() function. This improves readability and maintainability of the replay/continue logic while preserving identical behavior. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- benchmark.py | 554 +++++++++++++++++++++++++++------------------------ 1 file changed, 296 insertions(+), 258 deletions(-) diff --git a/benchmark.py b/benchmark.py index 9bbb37b..e10f780 100644 --- a/benchmark.py +++ b/benchmark.py @@ -24,6 +24,258 @@ os.environ["WANDB_MODE"] = "disabled" +def _make_state(obs, info): + """Create mutable game state dict.""" + return { + "step": 0, + "score": 0, + "moves": 0, + "highscore": 0, + "max_score": info["max_score"], + "nb_wins": 0, + "nb_losts": 0, + "nb_resets": 0, + "nb_invalid_actions": 0, + "obs": obs, + "done": False, + "info": info, + "results": [], + } + + +def _step_env(env, state, action): + """Execute an action and update state. Returns (prev_obs, feedback).""" + prev_obs = state["obs"] + + if "\n" in action.strip(): + state["obs"] = "The game only allows one action per step." + else: + obs, _, done, info = env.step(action) + state["obs"] = obs + state["done"] = done + state["info"] = info + + state["score"] = state["info"]["score"] + state["moves"] = state["info"]["moves"] + state["highscore"] = max(state["score"], state["highscore"]) + + return prev_obs, state["info"]["feedback"] + + +def _check_invalid(state, action, admissible_commands): + """Track invalid actions when admissible commands are enabled.""" + if ( + admissible_commands + and state["info"]["admissible_commands"] + and action not in state["info"]["admissible_commands"] + ): + state["nb_invalid_actions"] += 1 + + +def _handle_done(env, agent, state, env_name, break_on_max=True): + """Handle game-over (win/loss/reset). Returns True if loop should break.""" + if not state["done"]: + return False + + if state["info"]["won"]: + state["nb_wins"] += 1 + if state["highscore"] == state["max_score"]: + log.debug(state["obs"]) + if break_on_max: + return True # Break: no reason to play more. + else: + return False # Replay: don't break or reset. + elif state["info"]["lost"]: + state["nb_losts"] += 1 + + # Restart the game to try for a better score. + last_obs = state["obs"] + obs, info = env.reset() + state["obs"] = last_obs + "\n\n-= Restarting =-\n" + obs + state["info"] = info + agent.reset(state["obs"], info, env_name) + state["nb_resets"] += 1 + log.debug(state["obs"]) + return False + + +def _record_step(state, prev_obs, action, feedback, token_stats, wandb_run): + """Append step to results and log to wandb.""" + s = state + norm_score = s["score"] / s["max_score"] + norm_highscore = s["highscore"] / s["max_score"] + + wandb_run.log( + { + "episode/moves": s["moves"], + "episode/score": s["score"], + "episode/highscore": s["highscore"], + "episode/normalized_score": norm_score, + "episode/normalized_highscore": norm_highscore, + "episode/token_usage": token_stats["nb_tokens"], + "episode/token_usage_thinking": token_stats.get("nb_tokens_thinking", 0), + }, + step=s["step"], + ) + + # fmt: off + s["results"].append([ + s["step"], s["score"], s["max_score"], norm_score, s["moves"], + prev_obs, action, feedback, + token_stats["prompt"], token_stats["response"], token_stats.get("thinking"), + token_stats["nb_tokens"], token_stats["nb_tokens_prompt"], + token_stats["nb_tokens_response"], token_stats.get("nb_tokens_thinking", 0), + ]) + # fmt: on + + +def replay_trajectory( + env, agent, trajectory_df, state, wandb_run, args, env_name, start_time +): + """Replay recorded actions through the environment (no LLM calls). + + Feeds each action from the trajectory to the environment, verifies + observations match, and builds agent history for subsequent LLM play. + """ + replay_steps = len(trajectory_df) + log.info(colored(f"Replaying {replay_steps} steps...", "cyan")) + + replay_pbar = tqdm( + trajectory_df.iterrows(), + total=args.nb_steps, + desc=f" {env_name} (replay)", + unit="steps", + leave=False, + ) + for _, row in replay_pbar: + state["step"] = int(row["Step"]) + action = str(row["Action"]) + + replay_pbar.set_postfix_str( + f"Score: {state['info']['score']}/{state['info']['max_score']}" + f" ({state['info']['score']/state['info']['max_score']:.1%})" + ) + + prev_obs, feedback = _step_env(env, state, action) + _check_invalid(state, action, args.admissible_commands) + + # Verify replay fidelity. + logged_obs = row.get("Observation") + if logged_obs is not None and isinstance(logged_obs, str): + if prev_obs.strip() != logged_obs.strip(): + log.warning( + f"Replay divergence at step {state['step']}:\n" + f" Expected: {logged_obs[:200]!r}\n" + f" Got: {prev_obs[:200]!r}" + ) + + # Build agent history for subsequent LLM calls. + agent.history.append((f"{prev_obs}\n> ", f"{action}\n")) + + msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s} (replay)" + msg = msg.format( + state["step"], + time.time() - start_time, + state["score"], + state["moves"], + action, + ) + log.info(msg) + + # Use original token stats from the trajectory. + token_stats = { + "prompt": row.get("Prompt", "") or "", + "response": row.get("Response", "") or "", + "thinking": row.get("Thinking"), + "nb_tokens": row.get("Token Usage", 0) or 0, + "nb_tokens_prompt": row.get("Prompt Tokens", 0) or 0, + "nb_tokens_response": row.get("Response Tokens", 0) or 0, + "nb_tokens_thinking": row.get("Thinking Tokens", 0) or 0, + } + _record_step(state, prev_obs, action, feedback, token_stats, wandb_run) + + if not state["done"]: + log.debug(state["obs"]) + + _handle_done(env, agent, state, env_name, break_on_max=False) + + replay_pbar.close() + + if state["highscore"] == state["max_score"]: + log.info( + colored( + f"Replay complete: game already won with max score " + f"({state['highscore']}/{state['max_score']}). No further steps needed.", + "green", + ) + ) + return args.nb_steps # Signal: skip the play loop. + else: + log.info( + colored( + f"Replay complete: {replay_steps} steps, score={state['score']}, " + f"highscore={state['highscore']}. " + f"LLM takes over from step {replay_steps + 1}.", + "cyan", + ) + ) + return replay_steps + + +def play_with_agent( + env, agent, state, wandb_run, args, env_name, start_time, start_step +): + """Play the game with the LLM agent from start_step to nb_steps.""" + pbar = tqdm( + range(start_step, args.nb_steps + 1), + initial=start_step - 1, + total=args.nb_steps, + desc=f" {env_name}", + unit="steps", + leave=False, + ) + for step in pbar: + state["step"] = step + pbar.set_postfix_str( + f"Score: {state['info']['score']}/{state['info']['max_score']}" + f" ({state['info']['score']/state['info']['max_score']:.1%})" + ) + + action, stats = agent.act( + state["obs"], state["score"], state["done"], state["info"] + ) + log.debug(colored(f"> {action}", "green")) + + if args.debug: + breakpoint() + + prev_obs, feedback = _step_env(env, state, action) + _check_invalid(state, action, args.admissible_commands) + + msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s}" + msg = msg.format( + step, time.time() - start_time, state["score"], state["moves"], action + ) + log.info(msg) + + token_stats = { + "prompt": stats["prompt"], + "response": stats["response"], + "thinking": stats.get("thinking"), + "nb_tokens": stats["nb_tokens"], + "nb_tokens_prompt": stats["nb_tokens_prompt"], + "nb_tokens_response": stats["nb_tokens_response"], + "nb_tokens_thinking": stats.get("nb_tokens_thinking", 0), + } + _record_step(state, prev_obs, action, feedback, token_stats, wandb_run) + + if not state["done"]: + log.debug(state["obs"]) + + if _handle_done(env, agent, state, env_name, break_on_max=True): + break + + def evaluate(agent, env_name, args): # Fetch trajectory if continuing from a previous run. trajectory_df = None @@ -171,261 +423,47 @@ def evaluate(agent, env_name, args): log.debug(f"Environment reset.\n{obs}\n") - status = "running" - max_score = info["max_score"] - step = 0 - nb_resets = 0 - nb_wins = 0 - nb_losts = 0 - nb_resets = 0 - nb_invalid_actions = 0 - moves = 0 - highscore = 0 - score = 0 - done = False - results = [] + state = _make_state(obs, info) wandb_run.log( { - "episode/moves": moves, - "episode/score": score, - "episode/highscore": highscore, - "episode/normalized_score": score / max_score, - "episode/normalized_highscore": highscore / max_score, + "episode/moves": 0, + "episode/score": 0, + "episode/highscore": 0, + "episode/normalized_score": 0, + "episode/normalized_highscore": 0, "episode/token_usage": 0, }, step=0, ) - # Replay phase: feed recorded actions to the environment (no LLM calls). + # Replay phase (if continuing from a previous run). replay_steps = 0 if trajectory_df is not None: - replay_steps = len(trajectory_df) - log.info(colored(f"Replaying {replay_steps} steps...", "cyan")) - - replay_pbar = tqdm( - trajectory_df.iterrows(), - total=args.nb_steps, - desc=f" {env_name} (replay)", - unit="steps", - leave=False, + replay_steps = replay_trajectory( + env, agent, trajectory_df, state, wandb_run, args, env_name, start_time ) - for _, row in replay_pbar: - step = int(row["Step"]) - action = str(row["Action"]) - - replay_pbar.set_postfix_str( - f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%})" - ) - - prev_obs = obs - - # Feed the recorded action to the environment. - if "\n" in action.strip(): - obs = "The game only allows one action per step." - else: - obs, _, done, info = env.step(action) - - score = info["score"] - moves = info["moves"] - feedback = info["feedback"] - norm_score = score / max_score - highscore = max(score, highscore) - norm_highscore = highscore / max_score - - if ( - args.admissible_commands - and info["admissible_commands"] - and action not in info["admissible_commands"] - ): - nb_invalid_actions += 1 - - # Verify replay fidelity by comparing observations. - logged_obs = row.get("Observation") - if logged_obs is not None and isinstance(logged_obs, str): - if prev_obs.strip() != logged_obs.strip(): - log.warning( - f"Replay divergence at step {step}:\n" - f" Expected: {logged_obs[:200]!r}\n" - f" Got: {prev_obs[:200]!r}" - ) - - # Build agent history so it has context for subsequent LLM calls. - agent.history.append((f"{prev_obs}\n> ", f"{action}\n")) - - msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s} (replay)" - msg = msg.format(step, time.time() - start_time, score, moves, action) - log.info(msg) - - # Log to wandb with original token usage from the trajectory. - nb_tokens = row.get("Token Usage", 0) or 0 - nb_tokens_thinking = row.get("Thinking Tokens", 0) or 0 - wandb_run.log( - { - "episode/moves": moves, - "episode/score": score, - "episode/highscore": highscore, - "episode/normalized_score": norm_score, - "episode/normalized_highscore": norm_highscore, - "episode/token_usage": nb_tokens, - "episode/token_usage_thinking": nb_tokens_thinking, - }, - step=step, - ) - - # Store results with original token usage from the trajectory. - # fmt: off - results.append([ - step, score, max_score, norm_score, moves, - prev_obs, action, feedback, - row.get("Prompt", ""), row.get("Response", ""), row.get("Thinking"), - row.get("Token Usage", 0) or 0, row.get("Prompt Tokens", 0) or 0, - row.get("Response Tokens", 0) or 0, row.get("Thinking Tokens", 0) or 0, - ]) - # fmt: on - - if not done: - log.debug(obs) - - if done: - if info["won"]: - nb_wins += 1 - if highscore == max_score: - log.debug(obs) - # Don't break during replay; continue replaying. - # Don't reset either — the original run broke here. - continue - elif info["lost"]: - nb_losts += 1 - - # Reset the game just like the original run did. - last_obs = obs - obs, info = env.reset() - obs = last_obs + "\n\n-= Restarting =-\n" + obs - agent.reset(obs, info, env_name) - nb_resets += 1 - - log.debug(f"{obs}") - - replay_pbar.close() - - if highscore == max_score: - log.info( - colored( - f"Replay complete: game already won with max score ({highscore}/{max_score}). " - f"No further steps needed.", - "green", - ) - ) - # Skip the LLM loop entirely. - replay_steps = args.nb_steps - else: - log.info( - colored( - f"Replay complete: {replay_steps} steps, score={score}, highscore={highscore}. " - f"LLM takes over from step {replay_steps + 1}.", - "cyan", - ) - ) + # Play phase (LLM-driven). + status = "running" try: - - pbar = tqdm( - range(replay_steps + 1, args.nb_steps + 1), - initial=replay_steps, - total=args.nb_steps, - desc=f" {env_name}", - unit="steps", - leave=False, + play_with_agent( + env, + agent, + state, + wandb_run, + args, + env_name, + start_time, + start_step=replay_steps + 1, ) - for step in pbar: - pbar.set_postfix_str( - f"Score: {info['score']}/{info['max_score']} ({info['score']/info['max_score']:.1%})" - ) - action, stats = agent.act(obs, score, done, info) - log.debug(colored(f"> {action}", "green")) - - if args.debug: - breakpoint() - - prev_obs = obs - - # Force one action per step. - if "\n" in action.strip(): - obs = "The game only allows one action per step." - else: - obs, _, done, info = env.step(action) - - score = info["score"] - moves = info["moves"] - feedback = info["feedback"] - norm_score = score / max_score - highscore = max(score, highscore) - norm_highscore = highscore / max_score - - if ( - args.admissible_commands - and info["admissible_commands"] - and action not in info["admissible_commands"] - ): - nb_invalid_actions += 1 - - msg = "{:5d}. Time: {:9.2f}\tScore: {:3d}\tMove: {:5d}\tAction: {:20s}" - msg = msg.format(step, time.time() - start_time, score, moves, action) - log.info(msg) - - wandb_run.log( - { - "episode/moves": moves, - "episode/score": score, - "episode/highscore": highscore, - "episode/normalized_score": norm_score, - "episode/normalized_highscore": norm_highscore, - "episode/token_usage": stats["nb_tokens"], - "episode/token_usage_thinking": stats.get("nb_tokens_thinking", 0), - }, - step=step, - ) - - # fmt: off - results.append([ - step, score, max_score, norm_score, moves, - prev_obs, action, feedback, - stats["prompt"], stats["response"], stats.get("thinking"), - stats["nb_tokens"], stats["nb_tokens_prompt"], stats["nb_tokens_response"], stats.get("nb_tokens_thinking", 0), - ]) - # fmt: on - - if not done: - log.debug(obs) - - if done: - if info["won"]: - nb_wins += 1 - if highscore == max_score: - log.debug(obs) - break # No reason to play that game more. - elif info["lost"]: - nb_losts += 1 - - # Replay the game in the hope of achieving a better score. - last_obs = obs - obs, info = env.reset() - obs = last_obs + "\n\n-= Restarting =-\n" + obs - agent.reset(obs, info, env_name) - nb_resets += 1 - - log.debug(f"{obs}") - status = "finished" except KeyboardInterrupt as e: status = "killed" log.critical(colored(f"{env_name} (killed)", "red")) log.error(str(e)) - time.sleep( - 1 - ) # Give time for the user to issue another ctrl+c to cancel the script. + time.sleep(1) if args.debug: raise @@ -438,16 +476,16 @@ def evaluate(agent, env_name, args): env.close() - stats = { - "nb_steps": step, - "nb_moves": moves, - "nb_invalid_actions": nb_invalid_actions, - "nb_losts": nb_losts, - "nb_wins": nb_wins, - "nb_resets": nb_resets, - "highscore": highscore, - "max_score": max_score, - "norm_score": highscore / max_score, + final_stats = { + "nb_steps": state["step"], + "nb_moves": state["moves"], + "nb_invalid_actions": state["nb_invalid_actions"], + "nb_losts": state["nb_losts"], + "nb_wins": state["nb_wins"], + "nb_resets": state["nb_resets"], + "highscore": state["highscore"], + "max_score": state["max_score"], + "norm_score": state["highscore"] / state["max_score"], "duration": time.time() - start_time, } @@ -459,28 +497,28 @@ def evaluate(agent, env_name, args): "Token Usage", "Prompt Tokens", "Response Tokens", "Thinking Tokens", ] # fmt: on - df = pd.DataFrame(results, columns=columns) + df = pd.DataFrame(state["results"], columns=columns) df.to_json(rollouts_file, orient="records", lines=True) wandb_stats = { - "total/Env. Steps": stats["nb_steps"], - "total/Game Moves": stats["nb_moves"], - "total/Invalid Actions": stats["nb_invalid_actions"], - "total/Losts": stats["nb_losts"], - "total/Wins": stats["nb_wins"], - "total/Resets": stats["nb_resets"], + "total/Env. Steps": final_stats["nb_steps"], + "total/Game Moves": final_stats["nb_moves"], + "total/Invalid Actions": final_stats["nb_invalid_actions"], + "total/Losts": final_stats["nb_losts"], + "total/Wins": final_stats["nb_wins"], + "total/Resets": final_stats["nb_resets"], "total/Tokens": df["Token Usage"].sum(), "total/Prompt Tokens": df["Prompt Tokens"].sum(), "total/Response Tokens": df["Response Tokens"].sum(), "total/Thinking Tokens": df["Thinking Tokens"].sum(), - "final/Highscore": stats["highscore"], - "final/Game Max Score": stats["max_score"], - "final/Normalized Score": stats["norm_score"], - "final/Duration": stats["duration"], + "final/Highscore": final_stats["highscore"], + "final/Game Max Score": final_stats["max_score"], + "final/Normalized Score": final_stats["norm_score"], + "final/Duration": final_stats["duration"], } wandb_run.log( {"episode/rollout": wandb.Table(dataframe=df), **wandb_stats}, - step=stats["nb_steps"], + step=final_stats["nb_steps"], ) # Save summary. @@ -490,7 +528,7 @@ def evaluate(agent, env_name, args): "env_params": env_params, "wandb_run_id": wandb_run.id, "wandb_url": wandb_run.url, - **stats, + **final_stats, **wandb_stats, } From 138e13288e7decf3505252c9b4849283f8131ebf Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Fri, 27 Mar 2026 11:53:59 -0700 Subject: [PATCH 3/8] Update --continue-from docs: cover truncation and early stop Document that trajectories are truncated to --nb-steps when the original run is longer, and that replay stops early when the game was already completed (max score reached). Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 21 ++++++++++++++------- 1 file changed, 14 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 6678c9b..0c3079c 100644 --- a/README.md +++ b/README.md @@ -77,7 +77,12 @@ In order to benchmark a given LLM acting as language agent playing text-based ga ### Continuing a Previous Run -If you have a previous run that was limited to N steps (e.g., 100), you can extend it to more steps (e.g., 1000) using the `--continue-from` flag. This replays the original trajectory without making LLM calls, then lets the LLM take over for the remaining steps. +Use the `--continue-from` flag to replay a previous wandb-logged trajectory. The replay is deterministic (no LLM calls) and preserves the original token usage stats. This is useful for: + +- **Extending** a short run (e.g., 100 steps → 1000): replays the original trajectory, then lets the LLM take over for the remaining steps. +- **Reproducing** a previous run exactly: set `--nb-steps` equal to or less than the original run's length to replay without any new LLM calls. + +When auto-finding, the longest matching run is always selected and truncated to `--nb-steps` if needed. **With an explicit run ID:** @@ -88,12 +93,14 @@ If you have a previous run that was limited to N steps (e.g., 100), you can exte python benchmark.py reasoning --llm gpt-4o --conversation --continue-from --nb-steps 1000 --envs JerichoEnvZork1 --wandb **How it works:** -1. Fetches the original run's config and rollout from the wandb API (or auto-finds one) -2. Recreates the environment with the same seed for deterministic replay -3. Replays all recorded actions (no LLM calls, preserving original token usage stats) -4. Verifies observations match the logged trajectory (warns on divergence) -5. Hands off to the LLM agent for the remaining steps -6. Logs as a new wandb run referencing the original run ID +1. Fetches the original run's config and rollout from the wandb API (or auto-finds the longest matching run) +2. Truncates the trajectory to `--nb-steps` if the original run is longer +3. Recreates the environment with the same seed for deterministic replay +4. Replays recorded actions (no LLM calls, preserving original token usage stats) +5. Verifies observations match the logged trajectory (warns on divergence) +6. If the game completed during replay (max score reached), stops early +7. Otherwise, hands off to the LLM agent for any remaining steps +8. Logs as a new wandb run referencing the original run ID > [!NOTE] > The `--continue-from` flag expects a wandb run ID (e.g., `abc123de`) from the `pearls-lab/text-games-benchmark` project, or no value to auto-find. The agent type and parameters must match the original run. When auto-finding, if no matching run is found, the game runs from scratch. From d2da59392c06d94fffcecfe2866a106c7068e647 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Apr 2026 09:30:16 -0700 Subject: [PATCH 4/8] Update model lists and token counter for new GPT-5.x models - Add gpt-5.3, gpt-5.4 to OPENAI_MODELS in reasoning agent - Remove claude-haiku-4.5 from CLAUDE_MODELS (not a reasoning model) - Simplify token counter: use startswith for gpt-5.x and gpt-4.1.x - Fix wandb duplicate-run check: always check (not just when force_all is off), add project path, exclude 'without-help' tag Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- agents/reasoning.py | 3 ++- benchmark.py | 10 ++++++++-- tales/token.py | 14 ++++---------- 3 files changed, 14 insertions(+), 13 deletions(-) diff --git a/agents/reasoning.py b/agents/reasoning.py index 7dfff88..870e07f 100644 --- a/agents/reasoning.py +++ b/agents/reasoning.py @@ -33,7 +33,6 @@ "claude-4-sonnet", "claude-4-opus", "claude-sonnet-4.5", - "claude-haiku-4.5", "claude-opus-4.5", "claude-opus-4.6", "claude-sonnet-4.6", @@ -48,6 +47,8 @@ "o3", "gpt-5.1", "gpt-5.2", + "gpt-5.3", + "gpt-5.4", "gpt-5", "gpt-5-mini", "gpt-5-nano", diff --git a/benchmark.py b/benchmark.py index e10f780..dc045db 100644 --- a/benchmark.py +++ b/benchmark.py @@ -355,10 +355,16 @@ def evaluate(agent, env_name, args): return summary run_name = f"{env_name} - {agent.uid}" - if args.wandb and not args.force_all: + if args.wandb: # and not args.force_all: # Check if there already exists a run with the same name using Wandb API. wandb_api = wandb.Api() - wandb_runs = wandb_api.runs(filters={"display_name": run_name}) + wandb_runs = wandb_api.runs( + "pearls-lab/text-games-benchmark", + filters={ + "display_name": run_name, + "tags": {"$ne": "without-help"}, + }, + ) if wandb_runs: wandb_run = wandb_runs[0] log.info(f"Previous evaluation found: {wandb_run.url} ({wandb_run.state})") diff --git a/tales/token.py b/tales/token.py index 2aeab78..9378c28 100644 --- a/tales/token.py +++ b/tales/token.py @@ -46,17 +46,11 @@ def __init__(self, model: str): self.model = model if self.model in tiktoken.model.MODEL_TO_ENCODING: self.tokenize = tiktoken.encoding_for_model(self.model).encode - elif self.model in ( - "o4-mini", - "o3", - "gpt-5", - "gpt-5-mini", - "gpt-5-nano", - "gpt-5.1", - "gpt-5.2", - ): + elif self.model.startswith("gpt-5"): + self.tokenize = tiktoken.encoding_for_model("gpt-5").encode + elif self.model in ("o4-mini", "o3"): self.tokenize = tiktoken.encoding_for_model("o3-mini").encode - elif self.model in ("gpt-4.1", "gpt-4.1-nano", "gpt-4.1-mini"): + elif self.model.startswith("gpt-4.1"): self.tokenize = tiktoken.encoding_for_model("gpt-4o").encode else: self.tokenize = tiktoken.encoding_for_model(self.model.split("_")[0]).encode From dbbee55cc9d083c5155300119d73efa7fb7f40df Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Apr 2026 10:00:44 -0700 Subject: [PATCH 5/8] Read WANDB_PROJECT from environment variable with fallback Default is 'tales'; entity resolved by wandb from the logged-in user. Removes hardcoded org/entity references from the codebase. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- README.md | 2 +- benchmark.py | 12 +++--------- tales/wandb_utils.py | 3 ++- 3 files changed, 6 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 0c3079c..588eef6 100644 --- a/README.md +++ b/README.md @@ -103,7 +103,7 @@ When auto-finding, the longest matching run is always selected and truncated to 8. Logs as a new wandb run referencing the original run ID > [!NOTE] -> The `--continue-from` flag expects a wandb run ID (e.g., `abc123de`) from the `pearls-lab/text-games-benchmark` project, or no value to auto-find. The agent type and parameters must match the original run. When auto-finding, if no matching run is found, the game runs from scratch. +> The `--continue-from` flag expects a wandb run ID (e.g., `abc123de`) from your wandb project, or no value to auto-find. The agent type and parameters must match the original run. When auto-finding, if no matching run is found, the game runs from scratch. ### API-based LLMs diff --git a/benchmark.py b/benchmark.py index dc045db..90aaa07 100644 --- a/benchmark.py +++ b/benchmark.py @@ -355,16 +355,10 @@ def evaluate(agent, env_name, args): return summary run_name = f"{env_name} - {agent.uid}" - if args.wandb: # and not args.force_all: + if args.wandb and not args.force_all: # Check if there already exists a run with the same name using Wandb API. wandb_api = wandb.Api() - wandb_runs = wandb_api.runs( - "pearls-lab/text-games-benchmark", - filters={ - "display_name": run_name, - "tags": {"$ne": "without-help"}, - }, - ) + wandb_runs = wandb_api.runs(filters={"display_name": run_name}) if wandb_runs: wandb_run = wandb_runs[0] log.info(f"Previous evaluation found: {wandb_run.url} ({wandb_run.state})") @@ -406,7 +400,7 @@ def evaluate(agent, env_name, args): wandb_config["continued_from_run_url"] = original_config["_run_url"] wandb_config["replay_steps"] = len(trajectory_df) wandb_run = wandb.init( - project="tales", + project=os.environ.get("WANDB_PROJECT", "tales"), config=wandb_config, reinit=True, name=run_name, diff --git a/tales/wandb_utils.py b/tales/wandb_utils.py index b436a17..5dd57ec 100644 --- a/tales/wandb_utils.py +++ b/tales/wandb_utils.py @@ -1,5 +1,6 @@ import json import logging +import os import tempfile import pandas as pd @@ -7,7 +8,7 @@ log = logging.getLogger("tales") -WANDB_PROJECT = "pearls-lab/text-games-benchmark" +WANDB_PROJECT = os.environ.get("WANDB_PROJECT", "tales") ROLLOUT_COLUMNS = [ "Step", From d87862f60ea461b57142fff3a7b39c14ae224ed4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Apr 2026 10:28:18 -0700 Subject: [PATCH 6/8] Use shared WANDB_PROJECT constant from wandb_utils in benchmark Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- benchmark.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/benchmark.py b/benchmark.py index 90aaa07..2a15922 100644 --- a/benchmark.py +++ b/benchmark.py @@ -19,7 +19,7 @@ import tales from tales.logger import log, setup_logging from tales.utils import NumpyEncoder -from tales.wandb_utils import fetch_run_trajectory, find_matching_run +from tales.wandb_utils import WANDB_PROJECT, fetch_run_trajectory, find_matching_run os.environ["WANDB_MODE"] = "disabled" @@ -400,7 +400,7 @@ def evaluate(agent, env_name, args): wandb_config["continued_from_run_url"] = original_config["_run_url"] wandb_config["replay_steps"] = len(trajectory_df) wandb_run = wandb.init( - project=os.environ.get("WANDB_PROJECT", "tales"), + project=WANDB_PROJECT, config=wandb_config, reinit=True, name=run_name, From d27d0e3da67fdf1d464b5e849ed1af008c5f1baa Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Apr 2026 09:45:47 -0700 Subject: [PATCH 7/8] Add TALES score vs. step budget analysis script Fetches run data from wandb, computes average normalized score (TALES metric) at configurable budget intervals, and generates line plots with per-framework breakdowns. Features: - Per-step history caching (reusable across budget intervals) - Incremental cache updates (only downloads new runs) - Per-(model, budget) completeness filtering - --max-steps: filter by one or more max_steps values - --budget-step: configurable budget interval - --continuous: smooth curves at every step (with merge_asof optimization) - --log-x: logarithmic x-axis scale Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- analysis/score_vs_budget.py | 527 ++++++++++++++++++++++++++++++++++++ 1 file changed, 527 insertions(+) create mode 100644 analysis/score_vs_budget.py diff --git a/analysis/score_vs_budget.py b/analysis/score_vs_budget.py new file mode 100644 index 0000000..fad1d19 --- /dev/null +++ b/analysis/score_vs_budget.py @@ -0,0 +1,527 @@ +"""Average normalized score vs. step budget across all games. + +For each model, takes runs with the specified max_steps, reads the +per-step normalized highscore, and plots TALES score at configurable +budget intervals. + +Usage: + python analysis/score_vs_budget.py [--cache analysis/data.csv] [--max-steps 300 400] [--budget-step 25] +""" + +import argparse +import sys +from pathlib import Path + +import matplotlib.pyplot as plt +import pandas as pd +import wandb + +WANDB_PROJECT = "pearls-lab/text-games-benchmark" + +FRAMEWORK_LABELS = { + "jericho": "Jericho", + "textworld": "TextWorld", + "textworld_express": "TextWorldExpress", + "alfworld": "ALFWorld", + "scienceworld": "ScienceWorld", +} + + +def _infer_framework(game: str) -> str: + """Infer framework from game name prefix (fallback for cached data).""" + if game.startswith("Jericho"): + return "Jericho" + if game.startswith("TWX"): + return "TextWorldExpress" + if game.startswith("TW"): + return "TextWorld" + if game.startswith("ALFWorld"): + return "ALFWorld" + if game.startswith("ScienceWorld"): + return "ScienceWorld" + return "Unknown" + + +def fetch_runs(max_steps_list: list[int]) -> list: + """Fetch runs, preferring the longest max_steps per (model, game, seed). + + Queries from largest to smallest max_steps. Once a (model, game, seed) + tuple is covered by a longer run, shorter runs for that tuple are skipped. + """ + api = wandb.Api() + seen_keys: set[tuple] = set() + selected: list = [] + + for max_steps in sorted(max_steps_list, reverse=True): + filters = {"config.max_steps": max_steps, "state": "finished"} + print(f"Querying wandb for runs with max_steps={max_steps}...") + runs = list(api.runs(WANDB_PROJECT, filters=filters, order="-created_at")) + print(f" Found {len(runs)} runs.") + + skipped = 0 + for r in runs: + key = ( + r.config.get("llm", "unknown"), + r.config.get("game", "unknown"), + r.config.get("seed"), + ) + if key in seen_keys: + skipped += 1 + continue + seen_keys.add(key) + selected.append(r) + + if skipped: + print(f" Skipped {skipped} run(s) already covered by longer runs.") + + print(f"Total: {len(selected)} unique runs to process.") + return selected + + +def build_history_table(runs: list) -> pd.DataFrame: + """Download per-step normalized highscore for each run.""" + records = [] + for i, run in enumerate(runs): + model = run.config.get("llm", "unknown") + game = run.config.get("game", "unknown") + seed = run.config.get("seed") + run_max_steps = run.config.get("max_steps", 0) + framework = run.config.get("framework", "unknown") + framework = FRAMEWORK_LABELS.get(framework, framework) + print( + f" [{i + 1}/{len(runs)}] {game} / {model} (seed={seed}, max_steps={run_max_steps})" + ) + + history = run.history(keys=["episode/normalized_highscore"], pandas=True) + if history.empty: + continue + + history = history.dropna(subset=["episode/normalized_highscore"]) + history = history.sort_values("_step") + + for _, row in history.iterrows(): + records.append( + { + "run_id": run.id, + "model": model, + "game": game, + "framework": framework, + "seed": seed, + "max_steps": run_max_steps, + "step": int(row["_step"]), + "normalized_score": row["episode/normalized_highscore"], + } + ) + + return pd.DataFrame(records) + + +def compute_budgets(history_df: pd.DataFrame, budgets: list[int]) -> pd.DataFrame: + """From per-step history, compute the normalized highscore at each budget cutoff. + + Uses ``merge_asof`` to map each budget to the most recent step per run, + avoiding a per-budget inner loop. + """ + run_cols = ["run_id", "model", "game", "framework", "seed", "max_steps"] + history_df = history_df.copy() + history_df["step"] = history_df["step"].astype(int) + history_df.sort_values(["run_id", "step"], inplace=True) + + budget_s = pd.Series(budgets, name="budget").sort_values() + + # For each run, merge_asof finds the latest step <= each budget. + parts = [] + for key, grp in history_df.groupby(run_cols, sort=False): + run_id, model, game, framework, seed, max_steps = key + max_steps = int(max_steps) + run_budgets = budget_s[budget_s <= max_steps].reset_index(drop=True) + if run_budgets.empty: + continue + + bdf = run_budgets.to_frame() + merged = pd.merge_asof( + bdf, grp[["step", "normalized_score"]], left_on="budget", right_on="step" + ) + merged["normalized_score"] = merged["normalized_score"].fillna(0.0) + merged["run_id"] = run_id + merged["model"] = model + merged["game"] = game + merged["framework"] = framework + merged["seed"] = seed + parts.append( + merged[ + [ + "run_id", + "model", + "game", + "framework", + "seed", + "budget", + "normalized_score", + ] + ] + ) + + return pd.concat(parts, ignore_index=True) if parts else pd.DataFrame() + + +def plot_budget( + df: pd.DataFrame, output: Path, continuous: bool = False, log_x: bool = False +) -> None: + """Line plot: TALES metric (avg normalized score across all games per seed) vs. step budget.""" + models = sorted(df["model"].unique()) + cmap = plt.get_cmap("tab10") + + fig, ax = plt.subplots(figsize=(10, 6)) + + # TALES metric: for each (model, seed, budget), average normalized_score across games. + tales = ( + df.groupby(["model", "seed", "budget"])["normalized_score"].mean().reset_index() + ) + tales.rename(columns={"normalized_score": "tales_score"}, inplace=True) + + for i, model in enumerate(models): + mdf = tales[tales["model"] == model] + # Aggregate across seeds (mean ± std) for each budget. + agg = mdf.groupby("budget")["tales_score"].agg(["mean", "std", "count"]) + agg = agg.sort_index().dropna(subset=["mean"]) + agg["std"] = agg["std"].fillna(0) + + if agg.empty: + continue + + ax.plot( + agg.index, + agg["mean"], + marker=None if continuous else "o", + label=model, + color=cmap(i), + linewidth=1.5 if continuous else 2, + ) + if (agg["count"] > 1).any(): + ax.fill_between( + agg.index, + agg["mean"] - agg["std"], + agg["mean"] + agg["std"], + alpha=0.12, + color=cmap(i), + ) + + # Annotate final point with the score value. + last = agg.iloc[-1] + ax.annotate( + f"{last['mean']:.3f}", + (agg.index[-1], last["mean"]), + textcoords="offset points", + xytext=(8, 0), + fontsize=9, + color=cmap(i), + ) + + budgets = sorted(df["budget"].unique()) + n_games = df["game"].nunique() + + ax.set_xlabel("Step Budget", fontsize=12) + ax.set_ylabel("TALES Score", fontsize=12) + ax.set_title( + f"TALES Score vs. Step Budget\n(avg. normalized score across {n_games} games)", + fontsize=14, + ) + if log_x: + ax.set_xscale("log") + ax.xaxis.set_major_formatter(plt.ScalarFormatter()) + if not continuous: + ax.set_xticks(budgets) + ax.set_ylim(bottom=0) + ax.grid(True, alpha=0.3) + ax.legend(fontsize=10) + fig.tight_layout() + + output.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(output, dpi=150, bbox_inches="tight") + print(f"Saved plot to {output}") + + +def print_table(df: pd.DataFrame) -> None: + """Print TALES scores per model and step budget.""" + tales = ( + df.groupby(["model", "seed", "budget"])["normalized_score"].mean().reset_index() + ) + pivot = ( + tales.groupby(["model", "budget"])["normalized_score"].mean().unstack("budget") + ) + budgets = sorted(df["budget"].unique()) + pivot = pivot.reindex(columns=budgets) + print("\nTALES Score per Model and Step Budget:") + print(pivot.round(4).to_string()) + print() + + +def plot_per_framework( + df: pd.DataFrame, output: Path, continuous: bool = False, log_x: bool = False +) -> None: + """One subplot per framework, TALES-style metric (avg across games) vs. budget.""" + frameworks = sorted(df["framework"].unique()) + models = sorted(df["model"].unique()) + cmap = plt.get_cmap("tab10") + model_colors = {m: cmap(i) for i, m in enumerate(models)} + + n_fw = len(frameworks) + cols = min(3, n_fw) + rows = (n_fw + cols - 1) // cols + fig, axes = plt.subplots(rows, cols, figsize=(6 * cols, 4.5 * rows), squeeze=False) + + budgets = sorted(df["budget"].unique()) + + for idx, fw in enumerate(frameworks): + ax = axes[idx // cols][idx % cols] + fdf = df[df["framework"] == fw] + n_games = fdf["game"].nunique() + has_data = False + + for model in models: + mdf = fdf[fdf["model"] == model] + if mdf.empty: + continue + + tales = ( + mdf.groupby(["seed", "budget"])["normalized_score"].mean().reset_index() + ) + agg = tales.groupby("budget")["normalized_score"].agg( + ["mean", "std", "count"] + ) + agg = agg.sort_index().reindex(budgets).dropna(subset=["mean"]) + agg["std"] = agg["std"].fillna(0) + + if agg.empty: + continue + + has_data = True + ax.plot( + agg.index, + agg["mean"], + marker=None if continuous else "o", + label=model, + color=model_colors[model], + linewidth=1.2 if continuous else 1.8, + ) + if (agg["count"] > 1).any(): + ax.fill_between( + agg.index, + agg["mean"] - agg["std"], + agg["mean"] + agg["std"], + alpha=0.12, + color=model_colors[model], + ) + + ax.set_title(f"{fw} ({n_games} games)", fontsize=11) + ax.set_xlabel("Step Budget") + ax.set_ylabel("Avg. Normalized Score") + if log_x: + ax.set_xscale("log") + ax.xaxis.set_major_formatter(plt.ScalarFormatter()) + elif has_data and not continuous: + ax.set_xticks(budgets) + ax.set_ylim(-0.05, 1.05) + ax.grid(True, alpha=0.3) + + for idx in range(n_fw, rows * cols): + axes[idx // cols][idx % cols].set_visible(False) + + handles, labels = axes[0][0].get_legend_handles_labels() + if handles: + fig.legend( + handles, + labels, + loc="lower center", + ncol=min(len(models), 5), + fontsize=9, + ) + + fig.suptitle( + "TALES Score vs. Step Budget — by Framework", + fontsize=14, + y=1.01, + ) + fig.tight_layout() + + fw_path = output.with_stem(output.stem + "_by_framework") + fw_path.parent.mkdir(parents=True, exist_ok=True) + fig.savefig(fw_path, dpi=150, bbox_inches="tight") + print(f"Saved per-framework plot to {fw_path}") + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--max-steps", + type=int, + nargs="+", + default=[300], + help="Filter runs by config.max_steps (one or more values, default: 300). " + "Budget range goes up to the largest value.", + ) + parser.add_argument( + "--budget-step", + type=int, + default=50, + help="Step interval for budget cutoffs (default: 50). " + "Ignored when --continuous is set.", + ) + parser.add_argument( + "--continuous", + action="store_true", + help="Plot a smooth curve using every step instead of discrete budget points.", + ) + parser.add_argument( + "--log-x", + action="store_true", + help="Use logarithmic scale for the x-axis (step budget).", + ) + parser.add_argument( + "--output", + type=Path, + default=None, + help="Output path for the plot (default: analysis/score_vs_budget_{max_steps}.png)", + ) + parser.add_argument( + "--cache", + type=Path, + default=None, + help="Cache per-step history to a CSV file. The cache stores raw per-step " + "data and is reusable across different --budget-step values.", + ) + return parser.parse_args() + + +def main() -> None: + args = parse_args() + + max_budget = max(args.max_steps) + step = 1 if args.continuous else args.budget_step + budgets = list(range(step, max_budget + 1, step)) + if args.continuous: + label = "continuous" + else: + steps_label = "_".join(str(s) for s in sorted(args.max_steps)) + label = f"{steps_label}_step{args.budget_step}" + output = args.output or Path(f"analysis/score_vs_budget_{label}.png") + + # --- Load / update per-step history cache --- + if args.cache and args.cache.exists(): + print(f"Loading cached history from {args.cache}") + history_df = pd.read_csv(args.cache, low_memory=False) + else: + history_df = pd.DataFrame() + + # Always query wandb for the current run list and fetch any missing runs. + # fetch_runs returns only the best (longest) run per (model, game, seed). + runs = fetch_runs(max_steps_list=args.max_steps) + if not runs and history_df.empty: + print("No runs found. Exiting.", file=sys.stderr) + sys.exit(1) + + selected_ids = {r.id for r in runs} + cached_ids = ( + set(history_df["run_id"].unique()) if "run_id" in history_df.columns else set() + ) + + # Remove cached runs not in the selected set (stale or superseded). + obsolete_ids = cached_ids - selected_ids + # Also remove runs with old cache format (NaN step values). + if "step" in history_df.columns: + old_fmt = set(history_df.loc[history_df["step"].isna(), "run_id"].unique()) + obsolete_ids |= old_fmt + if obsolete_ids: + reason = "superseded/stale/old-format" + print(f" Removing {len(obsolete_ids)} cached run(s) ({reason}).") + history_df = history_df[~history_df["run_id"].isin(obsolete_ids)] + cached_ids -= obsolete_ids + + new_runs = [r for r in runs if r.id not in cached_ids] + + if new_runs: + print( + f" Fetching {len(new_runs)} new run(s) " + f"(skipping {len(runs) - len(new_runs)} cached)..." + ) + new_df = build_history_table(new_runs) + history_df = ( + pd.concat([history_df, new_df], ignore_index=True) + if not history_df.empty + else new_df + ) + + if history_df.empty: + print("No score data found in runs. Exiting.", file=sys.stderr) + sys.exit(1) + + cache_changed = bool(new_runs or obsolete_ids) + if args.cache and (cache_changed or not args.cache.exists()): + args.cache.parent.mkdir(parents=True, exist_ok=True) + history_df.to_csv(args.cache, index=False) + print(f"{'Updated' if cache_changed else 'Saved'} cache at {args.cache}") + + # --- Compute budget table from per-step history --- + if args.continuous: + print("\nComputing budgets at every step (continuous)...") + else: + print(f"\nComputing budgets at step interval {args.budget_step}...") + df = compute_budgets(history_df, budgets) + + n_models = df["model"].nunique() + n_games = df["game"].nunique() + print(f"Data: {n_models} models, {n_games} games, budgets={budgets}") + + # Ensure framework column exists and has no missing values. + if "framework" not in df.columns: + df["framework"] = df["game"].map(_infer_framework) + else: + mask = df["framework"].isna() + if mask.any(): + df.loc[mask, "framework"] = df.loc[mask, "game"].map(_infer_framework) + + # Drop (model, budget) pairs where the model hasn't completed all games. + games_per_mb = ( + df.groupby(["model", "budget"])["game"].nunique().reset_index(name="n_games") + ) + incomplete_mb = games_per_mb[games_per_mb["n_games"] < n_games] + + if not incomplete_mb.empty: + total_budgets_per_model = games_per_mb.groupby("model").size() + incomplete_budgets_per_model = incomplete_mb.groupby("model").size() + + for model in incomplete_budgets_per_model.index: + if incomplete_budgets_per_model[model] == total_budgets_per_model[model]: + max_count = int( + games_per_mb.loc[games_per_mb["model"] == model, "n_games"].max() + ) + print( + f" ⚠ Dropping {model}: only {max_count}/{n_games} games completed" + ) + elif args.continuous: + n_inc = int(incomplete_budgets_per_model[model]) + print(f" ⚠ Skipping {model} at {n_inc} incomplete budget point(s)") + else: + rows = incomplete_mb[incomplete_mb["model"] == model] + for _, row in rows.iterrows(): + print( + f" ⚠ Skipping {model} at budget {int(row['budget'])}: " + f"only {int(row['n_games'])}/{n_games} games" + ) + + drop_keys = set(zip(incomplete_mb["model"], incomplete_mb["budget"])) + df = df[~df.apply(lambda r: (r["model"], r["budget"]) in drop_keys, axis=1)] + + if df.empty: + print("No complete data remaining. Exiting.", file=sys.stderr) + sys.exit(1) + + if not args.continuous: + print_table(df) + plot_budget(df, output=output, continuous=args.continuous, log_x=args.log_x) + plot_per_framework(df, output=output, continuous=args.continuous, log_x=args.log_x) + + +if __name__ == "__main__": + main() From d2501284da1dedfee47289c15cf501c45e942941 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marc-Alexandre=20C=C3=B4t=C3=A9?= Date: Wed, 1 Apr 2026 09:45:56 -0700 Subject: [PATCH 8/8] Add script to build HuggingFace dataset from wandb trajectories Downloads JSONL rollout files from wandb runs, enriches each step with run metadata, and saves as Parquet or JSONL for HuggingFace. Features: - Filtering by --models, --frameworks, --games, --max-steps, --seeds - Deduplication (longest run per model/game/seed tuple) - Incremental builds via --cache (only fetches new runs) - --dry-run to preview matching runs - --format parquet|jsonl Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --- analysis/build_hf_dataset.py | 353 +++++++++++++++++++++++++++++++++++ 1 file changed, 353 insertions(+) create mode 100644 analysis/build_hf_dataset.py diff --git a/analysis/build_hf_dataset.py b/analysis/build_hf_dataset.py new file mode 100644 index 0000000..39eab92 --- /dev/null +++ b/analysis/build_hf_dataset.py @@ -0,0 +1,353 @@ +"""Build a Hugging Face dataset from TALES wandb trajectories. + +Downloads JSONL rollout files from wandb runs, enriches each step with +run metadata (model, game, framework, seed, etc.), and saves the result +as a Parquet dataset ready for ``datasets.load_dataset()``. + +Usage: + python analysis/build_hf_dataset.py --output analysis/tales_dataset + python analysis/build_hf_dataset.py --models claude-opus-4.6 gpt-5.4-mini --max-steps 300 400 + python analysis/build_hf_dataset.py --frameworks jericho scienceworld --games JerichoEnvZork1 + python analysis/build_hf_dataset.py --cache analysis/hf_cache.json --output analysis/tales_dataset +""" + +import argparse +import json +import sys +import tempfile +from pathlib import Path + +import pandas as pd +import wandb + +WANDB_PROJECT = "pearls-lab/text-games-benchmark" + +METADATA_COLUMNS = [ + "run_id", + "model", + "game", + "framework", + "agent_type", + "seed", + "game_seed", + "max_steps", +] + +ROLLOUT_COLUMNS = [ + "Step", + "Score", + "Max Score", + "Normalized Score", + "Moves", + "Observation", + "Action", + "Feedback", + "Prompt", + "Response", + "Thinking", + "Token Usage", + "Prompt Tokens", + "Response Tokens", + "Thinking Tokens", +] + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser( + description=__doc__, formatter_class=argparse.RawDescriptionHelpFormatter + ) + parser.add_argument( + "--output", + type=Path, + default=Path("analysis/tales_dataset"), + help="Output directory for the dataset (default: analysis/tales_dataset).", + ) + parser.add_argument( + "--format", + choices=["parquet", "jsonl"], + default="parquet", + help="Output format (default: parquet).", + ) + + filt = parser.add_argument_group("filtering") + filt.add_argument( + "--models", + nargs="+", + default=None, + help="Only include these models (e.g., claude-opus-4.6 gpt-5.4-mini).", + ) + filt.add_argument( + "--frameworks", + nargs="+", + default=None, + help="Only include these frameworks (e.g., jericho textworld scienceworld).", + ) + filt.add_argument( + "--games", + nargs="+", + default=None, + help="Only include these game names (e.g., JerichoEnvZork1 ScienceWorldBoil).", + ) + filt.add_argument( + "--max-steps", + type=int, + nargs="+", + default=None, + help="Only include runs with these max_steps values (e.g., 300 400).", + ) + filt.add_argument( + "--seeds", + type=int, + nargs="+", + default=None, + help="Only include runs with these seed values.", + ) + + parser.add_argument( + "--cache", + type=Path, + default=None, + help="JSON file tracking which run IDs have been downloaded. " + "Enables incremental builds — only new runs are fetched.", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="List matching runs without downloading trajectories.", + ) + return parser.parse_args() + + +def fetch_runs(args: argparse.Namespace) -> list: + """Fetch finished wandb runs matching the provided filters. + + Prefers the longest run per (model, game, seed) tuple. + """ + api = wandb.Api() + + filters: dict = {"state": "finished"} + if args.max_steps and len(args.max_steps) == 1: + filters["config.max_steps"] = args.max_steps[0] + if args.models and len(args.models) == 1: + filters["config.llm"] = args.models[0] + if args.frameworks and len(args.frameworks) == 1: + filters["config.framework"] = args.frameworks[0] + + print("Querying wandb for runs...") + all_runs = list(api.runs(WANDB_PROJECT, filters=filters, order="-created_at")) + print(f" Found {len(all_runs)} finished runs (server-side filter).") + + # Client-side filtering for multi-value filters. + max_steps_set = set(args.max_steps) if args.max_steps else None + models_set = set(args.models) if args.models else None + frameworks_set = set(args.frameworks) if args.frameworks else None + games_set = set(args.games) if args.games else None + seeds_set = set(args.seeds) if args.seeds else None + + filtered = [] + for r in all_runs: + cfg = r.config + if max_steps_set and cfg.get("max_steps") not in max_steps_set: + continue + if models_set and cfg.get("llm") not in models_set: + continue + if frameworks_set and cfg.get("framework") not in frameworks_set: + continue + if games_set and cfg.get("game") not in games_set: + continue + if seeds_set and cfg.get("seed") not in seeds_set: + continue + filtered.append(r) + + # Deduplicate: keep longest run per (model, game, seed). + filtered.sort(key=lambda r: r.config.get("max_steps", 0) or 0, reverse=True) + seen: set[tuple] = set() + selected = [] + for r in filtered: + key = ( + r.config.get("llm", "unknown"), + r.config.get("game", "unknown"), + r.config.get("seed"), + ) + if key in seen: + continue + seen.add(key) + selected.append(r) + + print(f" After filtering & dedup: {len(selected)} runs.") + return selected + + +def download_trajectory(run) -> pd.DataFrame | None: + """Download the JSONL rollout from a wandb run and return as DataFrame.""" + rollout_file = None + for f in run.files(): + if f.name.endswith(".jsonl"): + rollout_file = f + break + + if rollout_file is None: + return None + + with tempfile.TemporaryDirectory() as tmpdir: + rollout_file.download(root=tmpdir, replace=True) + filepath = f"{tmpdir}/{rollout_file.name}" + df = pd.read_json(filepath, orient="records", lines=True) + + return df + + +def build_dataset(runs: list, cached_ids: set[str]) -> pd.DataFrame: + """Download trajectories for all runs and build a single DataFrame.""" + parts = [] + skipped_cache = 0 + skipped_no_data = 0 + + for i, run in enumerate(runs): + cfg = run.config + model = cfg.get("llm", "unknown") + game = cfg.get("game", "unknown") + seed = cfg.get("seed") + max_steps = cfg.get("max_steps", 0) + + if run.id in cached_ids: + skipped_cache += 1 + continue + + print( + f" [{i + 1}/{len(runs)}] {game} / {model} " + f"(seed={seed}, max_steps={max_steps})" + ) + + traj = download_trajectory(run) + if traj is None or traj.empty: + print(f" ⚠ No trajectory found, skipping.") + skipped_no_data += 1 + continue + + # Attach metadata columns. + traj["run_id"] = run.id + traj["model"] = model + traj["game"] = game + traj["framework"] = cfg.get("framework", "unknown") + traj["agent_type"] = cfg.get("agent_type", "unknown") + traj["seed"] = seed + traj["game_seed"] = cfg.get("game_seed") + traj["max_steps"] = max_steps + + parts.append(traj) + + if skipped_cache: + print(f" Skipped {skipped_cache} already-cached run(s).") + if skipped_no_data: + print(f" Skipped {skipped_no_data} run(s) with no trajectory data.") + + if not parts: + return pd.DataFrame() + + return pd.concat(parts, ignore_index=True) + + +def load_cache(cache_path: Path | None) -> tuple[set[str], pd.DataFrame | None]: + """Load cache index and any previously saved data.""" + if cache_path is None or not cache_path.exists(): + return set(), None + + with open(cache_path) as f: + cache = json.load(f) + + cached_ids = set(cache.get("run_ids", [])) + data_path = cache.get("data_path") + + prev_df = None + if data_path and Path(data_path).exists(): + print(f"Loading previously cached data from {data_path}") + if data_path.endswith(".parquet"): + prev_df = pd.read_parquet(data_path) + else: + prev_df = pd.read_json(data_path, orient="records", lines=True) + + print(f"Cache: {len(cached_ids)} previously downloaded run(s).") + return cached_ids, prev_df + + +def save_cache(cache_path: Path, run_ids: set[str], data_path: str) -> None: + """Save cache index.""" + cache_path.parent.mkdir(parents=True, exist_ok=True) + with open(cache_path, "w") as f: + json.dump({"run_ids": sorted(run_ids), "data_path": data_path}, f, indent=2) + + +def main() -> None: + args = parse_args() + + runs = fetch_runs(args) + if not runs: + print("No matching runs found. Exiting.", file=sys.stderr) + sys.exit(1) + + if args.dry_run: + print(f"\nDry run — {len(runs)} runs would be downloaded:\n") + summary: dict[str, set] = {} + for r in runs: + model = r.config.get("llm", "unknown") + game = r.config.get("game", "unknown") + summary.setdefault(model, set()).add(game) + for model in sorted(summary): + print(f" {model}: {len(summary[model])} games") + return + + # Load cache for incremental builds. + cached_ids, prev_df = load_cache(args.cache) + + # Download new trajectories. + new_df = build_dataset(runs, cached_ids) + + # Merge with previously cached data. + if prev_df is not None and not prev_df.empty: + # Remove any previously cached runs that are no longer selected + # (e.g., superseded by longer runs or filtered out). + selected_ids = {r.id for r in runs} + prev_df = prev_df[prev_df["run_id"].isin(selected_ids)] + + if not new_df.empty: + df = pd.concat([prev_df, new_df], ignore_index=True) + else: + df = prev_df + else: + df = new_df + + if df.empty: + print("No trajectory data collected. Exiting.", file=sys.stderr) + sys.exit(1) + + # Reorder columns: metadata first, then rollout columns. + all_cols = METADATA_COLUMNS + [c for c in ROLLOUT_COLUMNS if c in df.columns] + extra = [c for c in df.columns if c not in all_cols] + df = df[all_cols + extra] + + # Save dataset. + args.output.mkdir(parents=True, exist_ok=True) + if args.format == "parquet": + out_path = args.output / "tales_trajectories.parquet" + df.to_parquet(out_path, index=False) + else: + out_path = args.output / "tales_trajectories.jsonl" + df.to_json(out_path, orient="records", lines=True) + + # Update cache. + all_ids = cached_ids | set(df["run_id"].unique()) + if args.cache: + save_cache(args.cache, all_ids, str(out_path)) + + n_runs = df["run_id"].nunique() + n_models = df["model"].nunique() + n_games = df["game"].nunique() + n_steps = len(df) + print(f"\n✓ Dataset saved to {out_path}") + print(f" {n_runs} runs, {n_models} models, {n_games} games, {n_steps} steps") + print(f" Size: {out_path.stat().st_size / 1024 / 1024:.1f} MB") + + +if __name__ == "__main__": + main()