From 7f16e8cb097a3ad31f46fae874b3a9473d02dec3 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Fri, 2 Jan 2026 19:56:42 +0100 Subject: [PATCH] fix --- src/lerobot/scripts/lerobot_train.py | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9117e64c5..3186eae44 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -303,26 +303,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): device=device, ) - # Compute per-timestep normalizer for relative actions + # Compute per-timestep normalizer for relative actions (main process computes, others load) relative_normalizer = None if cfg.use_relative_actions: + stats_path = cfg.output_dir / "relative_stats.pt" mode = "actions + state" if cfg.use_relative_state else "actions only" + if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) logging.info("Computing per-timestep stats from dataset (first 1000 batches)...") + cfg.output_dir.mkdir(parents=True, exist_ok=True) temp_loader = torch.utils.data.DataLoader( dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0 ) mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) relative_normalizer = PerTimestepNormalizer(mean, std) - stats_path = cfg.output_dir / "relative_stats.pt" relative_normalizer.save(stats_path) logging.info(f"Saved stats to: {stats_path}") + # Barrier: wait for main process to finish computing and saving stats accelerator.wait_for_everyone() if not is_main_process: - relative_normalizer = PerTimestepNormalizer.load(cfg.output_dir / "relative_stats.pt") + relative_normalizer = PerTimestepNormalizer.load(stats_path) step = 0 # number of policy updates (forward + backward + optim)