From 040c6b3d66ca1662e63f461495669450c4b6ec3d Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Tue, 16 Jun 2026 21:35:05 +0200 Subject: [PATCH] refactor(eval): per-env datasets recording, no double reset - Extract _infer_shape_from_obs() to reduce nesting in feature conversion - Move dataset creation into rollout() using its own env.reset() observation, eliminating the extra reset in run_one() - Replace deepcopy with _shallow_copy_obs() for raw observation stashing - Support batch_size > 1: each parallel env records to its own dataset (single env skips the env_0/ nesting for simplicity) - One-time warning for env_features keys missing from observations - Pass recording_dir + env_features through the call chain instead of a pre-built recording_dataset object --- src/lerobot/scripts/lerobot_eval.py | 158 +++++++++++++++++----------- 1 file changed, 95 insertions(+), 63 deletions(-) diff --git a/src/lerobot/scripts/lerobot_eval.py b/src/lerobot/scripts/lerobot_eval.py index dfdf3c539..0421cac62 100644 --- a/src/lerobot/scripts/lerobot_eval.py +++ b/src/lerobot/scripts/lerobot_eval.py @@ -96,31 +96,31 @@ from lerobot.utils.utils import ( ) -def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict: - """Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create(). +def _infer_shape_from_obs(key: str, raw_obs: dict, fallback: tuple) -> tuple: + """Infer the observation shape from a raw env observation, stripping the batch dim.""" + if key in raw_obs and isinstance(raw_obs[key], np.ndarray): + return raw_obs[key].shape[1:] + if "pixels" in raw_obs: + pixels = raw_obs["pixels"] + if isinstance(pixels, dict): + for cam_name, img in pixels.items(): + if key in (f"{OBS_IMAGES}.{cam_name}", cam_name): + return img.shape[1:] + elif key in ("pixels", OBS_IMAGE): + return pixels.shape[1:] + return fallback - If raw_obs is provided, visual feature shapes are inferred from the actual observation - to avoid mismatches between the env config and the real observation resolution. - """ + +def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict: + """Convert EnvConfig.features to the dict format expected by LeRobotDataset.create().""" features = {} for key, ft in env_features.items(): + shape = tuple(ft.shape) + if raw_obs is not None: + shape = _infer_shape_from_obs(key, raw_obs, shape) if ft.type is FeatureType.VISUAL: - shape = tuple(ft.shape) - if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray): - shape = raw_obs[key].shape[1:] # strip batch dim - elif raw_obs is not None and "pixels" in raw_obs: - pixels = raw_obs["pixels"] - if isinstance(pixels, dict): - for cam_name, img in pixels.items(): - if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name: - shape = img.shape[1:] # strip batch dim - elif key in ("pixels", OBS_IMAGE): - shape = pixels.shape[1:] # strip batch dim features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]} else: - shape = tuple(ft.shape) - if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray): - shape = raw_obs[key].shape[1:] # strip batch dim features[key] = {"dtype": "float32", "shape": shape, "names": None} features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None} features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None} @@ -147,6 +147,8 @@ def _build_raw_frame( for key in env_features: if key == ACTION: continue + if key.startswith("next."): + continue if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict): for cam_name, img in raw_obs["pixels"].items(): candidate = f"{OBS_IMAGES}.{cam_name}" @@ -157,9 +159,8 @@ def _build_raw_frame( if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE): frame[key] = raw_obs["pixels"][env_idx] continue - raw_key = key - if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray): - val = raw_obs[raw_key][env_idx] + if key in raw_obs and isinstance(raw_obs[key], np.ndarray): + val = raw_obs[key][env_idx] if val.dtype == np.float64: val = val.astype(np.float32) frame[key] = val @@ -171,6 +172,19 @@ def _build_raw_frame( return frame +def _shallow_copy_obs(obs: dict) -> dict: + """Copy observation dict, cloning only ndarray/Tensor values to avoid mutation.""" + out = {} + for k, v in obs.items(): + if isinstance(v, np.ndarray): + out[k] = v.copy() + elif isinstance(v, dict): + out[k] = _shallow_copy_obs(v) + else: + out[k] = v + return out + + def rollout( env: gym.vector.VectorEnv, policy: PreTrainedPolicy, @@ -181,7 +195,8 @@ def rollout( seeds: list[int] | None = None, return_observations: bool = False, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, - recording_dataset: Any | None = None, + recording_dir: Path | None = None, + env_features: dict | None = None, ) -> dict: """Run a batched policy rollout once through a batch of environments. @@ -222,9 +237,31 @@ def rollout( if render_callback is not None: render_callback(env) - raw_observation = deepcopy(observation) if recording_dataset is not None else None + recording_datasets: list[LeRobotDataset] | None = None + raw_observation = None task_desc = "" - if recording_dataset is not None: + if recording_dir is not None and env_features is not None: + features = _env_features_to_dataset_features(env_features, raw_obs=observation) + fps = env.unwrapped.metadata.get("render_fps", 30) + recording_datasets = [] + for i in range(env.num_envs): + root = str(recording_dir / f"env_{i}") if env.num_envs > 1 else str(recording_dir) + recording_datasets.append( + LeRobotDataset.create( + repo_id="eval_recording", + fps=fps, + features=features, + root=root, + use_videos=True, + ) + ) + raw_observation = _shallow_copy_obs(observation) + obs_keys = set(observation.keys()) + if "pixels" in obs_keys and isinstance(observation["pixels"], dict): + obs_keys.update(f"{OBS_IMAGES}.{c}" for c in observation["pixels"]) + missing = [k for k in env_features if k != ACTION and not k.startswith("next.") and k not in obs_keys] + if missing: + logging.warning("Recording: env_features keys %s not found in env observations.", missing) try: task_desc = list(env.call("task_description"))[0] except (AttributeError, NotImplementedError): @@ -302,7 +339,7 @@ def rollout( else: successes = [False] * env.num_envs - if recording_dataset is not None and raw_observation is not None: + if recording_datasets is not None and raw_observation is not None: prev_done = done.copy() for env_idx in range(env.num_envs): if prev_done[env_idx]: @@ -315,12 +352,12 @@ def rollout( successes[env_idx], bool(terminated[env_idx] | truncated[env_idx]), task_desc, - recording_dataset.features, + recording_datasets[env_idx].features, ) - recording_dataset.add_frame(frame) + recording_datasets[env_idx].add_frame(frame) if terminated[env_idx] or truncated[env_idx]: - recording_dataset.save_episode() - raw_observation = deepcopy(observation) + recording_datasets[env_idx].save_episode() + raw_observation = _shallow_copy_obs(observation) # Keep track of which environments are done so far. # Mark the episode as done if we reach the maximum step limit. @@ -360,6 +397,10 @@ def rollout( stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) ret[OBS_STR] = stacked_observations + if recording_datasets is not None: + for ds in recording_datasets: + ds.finalize() + if hasattr(policy, "use_original_modules"): policy.use_original_modules() @@ -378,7 +419,8 @@ def eval_policy( videos_dir: Path | None = None, return_episode_data: bool = False, start_seed: int | None = None, - recording_dataset: Any | None = None, + recording_dir: Path | None = None, + env_features: dict | None = None, ) -> dict: """ Args: @@ -467,7 +509,8 @@ def eval_policy( seeds=list(seeds) if seeds else None, return_observations=return_episode_data, render_callback=render_frame if max_episodes_rendered > 0 else None, - recording_dataset=recording_dataset, + recording_dir=recording_dir, + env_features=env_features, ) # Figure out where in each rollout sequence the first done condition was encountered (results after @@ -732,7 +775,8 @@ def eval_one( videos_dir: Path | None, return_episode_data: bool, start_seed: int | None, - recording_dataset: Any | None = None, + recording_dir: Path | None = None, + env_features: dict | None = None, ) -> TaskMetrics: """Evaluates one task_id of one suite using the provided vec env.""" @@ -750,7 +794,8 @@ def eval_one( videos_dir=task_videos_dir, return_episode_data=return_episode_data, start_seed=start_seed, - recording_dataset=recording_dataset, + recording_dir=recording_dir, + env_features=env_features, ) per_episode = task_result["per_episode"] @@ -790,38 +835,25 @@ def run_one( task_videos_dir = videos_dir / f"{task_group}_{task_id}" task_videos_dir.mkdir(parents=True, exist_ok=True) - recording_dataset = None + task_recording_dir = None if recording_dir is not None and env_features is not None: task_recording_dir = recording_dir / f"{task_group}_{task_id}" - fps = env.unwrapped.metadata.get("render_fps", 30) - sample_obs, _ = env.reset() - features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs) - recording_dataset = LeRobotDataset.create( - repo_id=f"eval_{task_group}_{task_id}", - fps=fps, - features=features, - root=str(task_recording_dir), - use_videos=True, - ) - try: - metrics = eval_one( - env, - policy=policy, - env_preprocessor=env_preprocessor, - env_postprocessor=env_postprocessor, - preprocessor=preprocessor, - postprocessor=postprocessor, - n_episodes=n_episodes, - max_episodes_rendered=max_episodes_rendered, - videos_dir=task_videos_dir, - return_episode_data=return_episode_data, - start_seed=start_seed, - recording_dataset=recording_dataset, - ) - finally: - if recording_dataset is not None: - recording_dataset.finalize() + metrics = eval_one( + env, + policy=policy, + env_preprocessor=env_preprocessor, + env_postprocessor=env_postprocessor, + preprocessor=preprocessor, + postprocessor=postprocessor, + n_episodes=n_episodes, + max_episodes_rendered=max_episodes_rendered, + videos_dir=task_videos_dir, + return_episode_data=return_episode_data, + start_seed=start_seed, + recording_dir=task_recording_dir, + env_features=env_features, + ) if max_episodes_rendered > 0: metrics.setdefault("video_paths", [])