From e46c6c191224d895d00f55d719b98b8c4a01f668 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 22 Jun 2026 19:03:33 +0200 Subject: [PATCH] fix(eval): address review comments - Wrap rollout loop in try/finally so finalize() runs on crash/interrupt - Guard push_to_hub with num_episodes > 0 to avoid pushing empty datasets - Hoist loop-invariant multi_env and base_repo_id out of creation loop --- src/lerobot/scripts/lerobot_eval.py | 188 ++++++++++++++-------------- 1 file changed, 97 insertions(+), 91 deletions(-) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index c965573c1..1ec4ea75f 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -216,10 +216,10 @@ def rollout( features = _env_features_to_dataset_features(env_features) fps = env.unwrapped.metadata.get("render_fps", 30) recording_datasets = [] + multi_env = env.num_envs > 1 + base_repo_id = recording_repo_id or "eval_recording" for i in range(env.num_envs): - multi_env = env.num_envs > 1 root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir) - base_repo_id = recording_repo_id or "eval_recording" repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id recording_datasets.append( LeRobotDataset.create( @@ -253,100 +253,112 @@ def rollout( leave=False, ) check_env_attributes_and_types(env) - while not np.all(done) and step < max_steps: - # Numpy array to tensor and changing dictionary keys to LeRobot policy format. - observation = preprocess_observation(observation) - if return_observations: - all_observations.append(deepcopy(observation)) + try: + while not np.all(done) and step < max_steps: + # Numpy array to tensor and changing dictionary keys to LeRobot policy format. + observation = preprocess_observation(observation) + if return_observations: + all_observations.append(deepcopy(observation)) - # Infer "task" from sub-environments (prefer natural language description). - # env.call() works with both SyncVectorEnv and AsyncVectorEnv. - try: - observation["task"] = list(env.call("task_description")) - except (AttributeError, NotImplementedError): + # Infer "task" from sub-environments (prefer natural language description). + # env.call() works with both SyncVectorEnv and AsyncVectorEnv. try: - observation["task"] = list(env.call("task")) + observation["task"] = list(env.call("task_description")) except (AttributeError, NotImplementedError): - observation["task"] = [""] * env.num_envs + try: + observation["task"] = list(env.call("task")) + except (AttributeError, NotImplementedError): + observation["task"] = [""] * env.num_envs - # Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO) - observation = env_preprocessor(observation) + # Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO) + observation = env_preprocessor(observation) - observation = preprocessor(observation) - with torch.inference_mode(): - action = policy.select_action(observation) - action = postprocessor(action) + observation = preprocessor(observation) + with torch.inference_mode(): + action = policy.select_action(observation) + action = postprocessor(action) - action_transition = {ACTION: action} - action_transition = env_postprocessor(action_transition) - action = action_transition[ACTION] + action_transition = {ACTION: action} + action_transition = env_postprocessor(action_transition) + action = action_transition[ACTION] - # Convert to CPU / numpy. - action_numpy: np.ndarray = action.to("cpu").numpy() - assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" + # Convert to CPU / numpy. + action_numpy: np.ndarray = action.to("cpu").numpy() + assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)" - # Apply the next action. - observation, reward, terminated, truncated, info = env.step(action_numpy) - if render_callback is not None: - render_callback(env) + # Apply the next action. + observation, reward, terminated, truncated, info = env.step(action_numpy) + if render_callback is not None: + render_callback(env) - # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't - # available if none of the envs finished. - if "final_info" in info: - final_info = info["final_info"] - if not isinstance(final_info, dict): - raise RuntimeError( - "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). " - "You're likely using an older version of gymnasium (< 1.0). Please upgrade." + # VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't + # available if none of the envs finished. + if "final_info" in info: + final_info = info["final_info"] + if not isinstance(final_info, dict): + raise RuntimeError( + "Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). " + "You're likely using an older version of gymnasium (< 1.0). Please upgrade." + ) + successes = final_info["is_success"].tolist() + elif "is_success" in info: + is_success = info["is_success"] + successes = ( + is_success.tolist() + if hasattr(is_success, "tolist") + else [bool(is_success)] * env.num_envs ) - successes = final_info["is_success"].tolist() - elif "is_success" in info: - is_success = info["is_success"] - successes = ( - is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs + else: + successes = [False] * env.num_envs + + if recording_datasets is not None and raw_observation is not None: + prev_done = done.copy() + for env_idx in range(env.num_envs): + if prev_done[env_idx]: + continue + frame = _build_raw_frame( + raw_observation, + env_idx, + action_numpy[env_idx], + reward[env_idx], + successes[env_idx], + bool(terminated[env_idx] | truncated[env_idx]), + task_desc, + recording_datasets[env_idx].features, + ) + recording_datasets[env_idx].add_frame(frame) + if terminated[env_idx] or truncated[env_idx]: + recording_datasets[env_idx].save_episode() + raw_observation = deepcopy(observation) + + # Keep track of which environments are done so far. + # Mark the episode as done if we reach the maximum step limit. + # This ensures that the rollout always terminates cleanly at `max_steps`, + # and allows logging/saving (e.g., videos) to be triggered consistently. + done = terminated | truncated | done + if step + 1 == max_steps: + done = np.ones_like(done, dtype=bool) + + all_actions.append(torch.from_numpy(action_numpy)) + all_rewards.append(torch.from_numpy(reward)) + all_dones.append(torch.from_numpy(done)) + all_successes.append(torch.tensor(successes)) + + step += 1 + running_success_rate = ( + einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() ) - else: - successes = [False] * env.num_envs - - if recording_datasets is not None and raw_observation is not None: - prev_done = done.copy() - for env_idx in range(env.num_envs): - if prev_done[env_idx]: - continue - frame = _build_raw_frame( - raw_observation, - env_idx, - action_numpy[env_idx], - reward[env_idx], - successes[env_idx], - bool(terminated[env_idx] | truncated[env_idx]), - task_desc, - recording_datasets[env_idx].features, - ) - recording_datasets[env_idx].add_frame(frame) - if terminated[env_idx] or truncated[env_idx]: - recording_datasets[env_idx].save_episode() - raw_observation = deepcopy(observation) - - # Keep track of which environments are done so far. - # Mark the episode as done if we reach the maximum step limit. - # This ensures that the rollout always terminates cleanly at `max_steps`, - # and allows logging/saving (e.g., videos) to be triggered consistently. - done = terminated | truncated | done - if step + 1 == max_steps: - done = np.ones_like(done, dtype=bool) - - all_actions.append(torch.from_numpy(action_numpy)) - all_rewards.append(torch.from_numpy(reward)) - all_dones.append(torch.from_numpy(done)) - all_successes.append(torch.tensor(successes)) - - step += 1 - running_success_rate = ( - einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean() - ) - progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) - progbar.update() + progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"}) + progbar.update() + finally: + if recording_datasets is not None: + for ds in recording_datasets: + ds.finalize() + if recording_repo_id is not None: + if ds.num_episodes > 0: + ds.push_to_hub(private=recording_private) + else: + logging.warning("No episodes recorded for %s — skipping push to hub.", ds.repo_id) # Track the final observation. if return_observations: @@ -366,12 +378,6 @@ def rollout( stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) ret[OBS_STR] = stacked_observations - if recording_datasets is not None: - for ds in recording_datasets: - ds.finalize() - if recording_repo_id is not None: - ds.push_to_hub(private=recording_private) - if hasattr(policy, "use_original_modules"): policy.use_original_modules()