stats per chunck

This commit is contained in:
Pepijn
2026-02-21 08:37:19 +01:00
parent 33cedc2f71
commit 9c981300dd
2 changed files with 26 additions and 19 deletions
@@ -1321,7 +1321,7 @@ class PI0FastPolicy(PreTrainedPolicy):
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
continuous_actions = to_absolute_actions(
continuous_actions, state, [True] * continuous_actions.shape[-1]
)
)
return continuous_actions
+25 -18
View File
@@ -244,34 +244,41 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
# Recompute action stats as delta if use_delta_actions is enabled.
# Must iterate the actual dataset (which returns action chunks via delta_timestamps)
# so stats capture the full range of chunk-level deltas, not just per-frame deltas.
# We sample a subset for speed — 1M frames is sufficient for accurate stats.
# Must build action CHUNKS (like the model sees) and subtract state from each chunk.
# hf_dataset stores per-frame data; we manually assemble chunks to match delta_timestamps.
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
import numpy as np
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.processor.delta_action_processor import to_delta_actions
max_samples = min(1000000, len(dataset))
indices = np.random.choice(len(dataset), max_samples, replace=False).tolist()
chunk_size = cfg.policy.chunk_size
hf = dataset.hf_dataset
total_frames = len(hf)
max_samples = min(100_000, total_frames - chunk_size)
indices = np.random.choice(total_frames - chunk_size, max_samples, replace=False)
logging.info(
f"use_delta_actions is enabled — computing delta action stats from {max_samples} dataset chunks"
f"use_delta_actions is enabled — computing delta action stats "
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
)
# Read only action and state from parquet (no video decoding)
hf = dataset.hf_dataset
actions_raw = hf.select(indices)["action"]
states_raw = hf.select(indices)["observation.state"]
# Build chunks: for each index i, read actions[i:i+chunk_size] and state[i]
all_delta_actions = []
for action, state in zip(actions_raw, states_raw):
action = torch.as_tensor(action).float()
state = torch.as_tensor(state).float()
if action.ndim == 1:
action = action.unsqueeze(0)
mask = [True] * action.shape[-1]
delta = to_delta_actions(action.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
episode_indices = np.array(hf["episode_index"])
for idx in indices:
idx = int(idx)
# Ensure chunk doesn't cross episode boundary
ep_idx = episode_indices[idx]
end_idx = min(idx + chunk_size, total_frames)
if end_idx > idx and episode_indices[end_idx - 1] != ep_idx:
continue
chunk_data = hf[idx:end_idx]
actions = torch.as_tensor(np.stack(chunk_data["action"])).float()
state = torch.as_tensor(np.array(chunk_data["observation.state"][0])).float()
mask = [True] * actions.shape[-1]
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_delta_actions.append(delta.numpy())
all_delta = np.concatenate(all_delta_actions, axis=0)