Compare commits

...

39 Commits

Author SHA1 Message Date
Pepijn 0394fae446 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-22 16:12:10 +01:00
Pepijn 602b8e66a6 fix multi gpu processor bug 2026-02-22 16:11:52 +01:00
Pepijn ab4dce6fed revert 2026-02-21 18:48:46 +01:00
Pepijn 40f4386e4a nccl 2026-02-21 18:44:35 +01:00
Pepijn 87a91b4b08 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 18:19:51 +01:00
Pepijn fadb900c36 compute before dist 2026-02-21 18:19:12 +01:00
Pepijn de0663226a max 1m frames 2026-02-21 17:44:12 +01:00
Pepijn 0ca9d66cae max 1m frames 2026-02-21 17:43:58 +01:00
Pepijn 2222f25da3 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 17:28:35 +01:00
Pepijn acae8417aa fix 2026-02-21 17:28:26 +01:00
Pepijn 2697f65cf6 stats for entire dataset 2026-02-21 17:15:55 +01:00
Pepijn 74f42f218e stats for entire dataset 2026-02-21 17:15:45 +01:00
Pepijn ca9d49e305 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 17:12:52 +01:00
Pepijn 6705876d47 use quantiles 2026-02-21 17:12:43 +01:00
Pepijn aadbd27675 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 08:48:39 +01:00
Pepijn 5221647b5e fix 2026-02-21 08:48:08 +01:00
Pepijn 9c981300dd stats per chunck 2026-02-21 08:37:38 +01:00
Pepijn f5b27aad1b stats per chunck 2026-02-21 08:37:19 +01:00
Pepijn 75f1285507 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 08:02:39 +01:00
Pepijn 33cedc2f71 sample 1m 2026-02-21 08:02:25 +01:00
Pepijn aa32e6c4ab Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 07:52:10 +01:00
Pepijn f906270ec4 load from parquet 2026-02-21 07:51:57 +01:00
Pepijn 733b6d84db Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 07:42:00 +01:00
Pepijn 8abc9037a3 sample 100k 2026-02-21 07:41:42 +01:00
Pepijn e4d4ac0bda Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-21 00:03:37 +01:00
Pepijn e79b2a439b calulate chunk based stats 2026-02-21 00:03:21 +01:00
Pepijn f9ae78ca74 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 23:04:36 +01:00
Pepijn e1ced538e3 only recompute state for stats 2026-02-20 23:04:20 +01:00
Pepijn 2a98602ad6 Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 22:54:46 +01:00
Pepijn a2f5b3571e normalzie after delta conversion 2026-02-20 22:54:29 +01:00
Pepijn cecf2eff4f Merge branch 'feat/add_relative_action_pi_models' into feat/mirror 2026-02-20 17:59:19 +01:00
Pepijn 7e6b598a51 add recomputation of stats and option to compute delta stats 2026-02-20 17:59:06 +01:00
Pepijn 4fa41ba806 formatting 2026-02-13 17:46:18 +01:00
Pepijn 1de2b87a92 Add option for pi family models to train with relative actions (relative to state) 2026-02-13 17:45:59 +01:00
Pepijn e3c511db67 add push to hub 2026-02-05 09:25:49 +01:00
Pepijn aed4130d39 add swap wrist camera's 2026-02-04 22:47:28 +01:00
Pepijn d26349c692 add push to hub 2026-02-04 19:17:40 +01:00
Pepijn a9bce4732b fix setting metadata 2026-02-04 19:04:43 +01:00
Pepijn 86d69e3c1d add mirroring 2026-02-04 18:56:51 +01:00
15 changed files with 1194 additions and 21 deletions
+117 -1
View File
@@ -37,7 +37,7 @@ import torch
from tqdm import tqdm
from lerobot.datasets.aggregate import aggregate_datasets
from lerobot.datasets.compute_stats import aggregate_stats
from lerobot.datasets.compute_stats import aggregate_stats, compute_episode_stats, get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
@@ -1522,6 +1522,122 @@ def modify_tasks(
return dataset
def recompute_stats(
dataset: LeRobotDataset,
skip_image_video: bool = True,
delta_action: bool = False,
delta_exclude_joints: list[str] | None = None,
) -> LeRobotDataset:
"""Recompute stats.json from scratch by iterating all episodes.
Args:
dataset: The LeRobotDataset to recompute stats for.
skip_image_video: If True (default), only recompute stats for numeric features
(action, state, etc.) and keep existing image/video stats unchanged.
delta_action: If True, compute action stats as delta (action - state).
Useful when training with use_delta_actions=True so normalization matches.
delta_exclude_joints: Joint names to exclude from delta conversion when
delta_action=True. These dims keep absolute stats. Uses dataset's
action feature names to build the mask. Default: ["gripper"].
Returns:
The same dataset with updated stats.
"""
features = dataset.meta.features
numeric_features = {
k: v for k, v in features.items()
if v["dtype"] not in ["image", "video", "string"]
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
if skip_image_video:
features_to_compute = numeric_features
else:
features_to_compute = {
k: v for k, v in features.items()
if v["dtype"] != "string"
and k not in ["index", "episode_index", "task_index", "frame_index", "timestamp"]
}
# Build delta mask if delta_action is enabled
delta_mask = None
if delta_action and "action" in features and "observation.state" in features:
if delta_exclude_joints is None:
delta_exclude_joints = ["gripper"]
action_names = features["action"].get("names")
if action_names is not None:
exclude = set(delta_exclude_joints)
delta_mask = [n not in exclude for n in action_names]
else:
action_dim = features["action"]["shape"][0]
delta_mask = [True] * action_dim
# Only recompute action stats when delta is enabled — state stays unchanged
features_to_compute = {"action": features["action"]}
logging.info(f"Recomputing action stats as delta (exclude: {delta_exclude_joints})")
else:
logging.info(f"Recomputing stats for features: {list(features_to_compute.keys())}")
data_dir = dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
all_episode_stats = []
numeric_keys = [k for k, v in features_to_compute.items() if v["dtype"] not in ["image", "video"]]
# Also need state for delta computation even though we don't recompute state stats
needs_state = delta_mask is not None
for parquet_path in tqdm(parquet_files, desc="Computing stats from data files"):
df = pd.read_parquet(parquet_path)
for ep_idx in sorted(df["episode_index"].unique()):
ep_df = df[df["episode_index"] == ep_idx]
episode_data = {}
for key in numeric_keys:
if key in ep_df.columns:
values = ep_df[key].values
if hasattr(values[0], "__len__"):
episode_data[key] = np.stack(values)
else:
episode_data[key] = np.array(values)
# Apply delta conversion to actions before computing stats
if delta_mask is not None and "action" in episode_data:
from lerobot.processor.delta_action_processor import to_delta_actions
# Load state for delta even if we're not computing state stats
if needs_state and "observation.state" in ep_df.columns:
state_values = ep_df["observation.state"].values
if hasattr(state_values[0], "__len__"):
states = np.stack(state_values)
else:
states = np.array(state_values)
actions_t = torch.from_numpy(episode_data["action"]).float()
states_t = torch.from_numpy(states).float()
episode_data["action"] = to_delta_actions(actions_t, states_t, delta_mask).numpy()
ep_stats = compute_episode_stats(episode_data, features_to_compute)
all_episode_stats.append(ep_stats)
if not all_episode_stats:
logging.warning("No episode stats computed")
return dataset
new_stats = aggregate_stats(all_episode_stats)
# Merge: keep existing stats for features we didn't recompute
if dataset.meta.stats:
for key, value in dataset.meta.stats.items():
if key not in new_stats:
new_stats[key] = value
write_stats(new_stats, dataset.root)
dataset.meta.stats = new_stats
logging.info(f"Stats recomputed for {len(all_episode_stats)} episodes")
return dataset
def convert_image_to_video_dataset(
dataset: LeRobotDataset,
output_dir: Path,
+7
View File
@@ -470,6 +470,13 @@ def make_policy(
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names for delta_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
cfg.action_feature_names = list(action_names)
kwargs["config"] = cfg
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
@@ -50,6 +50,13 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
+11 -1
View File
@@ -21,8 +21,10 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -126,7 +128,13 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -138,6 +146,7 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
delta_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
@@ -149,6 +158,7 @@ def make_pi0_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]
@@ -50,6 +50,13 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
+13 -1
View File
@@ -25,7 +25,9 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -129,10 +131,19 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
delta_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
@@ -154,6 +165,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
max_action_dim: int = 32
max_action_tokens: int = 256
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions
from lerobot.utils.constants import (
ACTION,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
action_tokens, action_horizon=action_horizon, action_dim=action_dim
)
if self.config.use_delta_actions and OBS_STATE in batch:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
continuous_actions = to_absolute_actions(
continuous_actions, state, [True] * continuous_actions.shape[-1]
)
return continuous_actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
padding_side="right",
padding="max_length",
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
+12 -1
View File
@@ -28,7 +28,14 @@ from .core import (
RobotObservation,
TransitionKey,
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .delta_action_processor import (
AbsoluteActionsProcessorStep,
DeltaActionsProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
to_absolute_actions,
to_delta_actions,
)
from .device_processor import DeviceProcessorStep
from .factory import (
make_default_processors,
@@ -97,6 +104,8 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeltaActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -126,6 +135,8 @@ __all__ = [
"transition_to_batch",
"TransitionKey",
"TruncatedProcessorStep",
"to_absolute_actions",
"to_delta_actions",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]
+168 -3
View File
@@ -14,12 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from collections.abc import Sequence
from dataclasses import dataclass, field
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_STATE
from .core import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to delta: delta = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@@ -141,3 +183,126 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
)
return features
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class DeltaActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the delta conversion.
exclude_joints: Joint names to keep absolute (not converted to delta).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude_tokens = [str(name).lower() for name in self.exclude_joints if name]
if not exclude_tokens:
return [True] * action_dim
mask = []
for name in self.action_names[:action_dim]:
action_name = str(name).lower()
is_excluded = any(token == action_name or token in action_name for token in exclude_tokens)
mask.append(not is_excluded)
if len(mask) < action_dim:
mask.extend([True] * (action_dim - len(mask)))
return mask
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted deltas are converted back to absolute positions for execution.
Reads the cached state from its paired DeltaActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
"""
enabled: bool = False
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.delta_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
)
if self.delta_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
mask = self.delta_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.delta_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+3 -9
View File
@@ -331,11 +331,9 @@ class _NormalizationMixin:
)
mean, std = stats["mean"], stats["std"]
# Avoid division by zero by adding a small epsilon.
denom = std + self.eps
if inverse:
return tensor * std + mean
return (tensor - mean) / denom
return tensor * (std + 1e-6) + mean
return (tensor - mean) / (std + 1e-6)
if norm_mode == NormalizationMode.MIN_MAX:
min_val = stats.get("min", None)
@@ -367,11 +365,7 @@ class _NormalizationMixin:
"QUANTILES normalization mode requires q01 and q99 stats, please update the dataset with the correct stats using the `augment_dataset_quantile_stats.py` script"
)
denom = q99 - q01
# Avoid division by zero by adding epsilon when quantiles are identical
denom = torch.where(
denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
)
denom = q99 - q01 + 1e-6
if inverse:
return (tensor + 1.0) * denom / 2.0 + q01
return 2.0 * (tensor - q01) / denom - 1.0
@@ -0,0 +1,366 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Mirror a bimanual robot dataset by swapping left/right arms and inverting joint values.
This script creates a mirrored version of a dataset where:
1. Left and right arm observations/actions are swapped
2. Joint values are inverted according to a mirroring mask
3. Video frames are horizontally flipped
Example usage:
```shell
python -m lerobot.scripts.lerobot_mirror_dataset \
--repo_id=pepijn/openarm_bimanual \
--output_repo_id=pepijn/openarm_bimanual_mirrored
```
"""
import argparse
import logging
import os
import subprocess
from concurrent.futures import ThreadPoolExecutor, as_completed
from pathlib import Path
import numpy as np
import pandas as pd
from tqdm import tqdm
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.datasets.utils import (
DATA_DIR,
DEFAULT_DATA_PATH,
write_info,
write_stats,
write_tasks,
)
from lerobot.utils.constants import HF_LEROBOT_HOME
logger = logging.getLogger(__name__)
OPENARM_MIRRORING_MASK = {
"joint_1": -1, # Pan - invert
"joint_2": -1, # Lift - invert
"joint_3": -1, # Roll - invert
"joint_4": 1, # Elbow - no invert
"joint_5": -1, # W-Roll - invert
"joint_6": -1, # W-Pitch - invert
"joint_7": -1, # W-Yaw - invert
"gripper": 1, # Gripper - no invert
}
def get_mirroring_mask(robot_type: str) -> dict[str, int]:
"""Get the mirroring mask for a given robot type."""
if robot_type in ["bi_openarm_follower", "openarm_follower", "bi_openarms_follower", "openarms_follower"]:
return OPENARM_MIRRORING_MASK
raise ValueError(f"Unknown robot type: {robot_type}. Add a mirroring mask for this robot.")
def swap_left_right_name(name: str) -> str:
"""Swap 'left' and 'right' in a feature name."""
# Use placeholder to avoid double-swap
result = name.replace("left_", "LEFT_PLACEHOLDER_")
result = result.replace("right_", "left_")
result = result.replace("LEFT_PLACEHOLDER_", "right_")
return result
def mirror_feature_names(names: list[str]) -> tuple[list[str], dict[int, int]]:
"""Mirror feature names by swapping left/right and return the new names and index mapping."""
mirrored_names = [swap_left_right_name(n) for n in names]
old_to_new_idx = {}
for old_idx, old_name in enumerate(names):
new_name = swap_left_right_name(old_name)
new_idx = mirrored_names.index(new_name)
old_to_new_idx[old_idx] = new_idx
return mirrored_names, old_to_new_idx
def apply_mirroring_mask(
value: float,
feature_name: str,
mirroring_mask: dict[str, int],
) -> float:
"""Apply mirroring mask to a joint value."""
name_without_prefix = feature_name.split("_", 1)[1] if "_" in feature_name else feature_name
joint_name = name_without_prefix.split(".")[0]
if joint_name in mirroring_mask:
return value * mirroring_mask[joint_name]
return value
def mirror_array(
array: np.ndarray,
names: list[str],
mirroring_mask: dict[str, int],
) -> np.ndarray:
"""Mirror an array of values (action or state) by swapping left/right and applying mask."""
mirrored_names, idx_mapping = mirror_feature_names(names)
result = np.zeros_like(array)
for old_idx, new_idx in idx_mapping.items():
old_name = names[old_idx]
new_name = mirrored_names[new_idx]
value = array[old_idx]
mirrored_value = apply_mirroring_mask(value, new_name, mirroring_mask)
result[new_idx] = mirrored_value
return result
def flip_video_frames(
input_path: Path,
output_path: Path,
fps: float,
vcodec: str = "libsvtav1",
):
"""Flip video frames horizontally using FFmpeg with same settings as encode_video_frames."""
output_path.parent.mkdir(parents=True, exist_ok=True)
cmd = [
"ffmpeg", "-y", "-i", str(input_path),
"-vf", "hflip",
"-c:v", vcodec,
"-g", "2",
"-crf", "30",
"-r", str(int(fps)),
"-pix_fmt", "yuv420p",
"-loglevel", "error",
]
if vcodec == "libsvtav1":
cmd.extend(["-preset", "12"])
cmd.append(str(output_path))
result = subprocess.run(cmd, capture_output=True, text=True)
if result.returncode != 0:
raise RuntimeError(f"FFmpeg failed: {result.stderr}")
def mirror_dataset(
repo_id: str,
output_repo_id: str,
root: str | Path | None = None,
output_root: str | Path | None = None,
mirroring_mask: dict[str, int] | None = None,
vcodec: str = "libsvtav1",
num_workers: int | None = None,
) -> LeRobotDataset:
"""Mirror a bimanual robot dataset."""
logger.info(f"Loading dataset: {repo_id}")
dataset = LeRobotDataset(repo_id, root=root)
if mirroring_mask is None:
robot_type = dataset.meta.robot_type or "bi_openarms_follower"
mirroring_mask = get_mirroring_mask(robot_type)
logger.info(f"Using mirroring mask for robot type: {robot_type}")
output_root = Path(output_root) if output_root else HF_LEROBOT_HOME / output_repo_id
mirrored_features = {}
for key, feat in dataset.meta.features.items():
new_key = swap_left_right_name(key)
new_feat = feat.copy()
if "names" in new_feat and new_feat["names"]:
new_feat["names"] = [swap_left_right_name(n) for n in new_feat["names"]]
mirrored_features[new_key] = new_feat
logger.info("Creating mirrored dataset metadata...")
new_meta = LeRobotDatasetMetadata.create(
repo_id=output_repo_id,
fps=dataset.meta.fps,
features=mirrored_features,
robot_type=dataset.meta.robot_type,
root=output_root,
use_videos=len(dataset.meta.video_keys) > 0,
)
if dataset.meta.tasks is not None:
write_tasks(dataset.meta.tasks, new_meta.root)
new_meta.tasks = dataset.meta.tasks.copy()
_mirror_data(dataset, new_meta, mirroring_mask)
_mirror_videos(dataset, new_meta, vcodec, num_workers)
_copy_episodes_metadata(dataset, new_meta)
logger.info(f"Mirrored dataset saved to: {output_root}")
return LeRobotDataset(output_repo_id, root=output_root)
def _mirror_data(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
mirroring_mask: dict[str, int],
) -> None:
"""Mirror parquet data files."""
data_dir = src_dataset.root / DATA_DIR
parquet_files = sorted(data_dir.glob("*/*.parquet"))
if not parquet_files:
raise ValueError(f"No parquet files found in {data_dir}")
action_names = src_dataset.meta.features.get("action", {}).get("names", [])
state_names = src_dataset.meta.features.get("observation.state", {}).get("names", [])
for src_path in tqdm(parquet_files, desc="Mirroring data files"):
df = pd.read_parquet(src_path).reset_index(drop=True)
relative_path = src_path.relative_to(src_dataset.root)
chunk_dir = relative_path.parts[1]
file_name = relative_path.parts[2]
chunk_idx = int(chunk_dir.split("-")[1])
file_idx = int(file_name.split("-")[1].split(".")[0])
if "action" in df.columns and action_names:
actions = np.stack(df["action"].values)
mirrored_actions = np.array([
mirror_array(row, action_names, mirroring_mask) for row in actions
])
df["action"] = list(mirrored_actions)
if "observation.state" in df.columns and state_names:
states = np.stack(df["observation.state"].values)
mirrored_states = np.array([
mirror_array(row, state_names, mirroring_mask) for row in states
])
df["observation.state"] = list(mirrored_states)
dst_path = dst_meta.root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
dst_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_path, index=False)
def _mirror_videos(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
vcodec: str,
num_workers: int | None = None,
) -> None:
"""Mirror video files by flipping horizontally and swapping left/right names."""
if not src_dataset.meta.video_keys:
return
video_tasks = []
for old_video_key in src_dataset.meta.video_keys:
new_video_key = swap_left_right_name(old_video_key)
for ep_idx in range(src_dataset.meta.total_episodes):
try:
src_path = src_dataset.root / src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
dst_relative = src_dataset.meta.get_video_file_path(ep_idx, old_video_key)
dst_relative_str = str(dst_relative).replace(old_video_key, new_video_key)
dst_path = dst_meta.root / dst_relative_str
if src_path.exists():
video_tasks.append((src_path, dst_path))
except KeyError:
continue
def process_video(task, pbar):
src_path, dst_path = task
pbar.set_postfix_str(src_path.name)
flip_video_frames(src_path, dst_path, src_dataset.meta.fps, vcodec)
return src_path
if num_workers is None:
num_workers = os.cpu_count() or 16
num_workers = min(len(video_tasks), num_workers)
logger.info(f"Processing {len(video_tasks)} videos with {num_workers} workers")
with tqdm(total=len(video_tasks), desc="Mirroring videos") as pbar:
with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = {executor.submit(process_video, t, pbar): t for t in video_tasks}
for future in as_completed(futures):
task = futures[future]
future.result()
pbar.set_postfix_str(f"done: {task[0].name}")
pbar.update(1)
def _copy_episodes_metadata(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
) -> None:
"""Copy episodes metadata with swapped video keys."""
episodes_dir = src_dataset.root / "meta/episodes"
dst_episodes_dir = dst_meta.root / "meta/episodes"
if episodes_dir.exists():
dst_episodes_dir.mkdir(parents=True, exist_ok=True)
for src_parquet in episodes_dir.glob("*/*.parquet"):
df = pd.read_parquet(src_parquet)
columns_to_rename = {}
for col in df.columns:
if col.startswith("videos/"):
parts = col.split("/")
if len(parts) >= 2:
video_key = parts[1]
new_video_key = swap_left_right_name(video_key)
new_col = col.replace(f"videos/{video_key}/", f"videos/{new_video_key}/")
columns_to_rename[col] = new_col
if columns_to_rename:
df = df.rename(columns=columns_to_rename)
dst_parquet = dst_episodes_dir / src_parquet.relative_to(episodes_dir)
dst_parquet.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(dst_parquet, index=False)
dst_meta.info.update({
"total_episodes": src_dataset.meta.total_episodes,
"total_frames": src_dataset.meta.total_frames,
"total_tasks": src_dataset.meta.total_tasks,
"total_videos": src_dataset.meta.total_videos,
"total_chunks": src_dataset.meta.total_chunks,
})
write_info(dst_meta.info, dst_meta.root)
if src_dataset.meta.stats is not None:
mirrored_stats = _mirror_stats(src_dataset.meta.stats)
write_stats(mirrored_stats, dst_meta.root)
def _mirror_stats(stats: dict) -> dict:
"""Mirror stats by swapping left/right in feature names."""
mirrored = {}
for key, value in stats.items():
new_key = swap_left_right_name(key)
if isinstance(value, dict):
mirrored[new_key] = _mirror_stats(value)
else:
mirrored[new_key] = value
return mirrored
def main():
logging.basicConfig(level=logging.INFO)
parser = argparse.ArgumentParser(description="Mirror a bimanual robot dataset")
parser.add_argument("--repo_id", type=str, required=True, help="Source dataset repo_id")
parser.add_argument("--output_repo_id", type=str, required=True, help="Output dataset repo_id")
parser.add_argument("--root", type=str, default=None, help="Source dataset root directory")
parser.add_argument("--output_root", type=str, default=None, help="Output dataset root directory")
parser.add_argument("--vcodec", type=str, default="libsvtav1", help="Video codec (libsvtav1, h264, hevc)")
parser.add_argument("--num_workers", type=int, default=None, help="Number of parallel workers for video processing")
parser.add_argument("--push-to-hub", action="store_true", help="Push mirrored dataset to HuggingFace Hub")
args = parser.parse_args()
dataset = mirror_dataset(
repo_id=args.repo_id,
output_repo_id=args.output_repo_id,
root=args.root,
output_root=args.output_root,
vcodec=args.vcodec,
num_workers=args.num_workers,
)
if getattr(args, "push_to_hub", False):
logger.info(f"Pushing dataset to HuggingFace Hub: {args.output_repo_id}")
dataset.push_to_hub()
if __name__ == "__main__":
main()
+126 -5
View File
@@ -175,8 +175,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
from accelerate.utils import DistributedDataParallelKwargs
ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
# Accelerate auto-detects the device based on the available hardware and ignores the policy.device setting.
# Force the device to be CPU when policy.device is set to CPU.
force_cpu = cfg.policy.device == "cpu"
accelerator = Accelerator(
step_scheduler_with_optimizer=False,
@@ -211,16 +209,98 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: main process downloads first to avoid race conditions
delta_action_stats = None
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
# Compute delta action stats BEFORE distributed sync to avoid NCCL timeout
if getattr(cfg.policy, "use_delta_actions", False):
import numpy as np
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.processor.delta_action_processor import DeltaActionsProcessorStep, to_delta_actions
chunk_size = cfg.policy.chunk_size
hf = dataset.hf_dataset
total_frames = len(hf)
sample_upper_bound = total_frames - chunk_size
if sample_upper_bound <= 0:
raise ValueError(
f"Cannot compute delta action stats: total_frames={total_frames}, chunk_size={chunk_size}"
)
max_samples = min(100_000, sample_upper_bound)
indices = np.random.choice(sample_upper_bound, max_samples, replace=False)
action_names = dataset.meta.features.get("action", {}).get("names")
delta_mask_step = DeltaActionsProcessorStep(
enabled=True,
exclude_joints=getattr(cfg.policy, "delta_exclude_joints", []),
action_names=action_names,
)
delta_mask = delta_mask_step._build_mask(dataset.meta.features["action"]["shape"][0])
logging.info(
f"use_delta_actions is enabled — computing delta action stats "
f"from {max_samples} chunk samples (chunk_size={chunk_size})"
)
all_delta_actions = []
episode_indices = np.array(hf["episode_index"])
for idx in indices:
idx = int(idx)
ep_idx = episode_indices[idx]
end_idx = min(idx + chunk_size, total_frames)
if end_idx > idx and episode_indices[end_idx - 1] != ep_idx:
continue
chunk_data = hf[idx:end_idx]
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), delta_mask).squeeze(0)
all_delta_actions.append(delta.numpy())
if not all_delta_actions:
raise RuntimeError("Failed to compute delta action stats: no valid chunks found.")
all_delta = np.concatenate(all_delta_actions, axis=0)
delta_stats = get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
delta_action_stats = delta_stats
dataset.meta.stats["action"] = delta_action_stats
norm_type = "UNKNOWN"
if hasattr(cfg.policy, "normalization_mapping"):
from lerobot.configs.types import NormalizationMode
action_norm = cfg.policy.normalization_mapping.get("ACTION", None)
norm_type = action_norm.value if action_norm else "UNKNOWN"
excluded_dims = len(delta_mask) - sum(delta_mask)
logging.info(
f"Delta action stats ({len(all_delta_actions)} chunks, {len(all_delta)} values, norm={norm_type}): "
f"delta_dims={sum(delta_mask)}/{len(delta_mask)} (excluded={excluded_dims}), "
f"mean={np.abs(delta_stats['mean']).mean():.4f}, std={delta_stats['std'].mean():.4f}, "
f"q01={delta_stats['q01'].mean():.4f}, q99={delta_stats['q99'].mean():.4f}"
)
if norm_type == "QUANTILES":
q_range = (delta_stats['q99'] - delta_stats['q01']).mean()
logging.info(f" Quantile range (q99-q01): {q_range:.4f}")
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Ensure all ranks use the exact same delta action stats.
if getattr(cfg.policy, "use_delta_actions", False):
if accelerator.num_processes > 1 and torch.distributed.is_initialized():
stats_list = [delta_action_stats]
torch.distributed.broadcast_object_list(stats_list, src=0)
delta_action_stats = stats_list[0]
if delta_action_stats is not None:
dataset.meta.stats["action"] = delta_action_stats
# Create environment used for evaluating checkpoints during training on simulation data.
# On real-world data, no need to create an environment as evaluations are done outside train.py,
# using the eval.py instead, with gym_dora environment and dora-rs.
@@ -246,10 +326,22 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Wait for all processes to finish policy creation before continuing
accelerator.wait_for_everyone()
processor_pretrained_path = cfg.policy.pretrained_path
if (
getattr(cfg.policy, "use_delta_actions", False)
and processor_pretrained_path is not None
and not cfg.resume
):
logging.warning(
"use_delta_actions=true with pretrained processors can skip delta transforms if "
"the checkpoint processors do not define them. Building processors from current policy config."
)
processor_pretrained_path = None
# Create processors - only provide dataset_stats if not resuming from saved processors
processor_kwargs = {}
postprocessor_kwargs = {}
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
if (processor_pretrained_path and not cfg.resume) or not processor_pretrained_path:
# Only provide dataset_stats when not resuming from saved processor state
processor_kwargs["dataset_stats"] = dataset.meta.stats
@@ -257,7 +349,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
if cfg.policy.type == "sarm":
processor_kwargs["dataset_meta"] = dataset.meta
if cfg.policy.pretrained_path is not None:
if processor_pretrained_path is not None:
processor_kwargs["preprocessor_overrides"] = {
"device_processor": {"device": device.type},
"normalizer_processor": {
@@ -279,7 +371,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
preprocessor, postprocessor = make_pre_post_processors(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
pretrained_path=processor_pretrained_path,
**processor_kwargs,
**postprocessor_kwargs,
)
@@ -397,7 +489,36 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
for _ in range(step, cfg.steps):
start_time = time.perf_counter()
batch = next(dl_iter)
# Debug logging for first few steps and periodically
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
action = batch.get("action")
state = batch.get("observation.state")
if action is not None and state is not None:
logging.info(
f"[DEBUG step={step}] PRE-PROCESSOR — "
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
f"min={action.min():.4f}, max={action.max():.4f} | "
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}"
)
batch = preprocessor(batch)
if is_main_process and (step < 3 or (cfg.log_freq > 0 and step % (cfg.log_freq * 10) == 0)):
action = batch.get("action")
state = batch.get("observation.state")
if action is not None:
logging.info(
f"[DEBUG step={step}] POST-PROCESSOR — "
f"action: shape={tuple(action.shape)}, mean={action.mean():.4f}, std={action.std():.4f}, "
f"min={action.min():.4f}, max={action.max():.4f}"
)
if state is not None:
logging.info(
f"[DEBUG step={step}] POST-PROCESSOR — "
f"state: shape={tuple(state.shape)}, mean={state.mean():.4f}, std={state.std():.4f}"
)
train_tracker.dataloading_s = time.perf_counter() - start_time
train_tracker, output_dict = update_policy(
+344
View File
@@ -0,0 +1,344 @@
"""Tests for delta action transforms — full pipeline validation.
Tests the complete flow matching OpenPI:
raw actions → DeltaActions → Normalize(delta_stats) → model → Unnormalize → AbsoluteActions
Uses real dataset: lerobot-data-collection/dagger_final_1_21
"""
import numpy as np
import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor.delta_action_processor import (
AbsoluteActionsProcessorStep,
DeltaActionsProcessorStep,
to_absolute_actions,
to_delta_actions,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.utils.constants import ACTION, OBS_STATE
CHUNK_SIZE = 10
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
@pytest.fixture(scope="module")
def dataset():
return LeRobotDataset(REPO_ID, episodes=[0])
@pytest.fixture(scope="module")
def action_dim(dataset):
return dataset.meta.features["action"]["shape"][0]
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
"""Build action chunks from hf_dataset, like the training script does."""
hf = dataset.hf_dataset
total = len(hf)
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
chunks, states = [], []
for i in range(total - chunk_size + 1):
if all_ep[i] != all_ep[i + chunk_size - 1]:
continue
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
state = hf[i]["observation.state"].float()
chunks.append(chunk_actions)
states.append(state)
if len(chunks) >= max_chunks:
break
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
return torch.stack(chunks), torch.stack(states)
def _compute_delta_chunk_stats(action_chunks, states, mask):
all_deltas = []
for actions, state in zip(action_chunks, states):
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_deltas.append(delta.numpy())
all_delta = np.concatenate(all_deltas, axis=0)
return get_feature_stats(all_delta, axis=0, keepdims=all_delta.ndim == 1)
# --- Basic roundtrip tests ---
def test_roundtrip_3d(action_dim):
actions = torch.randn(4, CHUNK_SIZE, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_roundtrip_2d(action_dim):
actions = torch.randn(4, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_delta_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_no_mutation(action_dim):
actions = torch.randn(2, CHUNK_SIZE, action_dim)
original = actions.clone()
state = torch.randn(2, action_dim)
to_delta_actions(actions, state, [True] * action_dim)
torch.testing.assert_close(actions, original)
def test_exclude_joints_supports_partial_name_matching():
names = [
"right_joint_1.pos",
"right_gripper.pos",
"left_joint_1.pos",
"left_gripper.pos",
]
step = DeltaActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
assert step._build_mask(len(names)) == [True, False, True, False]
# --- Chunk-level delta stats test ---
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
"""Chunk-level delta stats should have larger std than per-frame delta stats."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
chunk_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
# Per-frame stats
hf = dataset.hf_dataset
n = min(500, len(hf))
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
frame_deltas = to_delta_actions(frame_actions, frame_states, mask).numpy()
frame_stats = get_feature_stats(frame_deltas, axis=0, keepdims=frame_deltas.ndim == 1)
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
f"frame std ({frame_stats['std'].mean():.4f})"
)
# --- Full pipeline roundtrip: delta → normalize → unnormalize → absolute ---
def test_full_pipeline_roundtrip(dataset, action_dim):
"""Test the complete OpenPI pipeline: delta → normalize → unnormalize → absolute."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
delta_step = DeltaActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: delta → normalize
t1 = delta_step(transition)
t2 = normalizer(t1)
normalized_action = t2[TransitionKey.ACTION]
assert normalized_action.abs().mean() < 10, (
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
)
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered_actions = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
def test_normalized_delta_values_are_reasonable(dataset, action_dim):
"""With correct chunk stats, normalized delta actions should be in a reasonable range."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
mean = torch.tensor(delta_stats["mean"]).float()
std = torch.tensor(delta_stats["std"]).float()
all_normalized = []
for actions, state in zip(action_chunks, states):
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
normalized = (delta - mean) / (std + 1e-6)
all_normalized.append(normalized)
all_normalized = torch.cat(all_normalized, dim=0)
pct_in_range = (all_normalized.abs() < 5).float().mean()
assert pct_in_range > 0.9, (
f"Only {pct_in_range*100:.1f}% of normalized values in [-5, 5], expected >90%"
)
assert all_normalized.mean().abs() < 1.0, (
f"Mean of normalized deltas is {all_normalized.mean():.2f}, expected near 0"
)
def test_processor_step_roundtrip(dataset, action_dim):
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_actions = batch[ACTION].clone()
transition = batch_to_transition(batch)
step = DeltaActionsProcessorStep(enabled=True)
delta_transition = step(transition)
assert not torch.allclose(delta_transition[TransitionKey.ACTION], original_actions)
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
mask = [True] * action_dim
recovered = to_absolute_actions(delta_transition[TransitionKey.ACTION], state, mask)
torch.testing.assert_close(recovered, original_actions)
def test_processor_step_disabled_is_noop(dataset, action_dim):
"""enabled=False should be a no-op."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
}
original = batch[ACTION].clone()
transition = batch_to_transition(batch)
result = DeltaActionsProcessorStep(enabled=False)(transition)
torch.testing.assert_close(result[TransitionKey.ACTION], original)
# --- Training batch shape validation ---
def test_delta_with_action_chunks(dataset, action_dim):
"""Verify delta works correctly with (B, chunk_size, action_dim) shaped actions."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
batch_states = states[:4] # (4, state_dim)
mask = [True] * action_dim
delta = to_delta_actions(batch_actions, batch_states, mask)
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
first_deltas = delta[:, 0, :] # (B, action_dim)
assert first_deltas.abs().mean() < delta.abs().mean(), (
f"First action in chunk should have smaller delta than average. "
f"First: {first_deltas.abs().mean():.4f}, Average: {delta.abs().mean():.4f}"
)
# Later actions should have larger deltas
last_deltas = delta[:, -1, :] # (B, action_dim)
assert last_deltas.abs().mean() >= first_deltas.abs().mean(), (
f"Last action in chunk should have >= delta than first. "
f"Last: {last_deltas.abs().mean():.4f}, First: {first_deltas.abs().mean():.4f}"
)
# Roundtrip
recovered = to_absolute_actions(delta, batch_states, mask)
torch.testing.assert_close(recovered, batch_actions)
def test_delta_stats_match_actual_data_distribution(dataset, action_dim):
"""Verify computed stats match the actual delta distribution."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
# Compute stats like the training script does
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
# Also compute directly
all_deltas = []
for actions, state in zip(action_chunks, states):
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_deltas.append(delta)
all_deltas_tensor = torch.cat(all_deltas, dim=0)
# Compare mean
actual_mean = all_deltas_tensor.mean(dim=0).numpy()
np.testing.assert_allclose(delta_stats["mean"], actual_mean, atol=0.01)
# Compare std
actual_std = all_deltas_tensor.std(dim=0).numpy()
np.testing.assert_allclose(delta_stats["std"], actual_std, atol=0.1)
# Verify q01 < mean < q99
assert (delta_stats["q01"] < delta_stats["mean"]).all(), "q01 should be < mean"
assert (delta_stats["mean"] < delta_stats["q99"]).all(), "mean should be < q99"
def test_quantile_normalization_roundtrip(dataset, action_dim):
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
delta_stats = _compute_delta_chunk_stats(action_chunks, states, mask)
stats = {ACTION: {k: v for k, v in delta_stats.items()}}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
delta_step = DeltaActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, delta_step=delta_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: delta → quantile normalize
t1 = delta_step(transition)
t2 = normalizer(t1)
normalized = t2[TransitionKey.ACTION]
# Most values should be in [-1, 1] with quantile normalization
pct_in_range = (normalized.abs() < 2).float().mean()
assert pct_in_range > 0.5, (
f"Only {pct_in_range*100:.1f}% in [-2, 2] after quantile norm, expected >50%"
)
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
def test_state_not_modified_by_delta(dataset, action_dim):
"""State should never be modified by the delta processor."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_state = batch[OBS_STATE].clone()
transition = batch_to_transition(batch)
step = DeltaActionsProcessorStep(enabled=True)
result = step(transition)
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
torch.testing.assert_close(result_state, original_state)