From 4f5e6596be14af29993a4cbda06d11640a41023c Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 16 Jun 2026 22:13:23 +0200 Subject: [PATCH] refactor(eval): remove shape inference and shallow copy helpers --- src/lerobot/scripts/lerobot_eval.py | 44 +++-------------------------- 1 file changed, 4 insertions(+), 40 deletions(-) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index 0421cac62..6ec262a30 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -96,28 +96,11 @@ from lerobot.utils.utils import ( ) -def _infer_shape_from_obs(key: str, raw_obs: dict, fallback: tuple) -> tuple: - """Infer the observation shape from a raw env observation, stripping the batch dim.""" - if key in raw_obs and isinstance(raw_obs[key], np.ndarray): - return raw_obs[key].shape[1:] - if "pixels" in raw_obs: - pixels = raw_obs["pixels"] - if isinstance(pixels, dict): - for cam_name, img in pixels.items(): - if key in (f"{OBS_IMAGES}.{cam_name}", cam_name): - return img.shape[1:] - elif key in ("pixels", OBS_IMAGE): - return pixels.shape[1:] - return fallback - - -def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict: +def _env_features_to_dataset_features(env_features: dict) -> dict: """Convert EnvConfig.features to the dict format expected by LeRobotDataset.create().""" features = {} for key, ft in env_features.items(): shape = tuple(ft.shape) - if raw_obs is not None: - shape = _infer_shape_from_obs(key, raw_obs, shape) if ft.type is FeatureType.VISUAL: features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]} else: @@ -172,19 +155,6 @@ def _build_raw_frame( return frame -def _shallow_copy_obs(obs: dict) -> dict: - """Copy observation dict, cloning only ndarray/Tensor values to avoid mutation.""" - out = {} - for k, v in obs.items(): - if isinstance(v, np.ndarray): - out[k] = v.copy() - elif isinstance(v, dict): - out[k] = _shallow_copy_obs(v) - else: - out[k] = v - return out - - def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, @@ -241,7 +211,7 @@ def rollout( raw_observation = None task_desc = "" if recording_dir is not None and env_features is not None: - features = _env_features_to_dataset_features(env_features, raw_obs=observation) + features = _env_features_to_dataset_features(env_features) fps = env.unwrapped.metadata.get("render_fps", 30) recording_datasets = [] for i in range(env.num_envs): @@ -255,13 +225,7 @@ def rollout( use_videos=True, ) ) - raw_observation = _shallow_copy_obs(observation) - obs_keys = set(observation.keys()) - if "pixels" in obs_keys and isinstance(observation["pixels"], dict): - obs_keys.update(f"{OBS_IMAGES}.{c}" for c in observation["pixels"]) - missing = [k for k in env_features if k != ACTION and not k.startswith("next.") and k not in obs_keys] - if missing: - logging.warning("Recording: env_features keys %s not found in env observations.", missing) + raw_observation = deepcopy(observation) try: task_desc = list(env.call("task_description"))[0] except (AttributeError, NotImplementedError): @@ -357,7 +321,7 @@ def rollout( recording_datasets[env_idx].add_frame(frame) if terminated[env_idx] or truncated[env_idx]: recording_datasets[env_idx].save_episode() - raw_observation = _shallow_copy_obs(observation) + raw_observation = deepcopy(observation) # Keep track of which environments are done so far. # Mark the episode as done if we reach the maximum step limit.