mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
use seperate process for stats computation
This commit is contained in:
@@ -305,27 +305,39 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Compute per-timestep normalizer for relative actions
|
# Compute per-timestep normalizer for relative actions
|
||||||
# Compute BEFORE accelerator.prepare() to avoid NCCL timeouts during long computation
|
|
||||||
relative_normalizer = None
|
relative_normalizer = None
|
||||||
if cfg.use_relative_actions:
|
if cfg.use_relative_actions:
|
||||||
mode = "actions + state" if cfg.use_relative_state else "actions only"
|
mode = "actions + state" if cfg.use_relative_state else "actions only"
|
||||||
|
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
stats_path = cfg.output_dir / "relative_stats.pt"
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
|
logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"]))
|
||||||
logging.info("Computing per-timestep stats from dataset (first 1000 batches)...")
|
|
||||||
|
if stats_path.exists():
|
||||||
|
logging.info(f"Loading pre-computed stats from: {stats_path}")
|
||||||
|
else:
|
||||||
|
logging.info("Computing per-timestep stats (first 1000 batches)...")
|
||||||
|
logging.info("Using fresh dataset to avoid video decoder state issues...")
|
||||||
|
# Create separate dataset instance to avoid corrupting main dataset's video decoders
|
||||||
|
stats_dataset = make_dataset(cfg)
|
||||||
|
temp_loader = torch.utils.data.DataLoader(
|
||||||
|
stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
|
||||||
|
)
|
||||||
|
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
|
||||||
|
del temp_loader, stats_dataset
|
||||||
|
gc.collect()
|
||||||
|
torch.save({"mean": mean, "std": std}, stats_path)
|
||||||
|
logging.info(f"Saved stats to: {stats_path}")
|
||||||
|
|
||||||
# All ranks compute independently to avoid NCCL timeout (computation takes ~10min)
|
# Poll for stats file instead of using NCCL barrier (avoids timeout during long computation)
|
||||||
temp_loader = torch.utils.data.DataLoader(
|
if not is_main_process:
|
||||||
dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0
|
while not stats_path.exists():
|
||||||
)
|
time.sleep(5)
|
||||||
mean, std = compute_relative_action_stats(temp_loader, num_batches=1000)
|
|
||||||
del temp_loader
|
|
||||||
gc.collect() # Clean up any video decoder resources
|
|
||||||
relative_normalizer = PerTimestepNormalizer(mean, std)
|
|
||||||
|
|
||||||
if is_main_process:
|
data = torch.load(stats_path, weights_only=True, map_location="cpu")
|
||||||
cfg.output_dir.mkdir(parents=True, exist_ok=True)
|
relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"])
|
||||||
relative_normalizer.save(cfg.output_dir / "relative_stats.pt")
|
accelerator.wait_for_everyone() # Sync after everyone has loaded
|
||||||
logging.info(f"Saved stats to: {cfg.output_dir / 'relative_stats.pt'}")
|
|
||||||
|
|
||||||
step = 0 # number of policy updates (forward + backward + optim)
|
step = 0 # number of policy updates (forward + backward + optim)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user