diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index f877fd39f..dfdf3c539 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -72,7 +72,7 @@ from termcolor import colored from torch import Tensor, nn from tqdm import trange -from lerobot.configs import parser +from lerobot.configs import FeatureType, parser from lerobot.configs.eval import EvalPipelineConfig from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.envs import ( @@ -85,7 +85,7 @@ from lerobot.envs import ( from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.processor import PolicyProcessorPipeline from lerobot.types import PolicyAction -from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR, REWARD +from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.io_utils import write_video @@ -96,18 +96,32 @@ from lerobot.utils.utils import ( ) -def _env_features_to_dataset_features(env_features: dict) -> dict: - """Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().""" +def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict: + """Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create(). + + If raw_obs is provided, visual feature shapes are inferred from the actual observation + to avoid mismatches between the env config and the real observation resolution. + """ features = {} for key, ft in env_features.items(): - if ft.type.value == "visual": - features[key] = { - "dtype": "video", - "shape": tuple(ft.shape), - "names": ["channel", "height", "width"], - } + if ft.type is FeatureType.VISUAL: + shape = tuple(ft.shape) + if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray): + shape = raw_obs[key].shape[1:] # strip batch dim + elif raw_obs is not None and "pixels" in raw_obs: + pixels = raw_obs["pixels"] + if isinstance(pixels, dict): + for cam_name, img in pixels.items(): + if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name: + shape = img.shape[1:] # strip batch dim + elif key in ("pixels", OBS_IMAGE): + shape = pixels.shape[1:] # strip batch dim + features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]} else: - features[key] = {"dtype": "float32", "shape": tuple(ft.shape), "names": None} + shape = tuple(ft.shape) + if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray): + shape = raw_obs[key].shape[1:] # strip batch dim + features[key] = {"dtype": "float32", "shape": shape, "names": None} features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None} features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None} features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None} @@ -122,26 +136,37 @@ def _build_raw_frame( success: bool, done: bool, task: str, + env_features: dict, ) -> dict: - """Build a dataset frame from raw env observations for one env index.""" + """Build a dataset frame from raw env observations for one env index. + + Keys in the frame match the keys in env_features so they align with the + dataset schema created by _env_features_to_dataset_features(). + """ frame: dict[str, Any] = {} - if "pixels" in raw_obs: - if isinstance(raw_obs["pixels"], dict): - for cam_name, img in raw_obs["pixels"].items(): - frame[f"{OBS_IMAGES}.{cam_name}"] = img[env_idx] - else: - frame[OBS_IMAGE] = raw_obs["pixels"][env_idx] - if "agent_pos" in raw_obs: - frame[OBS_STATE] = raw_obs["agent_pos"][env_idx] - for key, val in raw_obs.items(): - if key in ("pixels", "agent_pos"): + for key in env_features: + if key == ACTION: continue - if isinstance(val, np.ndarray): - frame[f"{OBS_STR}.{key}"] = val[env_idx] + if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict): + for cam_name, img in raw_obs["pixels"].items(): + candidate = f"{OBS_IMAGES}.{cam_name}" + if candidate == key: + frame[key] = img[env_idx] + if key in frame: + continue + if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE): + frame[key] = raw_obs["pixels"][env_idx] + continue + raw_key = key + if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray): + val = raw_obs[raw_key][env_idx] + if val.dtype == np.float64: + val = val.astype(np.float32) + frame[key] = val frame[ACTION] = action - frame["next.reward"] = np.float32(reward) - frame["next.success"] = success - frame["next.done"] = done + frame["next.reward"] = np.atleast_1d(np.float32(reward)) + frame["next.success"] = np.atleast_1d(np.bool_(success)) + frame["next.done"] = np.atleast_1d(np.bool_(done)) frame["task"] = task return frame @@ -290,6 +315,7 @@ def rollout( successes[env_idx], bool(terminated[env_idx] | truncated[env_idx]), task_desc, + recording_dataset.features, ) recording_dataset.add_frame(frame) if terminated[env_idx] or truncated[env_idx]: @@ -768,7 +794,8 @@ def run_one( if recording_dir is not None and env_features is not None: task_recording_dir = recording_dir / f"{task_group}_{task_id}" fps = env.unwrapped.metadata.get("render_fps", 30) - features = _env_features_to_dataset_features(env_features) + sample_obs, _ = env.reset() + features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs) recording_dataset = LeRobotDataset.create( repo_id=f"eval_{task_group}_{task_id}", fps=fps,