feat(training): add inline offline validation with train/eval split

- Add eval_split config for balanced per-task holdout
- Add eval_steps for periodic inline eval loss computation
- Add max_eval_samples to cap eval cost
This commit is contained in:
Khalil Meftah
2026-06-14 21:29:54 +02:00
parent 64773e7b22
commit 2201401c99
5 changed files with 140 additions and 4 deletions
+2
View File
@@ -39,6 +39,8 @@ class DatasetConfig:
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
return_uint8: bool = False
streaming: bool = False
# Fraction of episodes held out per task for offline evaluation (0.0 = disabled).
eval_split: float = 0.0
def __post_init__(self) -> None:
if self.episodes is not None:
+4
View File
@@ -103,6 +103,10 @@ class TrainPipelineConfig(HubMixin):
# Run policy in the simulation environment every N steps to measure reward/success (0 = disabled).
env_eval_freq: int = 20_000
log_freq: int = 200
# Compute eval loss on held-out episodes every N steps (0 = disabled). Requires eval_split > 0.
eval_steps: int = 0
# Cap on total eval samples, split uniformly across tasks (0 = use all held-out data).
max_eval_samples: int = 0
tolerance_s: float = 1e-4
save_checkpoint: bool = True
# Checkpoint is saved every `save_freq` training iterations and after the last training step.
+2 -1
View File
@@ -35,7 +35,7 @@ from .dataset_tools import (
remove_feature,
split_dataset,
)
from .factory import make_dataset, resolve_delta_timestamps
from .factory import make_dataset, make_train_eval_datasets, resolve_delta_timestamps
from .image_writer import safe_stop_image_writer
from .io_utils import load_episodes, write_stats
from .language import (
@@ -89,6 +89,7 @@ __all__ = [
"get_feature_stats",
"load_episodes",
"make_dataset",
"make_train_eval_datasets",
"merge_datasets",
"modify_features",
"modify_tasks",
+79
View File
@@ -14,6 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import math
from pprint import pformat
import torch
@@ -130,3 +131,81 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return dataset
def make_train_eval_datasets(
cfg: TrainPipelineConfig,
) -> tuple[LeRobotDataset | MultiLeRobotDataset, LeRobotDataset | None]:
"""Create train and optional eval datasets by splitting episodes based on eval_split.
The last ceil(n_episodes * eval_split) episodes per task are held out for evaluation.
If eval_split == 0.0, returns (full_dataset, None).
"""
full_dataset = make_dataset(cfg)
if cfg.dataset.eval_split == 0.0:
return full_dataset, None
base_episodes = (
full_dataset.episodes if full_dataset.episodes is not None else list(range(full_dataset.num_episodes))
)
episode_tasks = full_dataset.meta.episodes["tasks"]
task_to_episodes: dict[str, list[int]] = {}
for ep_idx in base_episodes:
task_key = episode_tasks[ep_idx][0] if episode_tasks[ep_idx] else ""
task_to_episodes.setdefault(task_key, []).append(ep_idx)
train_episodes, eval_episodes = [], []
for eps in task_to_episodes.values():
n_eval = math.ceil(len(eps) * cfg.dataset.eval_split)
train_episodes.extend(eps[: len(eps) - n_eval])
eval_episodes.extend(eps[len(eps) - n_eval :])
if not train_episodes:
raise ValueError(
f"eval_split={cfg.dataset.eval_split} leaves 0 training episodes from {len(base_episodes)} total."
)
logging.info(
f"Train/eval split: {len(train_episodes)} train, {len(eval_episodes)} eval "
f"(eval_split={cfg.dataset.eval_split}, {len(task_to_episodes)} tasks)"
)
delta_timestamps = resolve_delta_timestamps(cfg.trainable_config, full_dataset.meta)
train_image_transforms = (
ImageTransforms(cfg.dataset.image_transforms) if cfg.dataset.image_transforms.enable else None
)
train_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=train_episodes,
delta_timestamps=delta_timestamps,
image_transforms=train_image_transforms,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
eval_dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=eval_episodes,
delta_timestamps=delta_timestamps,
image_transforms=None,
revision=cfg.dataset.revision,
video_backend=cfg.dataset.video_backend,
return_uint8=True,
tolerance_s=cfg.tolerance_s,
)
if cfg.dataset.use_imagenet_stats:
for ds in (train_dataset, eval_dataset):
for key in ds.meta.camera_keys:
for stats_type, stats in IMAGENET_STATS.items():
ds.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
return train_dataset, eval_dataset
+53 -3
View File
@@ -45,7 +45,8 @@ from lerobot.common.train_utils import (
from lerobot.common.wandb_utils import WandBLogger
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state, make_dataset
from lerobot.datasets import EpisodeAwareSampler, compute_sampler_state
from lerobot.datasets.factory import make_train_eval_datasets
from lerobot.envs import close_envs, make_env, make_env_pre_post_processors
from lerobot.optim.factory import make_optimizer_and_scheduler
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
@@ -244,13 +245,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# LeRobotDataset skips its snapshot_download when try_load() succeeds, so no rank re-downloads.
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
dataset, eval_dataset = make_train_eval_datasets(cfg)
accelerator.wait_for_everyone()
# Other ranks read from the shared copy populated by the main process.
if not is_main_process:
dataset = make_dataset(cfg)
dataset, eval_dataset = make_train_eval_datasets(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
@@ -434,6 +435,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
f"Resuming data order at epoch {sampler_state['epoch']}, "
f"sample {sampler_state['start_index']}"
)
if dataset.reader._absolute_to_relative_idx is not None:
sampler.indices = [dataset.reader._absolute_to_relative_idx[i] for i in sampler.indices]
else:
shuffle = True
sampler = None
@@ -455,6 +458,31 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
)
# Build eval dataloader if a held-out split exists
eval_dataloader = None
if eval_dataset is not None:
eval_ds = eval_dataset
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
task_indices = eval_dataset.hf_dataset["task_index"]
unique_tasks = sorted(set(task_indices))
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
selected: list[int] = []
for t in unique_tasks:
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
selected.extend(frames)
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
eval_dataloader = torch.utils.data.DataLoader(
eval_ds,
batch_size=cfg.batch_size,
shuffle=False,
num_workers=cfg.num_workers,
pin_memory=device.type == "cuda",
drop_last=False,
collate_fn=eval_collate_fn,
)
# Prepare everything with accelerator
accelerator.wait_for_everyone()
policy, optimizer, dataloader, lr_scheduler = accelerator.prepare(
@@ -535,6 +563,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
is_env_eval_step = cfg.env_eval_freq > 0 and step % cfg.env_eval_freq == 0
is_eval_step = cfg.eval_steps > 0 and eval_dataloader is not None and step % cfg.eval_steps == 0
if is_log_step:
# Collective reduce must run on every rank, before the main-process gate below.
@@ -557,6 +586,27 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
wandb_logger.log_dict(wandb_log_dict, step)
train_tracker.reset_averages()
if is_eval_step:
policy.eval()
eval_loss_sum = 0.0
n_eval_batches = 0
with torch.no_grad(), accelerator.autocast():
for eval_batch in eval_dataloader:
for cam_key in dataset.meta.camera_keys:
if cam_key in eval_batch and eval_batch[cam_key].dtype == torch.uint8:
eval_batch[cam_key] = eval_batch[cam_key].to(dtype=torch.float32) / 255.0
eval_batch = preprocessor(eval_batch)
loss, _ = policy.forward(eval_batch)
eval_loss_sum += loss.item()
n_eval_batches += 1
eval_loss = eval_loss_sum / max(n_eval_batches, 1)
policy.train()
if is_main_process:
logging.info(f"step {step}: eval_loss={eval_loss:.4f}")
if wandb_logger:
wandb_logger.log_dict({"eval_loss": eval_loss}, step=step, mode="eval")
if cfg.save_checkpoint and is_saving_step:
if is_main_process:
logging.info(f"Checkpoint policy after step {step}")