From 9c981300dd263def413f829e414881634ae44e25 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Sat, 21 Feb 2026 08:37:19 +0100 Subject: [PATCH] stats per chunck --- .../policies/pi0_fast/modeling_pi0_fast.py | 2 +- src/lerobot/scripts/lerobot_train.py | 43 +++++++++++-------- 2 files changed, 26 insertions(+), 19 deletions(-) diff --git a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py index d1e8515b0..fd8b80359 100644 --- a/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py +++ b/src/lerobot/policies/pi0_fast/modeling_pi0_fast.py @@ -1321,7 +1321,7 @@ class PI0FastPolicy(PreTrainedPolicy): state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) continuous_actions = to_absolute_actions( continuous_actions, state, [True] * continuous_actions.shape[-1] - ) + ) return continuous_actions diff --git a/src/lerobot/scripts/lerobot_train.py b/src/lerobot/scripts/lerobot_train.py index ce6b262d1..e5fb11b92 100644 --- a/src/lerobot/scripts/lerobot_train.py +++ b/src/lerobot/scripts/lerobot_train.py @@ -244,34 +244,41 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None): policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides) # Recompute action stats as delta if use_delta_actions is enabled. - # Must iterate the actual dataset (which returns action chunks via delta_timestamps) - # so stats capture the full range of chunk-level deltas, not just per-frame deltas. - # We sample a subset for speed — 1M frames is sufficient for accurate stats. + # Must build action CHUNKS (like the model sees) and subtract state from each chunk. + # hf_dataset stores per-frame data; we manually assemble chunks to match delta_timestamps. if getattr(cfg.policy, "use_delta_actions", False) and is_main_process: import numpy as np from lerobot.datasets.compute_stats import get_feature_stats from lerobot.processor.delta_action_processor import to_delta_actions - max_samples = min(1000000, len(dataset)) - indices = np.random.choice(len(dataset), max_samples, replace=False).tolist() + chunk_size = cfg.policy.chunk_size + hf = dataset.hf_dataset + total_frames = len(hf) + max_samples = min(100_000, total_frames - chunk_size) + indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False) logging.info( - f"use_delta_actions is enabled — computing delta action stats from {max_samples} dataset chunks" + f"use_delta_actions is enabled — computing delta action stats " + f"from {max_samples} chunk samples (chunk_size={chunk_size})" ) - # Read only action and state from parquet (no video decoding) - hf = dataset.hf_dataset - actions_raw = hf.select(indices)["action"] - states_raw = hf.select(indices)["observation.state"] - + # Build chunks: for each index i, read actions[i:i+chunk_size] and state[i] all_delta_actions = [] - for action, state in zip(actions_raw, states_raw): - action = torch.as_tensor(action).float() - state = torch.as_tensor(state).float() - if action.ndim == 1: - action = action.unsqueeze(0) - mask = [True] * action.shape[-1] - delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) + episode_indices = np.array(hf["episode_index"]) + for idx in indices: + idx = int(idx) + # Ensure chunk doesn't cross episode boundary + ep_idx = episode_indices[idx] + end_idx = min(idx + chunk_size, total_frames) + if end_idx > idx and episode_indices[end_idx - 1] != ep_idx: + continue + + chunk_data = hf[idx:end_idx] + actions = torch.as_tensor(np.stack(chunk_data["action"])).float() + state = torch.as_tensor(np.array(chunk_data["observation.state"][0])).float() + + mask = [True] * actions.shape[-1] + delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0) all_delta_actions.append(delta.numpy()) all_delta = np.concatenate(all_delta_actions, axis=0)