From 36470d059e55fedeb5d5e8bc142bc9747a42f7f0 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Mon, 15 Jun 2026 18:38:12 +0200 Subject: [PATCH] fix(eval): align raw frame keys with dataset schema and fix numpy types --- src/lerobot/scripts/lerobot_eval.py | 43 +++++++++++++++++------------ 1 file changed, 26 insertions(+), 17 deletions(-) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index f877fd39f..990800e84 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -85,7 +85,7 @@ from lerobot.envs import ( from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.processor import PolicyProcessorPipeline from lerobot.types import PolicyAction -from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video @@ -122,26 +122,34 @@ def _build_raw_frame( success: bool, done: bool, task: str, + env_features: dict, ) -> dict: - """Build a dataset frame from raw env observations for one env index.""" + """Build a dataset frame from raw env observations for one env index. + + Keys in the frame match the keys in env_features so they align with the + dataset schema created by _env_features_to_dataset_features(). + """ frame: dict[str, Any] = {} - if "pixels" in raw_obs: - if isinstance(raw_obs["pixels"], dict): - for cam_name, img in raw_obs["pixels"].items(): - frame[f"{OBS_IMAGES}.{cam_name}"] = img[env_idx] - else: - frame[OBS_IMAGE] = raw_obs["pixels"][env_idx] - if "agent_pos" in raw_obs: - frame[OBS_STATE] = raw_obs["agent_pos"][env_idx] - for key, val in raw_obs.items(): - if key in ("pixels", "agent_pos"): + for key in env_features: + if key == ACTION: continue - if isinstance(val, np.ndarray): - frame[f"{OBS_STR}.{key}"] = val[env_idx] + if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict): + for cam_name, img in raw_obs["pixels"].items(): + candidate = f"{OBS_IMAGES}.{cam_name}" + if candidate == key: + frame[key] = img[env_idx] + if key in frame: + continue + if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE): + frame[key] = raw_obs["pixels"][env_idx] + continue + raw_key = key + if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray): + frame[key] = raw_obs[raw_key][env_idx] frame[ACTION] = action - frame["next.reward"] = np.float32(reward) - frame["next.success"] = success - frame["next.done"] = done + frame["next.reward"] = np.atleast_1d(np.float32(reward)) + frame["next.success"] = np.atleast_1d(np.bool_(success)) + frame["next.done"] = np.atleast_1d(np.bool_(done)) frame["task"] = task return frame @@ -290,6 +298,7 @@ def rollout( successes[env_idx], bool(terminated[env_idx] | truncated[env_idx]), task_desc, + recording_dataset.features, ) recording_dataset.add_frame(frame) if terminated[env_idx] or truncated[env_idx]: