mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
fix(eval): infer recording features from actual env observations
This commit is contained in:
@@ -96,18 +96,32 @@ from lerobot.utils.utils import (
|
||||
)
|
||||
|
||||
|
||||
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create()."""
|
||||
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
|
||||
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
|
||||
|
||||
If raw_obs is provided, visual feature shapes are inferred from the actual observation
|
||||
to avoid mismatches between the env config and the real observation resolution.
|
||||
"""
|
||||
features = {}
|
||||
for key, ft in env_features.items():
|
||||
if ft.type.value == "visual":
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"shape": tuple(ft.shape),
|
||||
"names": ["channel", "height", "width"],
|
||||
}
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
elif raw_obs is not None and "pixels" in raw_obs:
|
||||
pixels = raw_obs["pixels"]
|
||||
if isinstance(pixels, dict):
|
||||
for cam_name, img in pixels.items():
|
||||
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
|
||||
shape = img.shape[1:] # strip batch dim
|
||||
elif key in ("pixels", OBS_IMAGE):
|
||||
shape = pixels.shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||
else:
|
||||
features[key] = {"dtype": "float32", "shape": tuple(ft.shape), "names": None}
|
||||
shape = tuple(ft.shape)
|
||||
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||
shape = raw_obs[key].shape[1:] # strip batch dim
|
||||
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||
@@ -145,7 +159,10 @@ def _build_raw_frame(
|
||||
continue
|
||||
raw_key = key
|
||||
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
|
||||
frame[key] = raw_obs[raw_key][env_idx]
|
||||
val = raw_obs[raw_key][env_idx]
|
||||
if val.dtype == np.float64:
|
||||
val = val.astype(np.float32)
|
||||
frame[key] = val
|
||||
frame[ACTION] = action
|
||||
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||
@@ -777,7 +794,8 @@ def run_one(
|
||||
if recording_dir is not None and env_features is not None:
|
||||
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||
features = _env_features_to_dataset_features(env_features)
|
||||
sample_obs, _ = env.reset()
|
||||
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
|
||||
recording_dataset = LeRobotDataset.create(
|
||||
repo_id=f"eval_{task_group}_{task_id}",
|
||||
fps=fps,
|
||||
|
||||
Reference in New Issue
Block a user