From 936187cd7669d316558d9ea25bb0bdc158c88f78 Mon Sep 17 00:00:00 2001 From: pepijn Date: Wed, 1 Apr 2026 20:45:35 +0000 Subject: [PATCH] fix relative actions: convert before normalization, use global stats The previous implementation had a double-normalization bug: the preprocessor normalized actions with absolute stats, then convert_to_relative subtracted normalized state (wrong), then the per-timestep normalizer re-normalized. Now the correct flow is: 1. Convert batch to relative on raw data (before preprocessing) 2. Compute global relative stats (mean/std across all timesteps) 3. Hotswap the preprocessor normalizer to use relative stats 4. Preprocessor normalizes relative values correctly This brings loss from ~3000+ down to ~0.5, matching the main branch. Made-with: Cursor --- src/lerobot/scripts/lerobot_train.py | 75 ++++++++++++++++----------- src/lerobot/utils/relative_actions.py | 44 +++++++++++++++- 2 files changed, 88 insertions(+), 31 deletions(-) diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index a9c0ddb54..355808e75 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -49,8 +49,7 @@ from lerobot.utils.train_utils import ( ) from lerobot.utils.relative_actions import ( convert_to_relative_actions, - compute_relative_action_stats, - PerTimestepNormalizer, + compute_global_relative_stats, ) from lerobot.utils.utils import ( format_big_number, @@ -304,42 +303,60 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): device=device, ) - # Compute per-timestep normalizer for relative actions - relative_normalizer = None + # Compute relative action/state stats and hotswap them into the normalizer + raw_state_key = None if cfg.use_relative_actions: + from lerobot.processor.normalize_processor import hotswap_stats + mode = "actions + state" if cfg.use_relative_state else "actions only" cfg.output_dir.mkdir(parents=True, exist_ok=True) stats_path = cfg.output_dir / "relative_stats.pt" - + + reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {} + raw_state_key = reverse_rename.get("observation.state", "observation.state") + if is_main_process: logging.info(colored(f"Relative mode: {mode}", "cyan", attrs=["bold"])) - + if stats_path.exists(): - logging.info(f"Loading pre-computed stats from: {stats_path}") + logging.info(f"Loading pre-computed relative stats from: {stats_path}") else: - logging.info("Computing per-timestep stats (first 1000 batches)...") - logging.info("Using fresh dataset to avoid video decoder state issues...") - # Create separate dataset instance to avoid corrupting main dataset's video decoders + logging.info("Computing global relative stats (first 1000 batches)...") stats_dataset = make_dataset(cfg) temp_loader = torch.utils.data.DataLoader( stats_dataset, batch_size=cfg.batch_size, shuffle=False, num_workers=0 ) - reverse_rename = {v: k for k, v in cfg.rename_map.items()} if cfg.rename_map else {} - raw_state_key = reverse_rename.get("observation.state", "observation.state") - mean, std = compute_relative_action_stats(temp_loader, state_key=raw_state_key, num_batches=1000) + rel_stats = compute_global_relative_stats( + temp_loader, state_key=raw_state_key, + convert_state=cfg.use_relative_state, num_batches=1000, + ) del temp_loader, stats_dataset gc.collect() - torch.save({"mean": mean, "std": std}, stats_path) - logging.info(f"Saved stats to: {stats_path}") - - # Poll for stats file instead of using NCCL barrier (avoids timeout during long computation) + torch.save(rel_stats, stats_path) + logging.info(f"Saved relative stats to: {stats_path}") + if not is_main_process: while not stats_path.exists(): time.sleep(5) - - data = torch.load(stats_path, weights_only=True, map_location="cpu") - relative_normalizer = PerTimestepNormalizer(data["mean"], data["std"]) - accelerator.wait_for_everyone() # Sync after everyone has loaded + + rel_stats = torch.load(stats_path, weights_only=True, map_location="cpu") + + # Replace absolute stats with relative stats in the normalizer + updated_stats = dict(dataset.meta.stats) + updated_stats["action"] = { + **updated_stats["action"], + "mean": rel_stats["action_mean"].numpy(), + "std": rel_stats["action_std"].numpy(), + } + if cfg.use_relative_state and "state_mean" in rel_stats: + updated_stats[raw_state_key] = { + **updated_stats.get(raw_state_key, {}), + "mean": rel_stats["state_mean"].numpy(), + "std": rel_stats["state_std"].numpy(), + } + preprocessor = hotswap_stats(preprocessor, updated_stats) + logging.info("Hotswapped normalizer stats with relative stats") + accelerator.wait_for_everyone() step = 0 # number of policy updates (forward + backward + optim) @@ -427,13 +444,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): for _ in range(step, cfg.steps): start_time = time.perf_counter() batch = next(dl_iter) - batch = preprocessor(batch) - - # Convert to relative actions (and optionally state) if enabled + + # Convert to relative on raw data BEFORE normalization if cfg.use_relative_actions: - batch = convert_to_relative_actions(batch, convert_state=cfg.use_relative_state) - if relative_normalizer is not None: - batch["action"] = relative_normalizer.normalize(batch["action"]) + batch = convert_to_relative_actions( + batch, state_key=raw_state_key, convert_state=cfg.use_relative_state, + ) + + batch = preprocessor(batch) train_tracker.dataloading_s = time.perf_counter() - start_time @@ -489,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): preprocessor=preprocessor, postprocessor=postprocessor, ) - # Save relative action stats with checkpoint - if relative_normalizer is not None: - relative_normalizer.save(checkpoint_dir / "relative_stats.pt") update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) diff --git a/src/lerobot/utils/relative_actions.py b/src/lerobot/utils/relative_actions.py index 888cb63a0..c9176812c 100644 --- a/src/lerobot/utils/relative_actions.py +++ b/src/lerobot/utils/relative_actions.py @@ -74,11 +74,53 @@ def compute_relative_action_stats( rel = action.clone() rel[..., :min_dim] -= current_pos[:, None, :min_dim] all_rel.append(rel) - + all_rel = torch.cat(all_rel, dim=0) return all_rel.mean(dim=0), all_rel.std(dim=0).clamp(min=1e-6) +def compute_global_relative_stats( + dataloader, + state_key: str = "observation.state", + convert_state: bool = True, + num_batches: int | None = None, +) -> dict[str, torch.Tensor]: + """Compute global mean/std for relative actions (and state) across all timesteps. + + Returns stats compatible with the standard MEAN_STD normalizer (shape = action_dim). + """ + all_rel_actions = [] + all_rel_states = [] + for i, batch in enumerate(dataloader): + if num_batches is not None and i >= num_batches: + break + action, state = batch["action"], batch[state_key] + current_pos = state[:, -1, :] if state.dim() == 3 else state + + min_dim = min(action.shape[-1], current_pos.shape[-1]) + rel = action.clone() + rel[..., :min_dim] -= current_pos[:, None, :min_dim] + all_rel_actions.append(rel.reshape(-1, rel.shape[-1])) + + if convert_state: + if state.dim() == 3: + rel_state = state - current_pos[:, None, :] + else: + rel_state = torch.zeros_like(state) + all_rel_states.append(rel_state.reshape(-1, rel_state.shape[-1])) + + all_rel_actions = torch.cat(all_rel_actions, dim=0) + result = { + "action_mean": all_rel_actions.mean(dim=0), + "action_std": all_rel_actions.std(dim=0).clamp(min=1e-6), + } + if convert_state and all_rel_states: + all_rel_states = torch.cat(all_rel_states, dim=0) + result["state_mean"] = all_rel_states.mean(dim=0) + result["state_std"] = all_rel_states.std(dim=0).clamp(min=1e-6) + return result + + def convert_to_relative( batch: dict, state_key: str = "observation.state",