diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index 3186eae44..f88d5c02f 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -303,29 +303,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): device=device, ) - # Compute per-timestep normalizer for relative actions (main process computes, others load) + # Compute per-timestep normalizer for relative actions + # Each process computes stats independently to avoid distributed sync issues relative_normalizer = None if cfg.use_relative_actions: - stats_path = cfg.output_dir / "relative_stats.pt" mode = "actions + state" if cfg.use_relative_state else "actions only" - if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) logging.info("Computing per-timestep stats from dataset (first 1000 batches)...") + + temp_loader = torch.utils.data.DataLoader( + dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 + ) + mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) + relative_normalizer = PerTimestepNormalizer(mean, std) + + if is_main_process: cfg.output_dir.mkdir(parents=True, exist_ok=True) - temp_loader = torch.utils.data.DataLoader( - dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0 - ) - mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) - relative_normalizer = PerTimestepNormalizer(mean, std) - relative_normalizer.save(stats_path) - logging.info(f"Saved stats to: {stats_path}") - - # Barrier: wait for main process to finish computing and saving stats - accelerator.wait_for_everyone() - - if not is_main_process: - relative_normalizer = PerTimestepNormalizer.load(stats_path) + relative_normalizer.save(cfg.output_dir / "relative_stats.pt") + logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}") step = 0 # number of policy updates (forward + backward + optim)