fix(eval): address review comments

- Wrap rollout loop in try/finally so finalize() runs on crash/interrupt
- Guard push_to_hub with num_episodes > 0 to avoid pushing empty datasets
- Hoist loop-invariant multi_env and base_repo_id out of creation loop
This commit is contained in:
Khalil Meftah
2026-06-22 19:03:33 +02:00
parent a130a9db39
commit e46c6c1912
+97 -91
View File
@@ -216,10 +216,10 @@ def rollout(
features = _env_features_to_dataset_features(env_features)
fps = env.unwrapped.metadata.get("render_fps", 30)
recording_datasets = []
multi_env = env.num_envs > 1
base_repo_id = recording_repo_id or "eval_recording"
for i in range(env.num_envs):
multi_env = env.num_envs > 1
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
base_repo_id = recording_repo_id or "eval_recording"
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
recording_datasets.append(
LeRobotDataset.create(
@@ -253,100 +253,112 @@ def rollout(
leave=False,
)
check_env_attributes_and_types(env)
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
try:
while not np.all(done) and step < max_steps:
# Numpy array to tensor and changing dictionary keys to LeRobot policy format.
observation = preprocess_observation(observation)
if return_observations:
all_observations.append(deepcopy(observation))
# Infer "task" from sub-environments (prefer natural language description).
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
try:
observation["task"] = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
# Infer "task" from sub-environments (prefer natural language description).
# env.call() works with both SyncVectorEnv and AsyncVectorEnv.
try:
observation["task"] = list(env.call("task"))
observation["task"] = list(env.call("task_description"))
except (AttributeError, NotImplementedError):
observation["task"] = [""] * env.num_envs
try:
observation["task"] = list(env.call("task"))
except (AttributeError, NotImplementedError):
observation["task"] = [""] * env.num_envs
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
observation = env_preprocessor(observation)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
observation = preprocessor(observation)
with torch.inference_mode():
action = policy.select_action(observation)
action = postprocessor(action)
action_transition = {ACTION: action}
action_transition = env_postprocessor(action_transition)
action = action_transition[ACTION]
action_transition = {ACTION: action}
action_transition = env_postprocessor(action_transition)
action = action_transition[ACTION]
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Convert to CPU / numpy.
action_numpy: np.ndarray = action.to("cpu").numpy()
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
# Apply the next action.
observation, reward, terminated, truncated, info = env.step(action_numpy)
if render_callback is not None:
render_callback(env)
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
# VectorEnv stores is_success in `info["final_info"][env_index]["is_success"]`. "final_info" isn't
# available if none of the envs finished.
if "final_info" in info:
final_info = info["final_info"]
if not isinstance(final_info, dict):
raise RuntimeError(
"Unsupported `final_info` format: expected dict (Gymnasium >= 1.0). "
"You're likely using an older version of gymnasium (< 1.0). Please upgrade."
)
successes = final_info["is_success"].tolist()
elif "is_success" in info:
is_success = info["is_success"]
successes = (
is_success.tolist()
if hasattr(is_success, "tolist")
else [bool(is_success)] * env.num_envs
)
successes = final_info["is_success"].tolist()
elif "is_success" in info:
is_success = info["is_success"]
successes = (
is_success.tolist() if hasattr(is_success, "tolist") else [bool(is_success)] * env.num_envs
else:
successes = [False] * env.num_envs
if recording_datasets is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_datasets[env_idx].features,
)
recording_datasets[env_idx].add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_datasets[env_idx].save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
# and allows logging/saving (e.g., videos) to be triggered consistently.
done = terminated | truncated | done
if step + 1 == max_steps:
done = np.ones_like(done, dtype=bool)
all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
else:
successes = [False] * env.num_envs
if recording_datasets is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_datasets[env_idx].features,
)
recording_datasets[env_idx].add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_datasets[env_idx].save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
# and allows logging/saving (e.g., videos) to be triggered consistently.
done = terminated | truncated | done
if step + 1 == max_steps:
done = np.ones_like(done, dtype=bool)
all_actions.append(torch.from_numpy(action_numpy))
all_rewards.append(torch.from_numpy(reward))
all_dones.append(torch.from_numpy(done))
all_successes.append(torch.tensor(successes))
step += 1
running_success_rate = (
einops.reduce(torch.stack(all_successes, dim=1), "b n -> b", "any").numpy().mean()
)
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
progbar.set_postfix({"running_success_rate": f"{running_success_rate.item() * 100:.1f}%"})
progbar.update()
finally:
if recording_datasets is not None:
for ds in recording_datasets:
ds.finalize()
if recording_repo_id is not None:
if ds.num_episodes > 0:
ds.push_to_hub(private=recording_private)
else:
logging.warning("No episodes recorded for %s — skipping push to hub.", ds.repo_id)
# Track the final observation.
if return_observations:
@@ -366,12 +378,6 @@ def rollout(
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret[OBS_STR] = stacked_observations
if recording_datasets is not None:
for ds in recording_datasets:
ds.finalize()
if recording_repo_id is not None:
ds.push_to_hub(private=recording_private)
if hasattr(policy, "use_original_modules"):
policy.use_original_modules()