From 2201401c9914b832bb8f5fdf605992ff2ff58104 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Sun, 14 Jun 2026 21:29:54 +0200 Subject: [PATCH] feat(training): add inline offline validation with train/eval split - Add eval_split config for balanced per-task holdout - Add eval_steps for periodic inline eval loss computation - Add max_eval_samples to cap eval cost --- src/lerobot/configs/default.py | 2 + src/lerobot/configs/train.py | 4 ++ src/lerobot/datasets/__init__.py | 3 +- src/lerobot/datasets/factory.py | 79 ++++++++++++++++++++++++++++ src/lerobot/scripts/lerobot_train.py | 56 ++++++++++++++++++-- 5 files changed, 140 insertions(+), 4 deletions(-) diff --git a/src/lerobot/configs/default.py b/src/lerobot/configs/default.py index b809e71d9..1fcd8d508 100644 --- a/src/lerobot/configs/default.py +++ b/src/lerobot/configs/default.py @@ -39,6 +39,8 @@ class DatasetConfig: # This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion. return_uint8: bool = False streaming: bool = False + # Fraction of episodes held out per task for offline evaluation (0.0 = disabled). + eval_split: float = 0.0 def __post_init__(self) -> None: if self.episodes is not None: diff --git a/src/lerobot/configs/train.py b/src/lerobot/configs/train.py index 949ebbae0..5b847f700 100644 --- a/src/lerobot/configs/train.py +++ b/src/lerobot/configs/train.py @@ -103,6 +103,10 @@ class TrainPipelineConfig(HubMixin): # Run policy in the simulation environment every N steps to measure reward/success (0 = disabled). env_eval_freq: int = 20_000 log_freq: int = 200 + # Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0. + eval_steps: int = 0 + # Cap on total eval samples, split uniformly across tasks (0 = use all held-out data). + max_eval_samples: int = 0 tolerance_s: float = 1e-4 save_checkpoint: bool = True # Checkpoint is saved every `save_freq` training iterations and after the last training step. diff --git a/src/lerobot/datasets/__init__.py b/src/lerobot/datasets/__init__.py index bd12a7248..7715a115e 100644 --- a/src/lerobot/datasets/__init__.py +++ b/src/lerobot/datasets/__init__.py @@ -35,7 +35,7 @@ from .dataset_tools import ( remove_feature, split_dataset, ) -from .factory import make_dataset, resolve_delta_timestamps +from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps from .image_writer import safe_stop_image_writer from .io_utils import load_episodes, write_stats from .language import ( @@ -89,6 +89,7 @@ __all__ = [ "get_feature_stats", "load_episodes", "make_dataset", + "make_train_eval_datasets", "merge_datasets", "modify_features", "modify_tasks", diff --git a/src/lerobot/datasets/factory.py b/src/lerobot/datasets/factory.py index cbbe83dc8..cd29ee99e 100644 --- a/src/lerobot/datasets/factory.py +++ b/src/lerobot/datasets/factory.py @@ -14,6 +14,7 @@ # See the License for the specific language governing permissions and # limitations under the License. import logging +import math from pprint import pformat import torch @@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) return dataset + + +def make_train_eval_datasets( + cfg: TrainPipelineConfig, +) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]: + """Create train and optional eval datasets by splitting episodes based on eval_split. + + The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation. + If eval_split == 0.0, returns (full_dataset, None). + """ + full_dataset = make_dataset(cfg) + + if cfg.dataset.eval_split == 0.0: + return full_dataset, None + + base_episodes = ( + full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes)) + ) + + episode_tasks = full_dataset.meta.episodes["tasks"] + task_to_episodes: dict[str, list[int]] = {} + for ep_idx in base_episodes: + task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else "" + task_to_episodes.setdefault(task_key, []).append(ep_idx) + + train_episodes, eval_episodes = [], [] + for eps in task_to_episodes.values(): + n_eval = math.ceil(len(eps) * cfg.dataset.eval_split) + train_episodes.extend(eps[: len(eps) - n_eval]) + eval_episodes.extend(eps[len(eps) - n_eval :]) + + if not train_episodes: + raise ValueError( + f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total." + ) + + logging.info( + f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval " + f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)" + ) + + delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta) + + train_image_transforms = ( + ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None + ) + + train_dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=train_episodes, + delta_timestamps=delta_timestamps, + image_transforms=train_image_transforms, + revision=cfg.dataset.revision, + video_backend=cfg.dataset.video_backend, + return_uint8=True, + tolerance_s=cfg.tolerance_s, + ) + + eval_dataset = LeRobotDataset( + cfg.dataset.repo_id, + root=cfg.dataset.root, + episodes=eval_episodes, + delta_timestamps=delta_timestamps, + image_transforms=None, + revision=cfg.dataset.revision, + video_backend=cfg.dataset.video_backend, + return_uint8=True, + tolerance_s=cfg.tolerance_s, + ) + + if cfg.dataset.use_imagenet_stats: + for ds in (train_dataset, eval_dataset): + for key in ds.meta.camera_keys: + for stats_type, stats in IMAGENET_STATS.items(): + ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32) + + return train_dataset, eval_dataset diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 5bfc3cb86..9e4a9a5b5 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -45,7 +45,8 @@ from lerobot.common.train_utils import ( from lerobot.common.wandb_utils import WandBLogger from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset +from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state +from lerobot.datasets.factory import make_train_eval_datasets from lerobot.envs import close_envs, make_env, make_env_pre_post_processors from lerobot.optim.factory import make_optimizer_and_scheduler from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors @@ -244,13 +245,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): # LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads. if is_main_process: logging.info("Creating dataset") - dataset = make_dataset(cfg) + dataset, eval_dataset = make_train_eval_datasets(cfg) accelerator.wait_for_everyone() # Other ranks read from the shared copy populated by the main process. if not is_main_process: - dataset = make_dataset(cfg) + dataset, eval_dataset = make_train_eval_datasets(cfg) # Create environment used for evaluating checkpoints during training on simulation data. # On real-world data, no need to create an environment as evaluations are done outside train.py, @@ -434,6 +435,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): f"Resuming data order at epoch {sampler_state['epoch']}, " f"sample {sampler_state['start_index']}" ) + if dataset.reader._absolute_to_relative_idx is not None: + sampler.indices = [dataset.reader._absolute_to_relative_idx[i] for i in sampler.indices] else: shuffle = True sampler = None @@ -455,6 +458,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, ) + # Build eval dataloader if a held-out split exists + eval_dataloader = None + if eval_dataset is not None: + eval_ds = eval_dataset + if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"): + task_indices = eval_dataset.hf_dataset["task_index"] + unique_tasks = sorted(set(task_indices)) + per_task = max(1, cfg.max_eval_samples // len(unique_tasks)) + selected: list[int] = [] + for t in unique_tasks: + frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task] + selected.extend(frames) + eval_ds = torch.utils.data.Subset(eval_dataset, selected) + + eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None + eval_dataloader = torch.utils.data.DataLoader( + eval_ds, + batch_size=cfg.batch_size, + shuffle=False, + num_workers=cfg.num_workers, + pin_memory=device.type == "cuda", + drop_last=False, + collate_fn=eval_collate_fn, + ) + # Prepare everything with accelerator accelerator.wait_for_everyone() policy, optimizer, dataloader, lr_scheduler = accelerator.prepare( @@ -535,6 +563,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0 + is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0 if is_log_step: # Collective reduce must run on every rank, before the main-process gate below. @@ -557,6 +586,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None): wandb_logger.log_dict(wandb_log_dict, step) train_tracker.reset_averages() + if is_eval_step: + policy.eval() + eval_loss_sum = 0.0 + n_eval_batches = 0 + with torch.no_grad(), accelerator.autocast(): + for eval_batch in eval_dataloader: + for cam_key in dataset.meta.camera_keys: + if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8: + eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0 + eval_batch = preprocessor(eval_batch) + loss, _ = policy.forward(eval_batch) + eval_loss_sum += loss.item() + n_eval_batches += 1 + eval_loss = eval_loss_sum / max(n_eval_batches, 1) + policy.train() + + if is_main_process: + logging.info(f"step {step}: eval_loss={eval_loss:.4f}") + if wandb_logger: + wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval") + if cfg.save_checkpoint and is_saving_step: if is_main_process: logging.info(f"Checkpoint policy after step {step}")