diff --git a/src/agentlab/analyze/agent_xray.py b/src/agentlab/analyze/agent_xray.py index 84dc423d..db95a59b 100644 --- a/src/agentlab/analyze/agent_xray.py +++ b/src/agentlab/analyze/agent_xray.py @@ -74,6 +74,7 @@ class EpisodeId: agent_id: str = None task_name: str = None seed: int = None + row_index: int = None # unique row index to disambiguate selections @dataclass @@ -99,6 +100,24 @@ def update_exp_result(self, episode_id: EpisodeId): if self.result_df is None or episode_id.task_name is None or episode_id.seed is None: self.exp_result = None + # Prefer selecting by explicit row index if available + if episode_id.row_index is not None: + tmp_df = self.result_df.reset_index(inplace=False) + tmp_df["_row_index"] = tmp_df.index + sub_df = tmp_df[tmp_df["_row_index"] == episode_id.row_index] + if len(sub_df) == 0: + self.exp_result = None + raise ValueError(f"Could not find episode for row_index: {episode_id.row_index}") + if len(sub_df) > 1: + warning( + f"Found multiple rows for row_index: {episode_id.row_index}. Using the first one." + ) + exp_dir = sub_df.iloc[0]["exp_dir"] + print(exp_dir) + self.exp_result = ExpResult(exp_dir) + self.step = 0 + return + # find unique row for task_name and seed result_df = self.agent_df.reset_index(inplace=False) sub_df = result_df[ @@ -128,16 +147,15 @@ def get_agent_id(self, row: pd.Series): return agent_id def filter_agent_id(self, agent_id: list[tuple]): - # query_str = " & ".join([f"`{col}` == {repr(val)}" for col, val in agent_id]) - # agent_df = info.result_df.query(query_str) - - agent_df = self.result_df.reset_index(inplace=False) - agent_df.set_index(TASK_NAME_KEY, inplace=True) + # Preserve a stable row index to disambiguate selections later + tmp_df = self.result_df.reset_index(inplace=False) + tmp_df["_row_index"] = tmp_df.index + tmp_df.set_index(TASK_NAME_KEY, inplace=True) for col, val in agent_id: col = col.replace(".\n", ".") - agent_df = agent_df[agent_df[col] == val] - self.agent_df = agent_df + tmp_df = tmp_df[tmp_df[col] == val] + self.agent_df = tmp_df info = Info() @@ -735,7 +753,7 @@ def dict_msg_to_markdown(d: dict): case _: parts.append(f"\n```\n{str(item)}\n```\n") - markdown = f"### {d["role"].capitalize()}\n" + markdown = f"### {d['role'].capitalize()}\n" markdown += "\n".join(parts) return markdown @@ -1003,7 +1021,8 @@ def get_seeds_df(result_df: pd.DataFrame, task_name: str): def extract_columns(row: pd.Series): return pd.Series( { - "seed": row[TASK_SEED_KEY], + "index": row.get("_row_index", None), + "seed": row.get(TASK_SEED_KEY, None), "reward": row.get("cum_reward", None), "err": bool(row.get("err_msg", None)), "n_steps": row.get("n_steps", None), @@ -1011,6 +1030,8 @@ def extract_columns(row: pd.Series): ) seed_df = result_df.apply(extract_columns, axis=1) + # Ensure column order and readability + seed_df = seed_df[["seed", "reward", "err", "n_steps","index"]] return seed_df @@ -1028,15 +1049,26 @@ def on_select_task(evt: gr.SelectData, df: pd.DataFrame, agent_id: list[tuple]): def update_seeds(agent_task_id: tuple): agent_id, task_name = agent_task_id seed_df = get_seeds_df(info.agent_df, task_name) - first_seed = seed_df.iloc[0]["seed"] - return seed_df, EpisodeId(agent_id=agent_id, task_name=task_name, seed=first_seed) + first_seed = int(seed_df.iloc[0]["seed"]) if len(seed_df) else None + first_index = int(seed_df.iloc[0]["index"]) if len(seed_df) else None + return seed_df, EpisodeId( + agent_id=agent_id, task_name=task_name, seed=first_seed, row_index=first_index + ) def on_select_seed(evt: gr.SelectData, df: pd.DataFrame, agent_task_id: tuple): agent_id, task_name = agent_task_id col_idx = df.columns.get_loc("seed") - seed = evt.row_value[col_idx] # seed should be the first column - return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed) + idx_col = df.columns.get_loc("index") if "index" in df.columns else None + seed = evt.row_value[col_idx] + row_index = evt.row_value[idx_col] if idx_col is not None else None + try: + seed = int(seed) + if row_index is not None: + row_index = int(row_index) + except Exception: + pass + return EpisodeId(agent_id=agent_id, task_name=task_name, seed=seed, row_index=row_index) def new_episode(episode_id: EpisodeId, progress=gr.Progress()):