fix(eval): infer recording features from actual env observations

This commit is contained in:
Khalil Meftah
2026-06-15 18:47:16 +02:00
parent 36470d059e
commit 58cf6c8710
+28 -10
View File
@@ -96,18 +96,32 @@ from lerobot.utils.utils import (
)
def _env_features_to_dataset_features(env_features: dict) -> dict:
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create()."""
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
If raw_obs is provided, visual feature shapes are inferred from the actual observation
to avoid mismatches between the env config and the real observation resolution.
"""
features = {}
for key, ft in env_features.items():
if ft.type.value == "visual":
features[key] = {
"dtype": "video",
"shape": tuple(ft.shape),
"names": ["channel", "height", "width"],
}
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
elif raw_obs is not None and "pixels" in raw_obs:
pixels = raw_obs["pixels"]
if isinstance(pixels, dict):
for cam_name, img in pixels.items():
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
shape = img.shape[1:] # strip batch dim
elif key in ("pixels", OBS_IMAGE):
shape = pixels.shape[1:] # strip batch dim
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
features[key] = {"dtype": "float32", "shape": tuple(ft.shape), "names": None}
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
features[key] = {"dtype": "float32", "shape": shape, "names": None}
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
@@ -145,7 +159,10 @@ def _build_raw_frame(
continue
raw_key = key
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
frame[key] = raw_obs[raw_key][env_idx]
val = raw_obs[raw_key][env_idx]
if val.dtype == np.float64:
val = val.astype(np.float32)
frame[key] = val
frame[ACTION] = action
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
@@ -777,7 +794,8 @@ def run_one(
if recording_dir is not None and env_features is not None:
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
fps = env.unwrapped.metadata.get("render_fps", 30)
features = _env_features_to_dataset_features(env_features)
sample_obs, _ = env.reset()
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
recording_dataset = LeRobotDataset.create(
repo_id=f"eval_{task_group}_{task_id}",
fps=fps,