diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index dc84e4244..53274ed49 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -243,13 +243,35 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): peft_cli_overrides = dataclasses.asdict(cfg.peft) policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) - # Recompute action stats as delta if use_delta_actions is enabled + # Recompute action stats as delta if use_delta_actions is enabled. + # Must iterate the actual dataset (which returns action chunks via delta_timestamps) + # so stats capture the full range of chunk-level deltas, not just per-frame deltas. if getattr(cfg.policy, "use_delta_actions", False) and is_main_process: - logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)") - from lerobot.datasets.dataset_tools import recompute_stats + logging.info("use_delta_actions is enabled — computing delta action stats from dataset chunks") + from lerobot.datasets.compute_stats import get_feature_stats + from lerobot.processor.delta_action_processor import to_delta_actions - exclude = getattr(cfg.policy, "delta_exclude_joints", []) - recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude) + all_delta_actions = [] + for i in range(len(dataset)): + item = dataset[i] + action = item["action"] + state = item["observation.state"] + # action may be (chunk_size, action_dim) or (action_dim,) + if action.ndim == 1: + action = action.unsqueeze(0) + mask = [True] * action.shape[-1] + delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + all_delta_actions.append(delta.numpy()) + + import numpy as np + + all_delta = np.concatenate(all_delta_actions, axis=0) + delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1) + dataset.meta.stats["action"] = delta_stats + logging.info( + f"Delta action stats computed from {len(dataset)} samples: " + f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}" + ) # Wait for all processes to finish policy creation before continuing accelerator.wait_for_everyone()