fix(eval): align raw frame keys with dataset schema and fix numpy types

This commit is contained in:
Khalil Meftah
2026-06-15 18:38:12 +02:00
parent 040a1df9d6
commit 36470d059e
+26 -17
View File
@@ -85,7 +85,7 @@ from lerobot.envs import (
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
@@ -122,26 +122,34 @@ def _build_raw_frame(
success: bool,
done: bool,
task: str,
env_features: dict,
) -> dict:
"""Build a dataset frame from raw env observations for one env index."""
"""Build a dataset frame from raw env observations for one env index.
Keys in the frame match the keys in env_features so they align with the
dataset schema created by _env_features_to_dataset_features().
"""
frame: dict[str, Any] = {}
if "pixels" in raw_obs:
if isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
frame[f"{OBS_IMAGES}.{cam_name}"] = img[env_idx]
else:
frame[OBS_IMAGE] = raw_obs["pixels"][env_idx]
if "agent_pos" in raw_obs:
frame[OBS_STATE] = raw_obs["agent_pos"][env_idx]
for key, val in raw_obs.items():
if key in ("pixels", "agent_pos"):
for key in env_features:
if key == ACTION:
continue
if isinstance(val, np.ndarray):
frame[f"{OBS_STR}.{key}"] = val[env_idx]
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
candidate = f"{OBS_IMAGES}.{cam_name}"
if candidate == key:
frame[key] = img[env_idx]
if key in frame:
continue
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
frame[key] = raw_obs["pixels"][env_idx]
continue
raw_key = key
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
frame[key] = raw_obs[raw_key][env_idx]
frame[ACTION] = action
frame["next.reward"] = np.float32(reward)
frame["next.success"] = success
frame["next.done"] = done
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
frame["next.done"] = np.atleast_1d(np.bool_(done))
frame["task"] = task
return frame
@@ -290,6 +298,7 @@ def rollout(
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_dataset.features,
)
recording_dataset.add_frame(frame)
if terminated[env_idx] or truncated[env_idx]: