shuffle false

This commit is contained in:
Pepijn
2026-01-02 22:34:57 +01:00
parent 7f16e8cb09
commit c5f66edff9
+12 -16
View File
@@ -303,29 +303,25 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
device=device,
)
# Compute per-timestep normalizer for relative actions (main process computes, others load)
# Compute per-timestep normalizer for relative actions
# Each process computes stats independently to avoid distributed sync issues
relative_normalizer = None
if cfg.use_relative_actions:
stats_path = cfg.output_dir / "relative_stats.pt"
mode = "actions + state" if cfg.use_relative_state else "actions only"
if is_main_process:
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
logging.info("Computing per-timestep stats from dataset (first 1000 batches)...")
temp_loader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
)
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
relative_normalizer = PerTimestepNormalizer(mean, std)
if is_main_process:
cfg.output_dir.mkdir(parents=True, exist_ok=True)
temp_loader = torch.utils.data.DataLoader(
dataset, batch_size=cfg.batch_size, shuffle=True, num_workers=0
)
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
relative_normalizer = PerTimestepNormalizer(mean, std)
relative_normalizer.save(stats_path)
logging.info(f"Saved stats to: {stats_path}")
# Barrier: wait for main process to finish computing and saving stats
accelerator.wait_for_everyone()
if not is_main_process:
relative_normalizer = PerTimestepNormalizer.load(stats_path)
relative_normalizer.save(cfg.output_dir / "relative_stats.pt")
logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}")
step = 0 # number of policy updates (forward + backward + optim)