Compare commits

..

6 Commits

Author SHA1 Message Date
Khalil Meftah a130a9db39 feat(eval): optionally push recorded eval datasets to the Hub 2026-06-17 11:41:50 +02:00
Khalil Meftah 4f5e6596be refactor(eval): remove shape inference and shallow copy helpers 2026-06-16 22:13:23 +02:00
Khalil Meftah afeeeb8982 Merge branch 'main' into feat/eval-dataset-recording 2026-06-16 21:45:06 +02:00
Khalil Meftah 040c6b3d66 refactor(eval): per-env datasets recording, no double reset
- Extract _infer_shape_from_obs() to reduce nesting in feature conversion
- Move dataset creation into rollout() using its own env.reset() observation,
  eliminating the extra reset in run_one()
- Replace deepcopy with _shallow_copy_obs() for raw observation stashing
- Support batch_size > 1: each parallel env records to its own dataset
  (single env skips the env_0/ nesting for simplicity)
- One-time warning for env_features keys missing from observations
- Pass recording_dir + env_features through the call chain instead of
  a pre-built recording_dataset object
2026-06-16 21:35:05 +02:00
Khalil Meftah acd31c7de2 fix(eval): use FeatureType enum comparison instead of string value 2026-06-16 15:22:50 +02:00
Khalil Meftah 240393d238 feat(eval): record eval rollouts as raw LeRobot datasets
- Record raw env observations inline during rollout(), before
preprocess_observation() transforms them. Uses LeRobotDataset.create()
with add_frame()/save_episode().

- Supports vectorized envs: each env in the batch records independently,
with save_episode() called per env on termination. Each task gets its
own dataset under output_dir/recordings/{task_group}_{task_id}/.

Enabled via --eval.recording=true; disabled by default.
2026-06-15 16:12:25 +02:00
8 changed files with 192 additions and 33 deletions
-1
View File
@@ -136,7 +136,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments. - **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot. - **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot. - **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
## Citation ## Citation
+2 -3
View File
@@ -442,12 +442,11 @@ class OpenCVCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
raw_frame = self._read_from_hardware() raw_frame = self._read_from_hardware()
processed_frame = self._postprocess_image(raw_frame) processed_frame = self._postprocess_image(raw_frame)
@@ -471,12 +471,11 @@ class RealSenseCamera(Camera):
Stops on DeviceNotConnectedError, logs other errors and continues. Stops on DeviceNotConnectedError, logs other errors and continues.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.") raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
color_frame_raw = frame.get_color_frame() color_frame_raw = frame.get_color_frame()
+2 -3
View File
@@ -246,12 +246,11 @@ class ZMQCamera(Camera):
""" """
Internal loop run by the background thread for asynchronous reading. Internal loop run by the background thread for asynchronous reading.
""" """
stop_event = self.stop_event if self.stop_event is None:
if stop_event is None:
raise RuntimeError(f"{self}: stop_event is not initialized.") raise RuntimeError(f"{self}: stop_event is not initialized.")
failure_count = 0 failure_count = 0
while not stop_event.is_set(): while not self.stop_event.is_set():
try: try:
frame = self._read_from_hardware() frame = self._read_from_hardware()
capture_time = time.perf_counter() capture_time = time.perf_counter()
+9
View File
@@ -73,8 +73,17 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing). # `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1. # Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True use_async_envs: bool = True
# Whether to record eval rollouts as a LeRobot dataset on disk.
recording: bool = False
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
# suffixed by task and env index). Requires recording=true.
recording_repo_id: str | None = None
# Whether the pushed recording repositories should be private.
recording_private: bool = False
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.recording_repo_id is not None and not self.recording:
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
if self.batch_size == 0: if self.batch_size == 0:
self.batch_size = self._auto_batch_size() self.batch_size = self._auto_batch_size()
if self.batch_size > self.n_episodes: if self.batch_size > self.n_episodes:
-12
View File
@@ -74,8 +74,6 @@ class DatasetReader:
self.episodes = episodes self.episodes = episodes
self._tolerance_s = tolerance_s self._tolerance_s = tolerance_s
self._video_backend = video_backend self._video_backend = video_backend
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms self._image_transforms = image_transforms
self._return_uint8 = return_uint8 self._return_uint8 = return_uint8
@@ -88,16 +86,6 @@ class DatasetReader:
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s) check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps) self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations."""
if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self._image_transforms = image_transforms
def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations."""
self._image_transforms = None
def try_load(self) -> bool: def try_load(self) -> bool:
"""Attempt to load from local cache. Returns True if data is sufficient.""" """Attempt to load from local cache. Returns True if data is sufficient."""
try: try:
+7 -5
View File
@@ -201,6 +201,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
super().__init__() super().__init__()
self.repo_id = repo_id self.repo_id = repo_id
self._requested_root = Path(root) if root else None self._requested_root = Path(root) if root else None
self.reader = None
self.set_image_transforms(image_transforms)
self.delta_timestamps = delta_timestamps self.delta_timestamps = delta_timestamps
self.tolerance_s = tolerance_s self.tolerance_s = tolerance_s
self.revision = revision if revision else CODEBASE_VERSION self.revision = revision if revision else CODEBASE_VERSION
@@ -247,7 +249,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
image_transforms=image_transforms, image_transforms=image_transforms,
return_uint8=self._return_uint8, return_uint8=self._return_uint8,
) )
self.image_transforms = image_transforms
# Load actual data # Load actual data
if force_cache_sync or not self.reader.try_load(): if force_cache_sync or not self.reader.try_load():
@@ -504,14 +505,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
def set_image_transforms(self, image_transforms: Callable | None) -> None: def set_image_transforms(self, image_transforms: Callable | None) -> None:
"""Replace the transform applied to visual observations.""" """Replace the transform applied to visual observations."""
self._ensure_reader().set_image_transforms(image_transforms) if image_transforms is not None and not callable(image_transforms):
raise TypeError("image_transforms must be callable or None.")
self.image_transforms = image_transforms self.image_transforms = image_transforms
if self.reader is not None:
self.reader._image_transforms = image_transforms
def clear_image_transforms(self) -> None: def clear_image_transforms(self) -> None:
"""Remove the transform applied to visual observations.""" """Remove the transform applied to visual observations."""
if self.reader is not None: self.set_image_transforms(None)
self.reader.set_image_transforms(None)
self.image_transforms = None
# ── Hub methods (stay on facade) ────────────────────────────────── # ── Hub methods (stay on facade) ──────────────────────────────────
+170 -6
View File
@@ -72,8 +72,9 @@ from termcolor import colored
from torch import Tensor, nn from torch import Tensor, nn
from tqdm import trange from tqdm import trange
from lerobot.configs import parser from lerobot.configs import FeatureType, parser
from lerobot.configs.eval import EvalPipelineConfig from lerobot.configs.eval import EvalPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs import ( from lerobot.envs import (
check_env_attributes_and_types, check_env_attributes_and_types,
close_envs, close_envs,
@@ -84,7 +85,7 @@ from lerobot.envs import (
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video from lerobot.utils.io_utils import write_video
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
) )
def _env_features_to_dataset_features(env_features: dict) -> dict:
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
features = {}
for key, ft in env_features.items():
shape = tuple(ft.shape)
if ft.type is FeatureType.VISUAL:
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
features[key] = {"dtype": "float32", "shape": shape, "names": None}
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
return features
def _build_raw_frame(
raw_obs: dict,
env_idx: int,
action: np.ndarray,
reward: float,
success: bool,
done: bool,
task: str,
env_features: dict,
) -> dict:
"""Build a dataset frame from raw env observations for one env index.
Keys in the frame match the keys in env_features so they align with the
dataset schema created by _env_features_to_dataset_features().
"""
frame: dict[str, Any] = {}
for key in env_features:
if key == ACTION:
continue
if key.startswith("next."):
continue
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
candidate = f"{OBS_IMAGES}.{cam_name}"
if candidate == key:
frame[key] = img[env_idx]
if key in frame:
continue
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
frame[key] = raw_obs["pixels"][env_idx]
continue
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
val = raw_obs[key][env_idx]
if val.dtype == np.float64:
val = val.astype(np.float32)
frame[key] = val
frame[ACTION] = action
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
frame["next.done"] = np.atleast_1d(np.bool_(done))
frame["task"] = task
return frame
def rollout( def rollout(
env: gym.vector.VectorEnv, env: gym.vector.VectorEnv,
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
@@ -105,6 +165,10 @@ def rollout(
seeds: list[int] | None = None, seeds: list[int] | None = None,
return_observations: bool = False, return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None, render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> dict: ) -> dict:
"""Run a batched policy rollout once through a batch of environments. """Run a batched policy rollout once through a batch of environments.
@@ -145,6 +209,33 @@ def rollout(
if render_callback is not None: if render_callback is not None:
render_callback(env) render_callback(env)
recording_datasets: list[LeRobotDataset] | None = None
raw_observation = None
task_desc = ""
if recording_dir is not None and env_features is not None:
features = _env_features_to_dataset_features(env_features)
fps = env.unwrapped.metadata.get("render_fps", 30)
recording_datasets = []
for i in range(env.num_envs):
multi_env = env.num_envs > 1
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
base_repo_id = recording_repo_id or "eval_recording"
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
recording_datasets.append(
LeRobotDataset.create(
repo_id=repo_id,
fps=fps,
features=features,
root=root,
use_videos=True,
)
)
raw_observation = deepcopy(observation)
try:
task_desc = list(env.call("task_description"))[0]
except (AttributeError, NotImplementedError):
task_desc = ""
all_observations = [] all_observations = []
all_actions = [] all_actions = []
all_rewards = [] all_rewards = []
@@ -217,6 +308,26 @@ def rollout(
else: else:
successes = [False] * env.num_envs successes = [False] * env.num_envs
if recording_datasets is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_datasets[env_idx].features,
)
recording_datasets[env_idx].add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_datasets[env_idx].save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far. # Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit. # Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`, # This ensures that the rollout always terminates cleanly at `max_steps`,
@@ -255,6 +366,12 @@ def rollout(
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1) stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
ret[OBS_STR] = stacked_observations ret[OBS_STR] = stacked_observations
if recording_datasets is not None:
for ds in recording_datasets:
ds.finalize()
if recording_repo_id is not None:
ds.push_to_hub(private=recording_private)
if hasattr(policy, "use_original_modules"): if hasattr(policy, "use_original_modules"):
policy.use_original_modules() policy.use_original_modules()
@@ -273,6 +390,10 @@ def eval_policy(
videos_dir: Path | None = None, videos_dir: Path | None = None,
return_episode_data: bool = False, return_episode_data: bool = False,
start_seed: int | None = None, start_seed: int | None = None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> dict: ) -> dict:
""" """
Args: Args:
@@ -361,6 +482,10 @@ def eval_policy(
seeds=list(seeds) if seeds else None, seeds=list(seeds) if seeds else None,
return_observations=return_episode_data, return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None, render_callback=render_frame if max_episodes_rendered > 0 else None,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
) )
# Figure out where in each rollout sequence the first done condition was encountered (results after # Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -563,6 +688,10 @@ def eval_main(cfg: EvalPipelineConfig):
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments) # Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy) env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(): with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all( info = eval_policy_all(
envs=envs, envs=envs,
@@ -572,10 +701,15 @@ def eval_main(cfg: EvalPipelineConfig):
preprocessor=preprocessor, preprocessor=preprocessor,
postprocessor=postprocessor, postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes, n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=10, max_episodes_rendered=max_episodes_rendered,
videos_dir=Path(cfg.output_dir) / "videos", videos_dir=videos_dir,
return_episode_data=False,
start_seed=cfg.seed, start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks, max_parallel_tasks=cfg.env.max_parallel_tasks,
recording_dir=recording_dir,
env_features=cfg.env.features if cfg.eval.recording else None,
recording_repo_id=cfg.eval.recording_repo_id,
recording_private=cfg.eval.recording_private,
) )
print("Overall Aggregated Metrics:") print("Overall Aggregated Metrics:")
print(info["overall"]) print(info["overall"])
@@ -618,6 +752,10 @@ def eval_one(
videos_dir: Path | None, videos_dir: Path | None,
return_episode_data: bool, return_episode_data: bool,
start_seed: int | None, start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
) -> TaskMetrics: ) -> TaskMetrics:
"""Evaluates one task_id of one suite using the provided vec env.""" """Evaluates one task_id of one suite using the provided vec env."""
@@ -635,6 +773,10 @@ def eval_one(
videos_dir=task_videos_dir, videos_dir=task_videos_dir,
return_episode_data=return_episode_data, return_episode_data=return_episode_data,
start_seed=start_seed, start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
) )
per_episode = task_result["per_episode"] per_episode = task_result["per_episode"]
@@ -661,6 +803,10 @@ def run_one(
videos_dir: Path | None, videos_dir: Path | None,
return_episode_data: bool, return_episode_data: bool,
start_seed: int | None, start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
): ):
""" """
Run eval_one for a single (task_group, task_id, env). Run eval_one for a single (task_group, task_id, env).
@@ -672,7 +818,13 @@ def run_one(
task_videos_dir = videos_dir / f"{task_group}_{task_id}" task_videos_dir = videos_dir / f"{task_group}_{task_id}"
task_videos_dir.mkdir(parents=True, exist_ok=True) task_videos_dir.mkdir(parents=True, exist_ok=True)
# Call the existing eval_one (assumed to return TaskMetrics-like dict) task_recording_dir = None
task_repo_id = None
if recording_dir is not None and env_features is not None:
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
if recording_repo_id is not None:
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
metrics = eval_one( metrics = eval_one(
env, env,
policy=policy, policy=policy,
@@ -685,8 +837,12 @@ def run_one(
videos_dir=task_videos_dir, videos_dir=task_videos_dir,
return_episode_data=return_episode_data, return_episode_data=return_episode_data,
start_seed=start_seed, start_seed=start_seed,
recording_dir=task_recording_dir,
env_features=env_features,
recording_repo_id=task_repo_id,
recording_private=recording_private,
) )
# ensure we always provide video_paths key to simplify accumulation
if max_episodes_rendered > 0: if max_episodes_rendered > 0:
metrics.setdefault("video_paths", []) metrics.setdefault("video_paths", [])
return task_group, task_id, metrics return task_group, task_id, metrics
@@ -702,6 +858,10 @@ def eval_policy_all(
n_episodes: int, n_episodes: int,
*, *,
max_episodes_rendered: int = 0, max_episodes_rendered: int = 0,
recording_dir: Path | None = None,
env_features: dict | None = None,
recording_repo_id: str | None = None,
recording_private: bool = False,
videos_dir: Path | None = None, videos_dir: Path | None = None,
return_episode_data: bool = False, return_episode_data: bool = False,
start_seed: int | None = None, start_seed: int | None = None,
@@ -761,6 +921,10 @@ def eval_policy_all(
videos_dir=videos_dir, videos_dir=videos_dir,
return_episode_data=return_episode_data, return_episode_data=return_episode_data,
start_seed=start_seed, start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
recording_repo_id=recording_repo_id,
recording_private=recording_private,
) )
if max_parallel_tasks <= 1: if max_parallel_tasks <= 1: