diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 9117e64c5..3186eae44 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -303,26 +303,29 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): device=device, ) - # Compute per-timestep normalizer for relative actions + # Compute per-timestep normalizer for relative actions (main process computes, others load) relative_normalizer = None if cfg.use_relative_actions: + stats_path = cfg.output_dir / "relative_stats.pt" mode = "actions + state" if cfg.use_relative_state else "actions only" + if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) logging.info("Computing per-timestep stats from dataset (first 1000 batches)...") + cfg.output_dir.mkdir(parents=True, exist_ok=True) temp_loader = torch.utils.data.DataLoader( dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0 ) mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) relative_normalizer = PerTimestepNormalizer(mean, std) - stats_path = cfg.output_dir / "relative_stats.pt" relative_normalizer.save(stats_path) logging.info(f"Saved stats to: {stats_path}") + # Barrier: wait for main process to finish computing and saving stats accelerator.wait_for_everyone() if not is_main_process: - relative_normalizer = PerTimestepNormalizer.load(cfg.output_dir / "relative_stats.pt") + relative_normalizer = PerTimestepNormalizer.load(stats_path) step = 0 # number of policy updates (forward + backward + optim)