diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index dd36f4a39..deb5a4681 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -305,27 +305,39 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): ) # Compute per-timestep normalizer for relative actions - # Compute BEFORE accelerator.prepare() to avoid NCCL timeouts during long computation relative_normalizer = None if cfg.use_relative_actions: mode = "actions + state" if cfg.use_relative_state else "actions only" + cfg.output_dir.mkdir(parents=True, exist_ok=True) + stats_path = cfg.output_dir / "relative_stats.pt" + if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) - logging.info("Computing per-timestep stats from dataset (first 1000 batches)...") + + if stats_path.exists(): + logging.info(f"Loading pre-computed stats from: {stats_path}") + else: + logging.info("Computing per-timestep stats (first 1000 batches)...") + logging.info("Using fresh dataset to avoid video decoder state issues...") + # Create separate dataset instance to avoid corrupting main dataset's video decoders + stats_dataset = make_dataset(cfg) + temp_loader = torch.utils.data.DataLoader( + stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 + ) + mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) + del temp_loader, stats_dataset + gc.collect() + torch.save({"mean": mean, "std": std}, stats_path) + logging.info(f"Saved stats to: {stats_path}") - # All ranks compute independently to avoid NCCL timeout (computation takes ~10min) - temp_loader = torch.utils.data.DataLoader( - dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 - ) - mean, std = compute_relative_action_stats(temp_loader, num_batches=1000) - del temp_loader - gc.collect() # Clean up any video decoder resources - relative_normalizer = PerTimestepNormalizer(mean, std) + # Poll for stats file instead of using NCCL barrier (avoids timeout during long computation) + if not is_main_process: + while not stats_path.exists(): + time.sleep(5) - if is_main_process: - cfg.output_dir.mkdir(parents=True, exist_ok=True) - relative_normalizer.save(cfg.output_dir / "relative_stats.pt") - logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}") + data = torch.load(stats_path, weights_only=True, map_location="cpu") + relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"]) + accelerator.wait_for_everyone() # Sync after everyone has loaded step = 0 # number of policy updates (forward + backward + optim)