add recomputation of stats and option to compute delta stats

This commit is contained in:
Pepijn
2026-02-20 17:59:06 +01:00
parent 4fa41ba806
commit 7e6b598a51
7 changed files with 153 additions and 6 deletions
+105 -1
View File
@@ -37,7 +37,7 @@ import torch
from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
@@ -1522,6 +1522,110 @@ def modify_tasks(
return dataset
def recompute_stats(
dataset: LeRobotDataset,
skip_image_video: bool = True,
delta_action: bool = False,
delta_exclude_joints: list[str] | None = None,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
Args:
dataset: The LeRobotDataset to recompute stats for.
skip_image_video: If True (default), only recompute stats for numeric features
(action, state, etc.) and keep existing image/video stats unchanged.
delta_action: If True, compute action stats as delta (action - state).
Useful when training with use_delta_actions=True so normalization matches.
delta_exclude_joints: Joint names to exclude from delta conversion when
delta_action=True. These dims keep absolute stats. Uses dataset's
action feature names to build the mask. Default: ["gripper"].
Returns:
The same dataset with updated stats.
"""
features = dataset.meta.features
numeric_features = {
k: v for k, v in features.items()
if v["dtype"] not in ["image", "video", "string"]
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
if skip_image_video:
features_to_compute = numeric_features
else:
features_to_compute = {
k: v for k, v in features.items()
if v["dtype"] != "string"
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
# Build delta mask if delta_action is enabled
delta_mask = None
if delta_action and "action" in features and "observation.state" in features:
if delta_exclude_joints is None:
delta_exclude_joints = ["gripper"]
action_names = features["action"].get("names")
if action_names is not None:
exclude = set(delta_exclude_joints)
delta_mask = [n not in exclude for n in action_names]
else:
action_dim = features["action"]["shape"][0]
delta_mask = [True] * action_dim
logging.info(f"Delta action stats enabled (exclude: {delta_exclude_joints})")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
for ep_idx in sorted(df["episode_index"].unique()):
ep_df = df[df["episode_index"] == ep_idx]
episode_data = {}
for key in numeric_keys:
if key in ep_df.columns:
values = ep_df[key].values
if hasattr(values[0], "__len__"):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.array(values)
# Apply delta conversion to actions before computing stats
if delta_mask is not None and "action" in episode_data and "observation.state" in episode_data:
from lerobot.processor.delta_action_processor import to_delta_actions
actions_t = torch.from_numpy(episode_data["action"]).float()
states_t = torch.from_numpy(episode_data["observation.state"]).float()
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
if not all_episode_stats:
logging.warning("No episode stats computed")
return dataset
new_stats = aggregate_stats(all_episode_stats)
if skip_image_video and dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value
write_stats(new_stats, dataset.root)
dataset.meta.stats = new_stats
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
+7
View File
@@ -470,6 +470,13 @@ def make_policy(
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names on config for delta_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
@@ -50,8 +50,12 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state)
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute).
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
+14 -2
View File
@@ -1221,6 +1221,18 @@ class PI0Policy(PreTrainedPolicy):
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion.
Uses action_feature_names and delta_exclude_joints to determine which
dims get delta conversion. Falls back to all-True if names are unavailable.
"""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
@@ -1261,7 +1273,7 @@ class PI0Policy(PreTrainedPolicy):
actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions
@@ -1281,7 +1293,7 @@ class PI0Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
if self.config.use_delta_actions:
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
@@ -52,6 +52,10 @@ class PI05Config(PreTrainedConfig):
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute).
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
+10 -2
View File
@@ -1201,6 +1201,14 @@ class PI05Policy(PreTrainedPolicy):
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion."""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
@@ -1236,7 +1244,7 @@ class PI05Policy(PreTrainedPolicy):
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions
@@ -1257,7 +1265,7 @@ class PI05Policy(PreTrainedPolicy):
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
+8
View File
@@ -243,6 +243,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
peft_cli_overrides = dataclasses.asdict(cfg.peft)
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
# Recompute action stats as delta if use_delta_actions is enabled
if getattr(cfg.policy, "use_delta_actions", False) and is_main_process:
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
from lerobot.datasets.dataset_tools import recompute_stats
exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"])
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()