Merge branch 'feat/add_relative_action_pi_models' into feat/mirror

This commit is contained in:
Pepijn
2026-02-20 22:54:46 +01:00
10 changed files with 108 additions and 56 deletions
+1 -1
View File
@@ -471,7 +471,7 @@ def make_policy(
if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names on config for delta_exclude_joints support
# Store action feature names for delta_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None:
@@ -52,9 +52,9 @@ class PI0Config(PreTrainedConfig):
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute).
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata.
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
-19
View File
@@ -44,7 +44,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -1221,18 +1220,6 @@ class PI0Policy(PreTrainedPolicy):
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion.
Uses action_feature_names and delta_exclude_joints to determine which
dims get delta conversion. Falls back to all-True if names are unavailable.
"""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
def prepare_action(self, batch):
"""Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
@@ -1272,9 +1259,6 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1292,9 +1276,6 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
actions = self.prepare_action(batch)
if self.config.use_delta_actions:
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
+10 -2
View File
@@ -21,6 +21,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeltaActionsProcessorStep,
@@ -127,7 +128,13 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
@@ -139,18 +146,19 @@ def make_pi0_pre_post_processors(
padding="max_length",
),
DeviceProcessorStep(device=config.device),
delta_step,
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
]
output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]
@@ -52,9 +52,9 @@ class PI05Config(PreTrainedConfig):
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute).
# Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata.
# Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration
@@ -44,12 +44,10 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -1201,14 +1199,6 @@ class PI05Policy(PreTrainedPolicy):
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion."""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations."""
@@ -1242,10 +1232,6 @@ class PI05Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1263,10 +1249,6 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
+12 -2
View File
@@ -25,6 +25,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
@@ -130,10 +131,19 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines.
"""
# Add remaining processors
delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
delta_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep(
@@ -141,7 +151,6 @@ def make_pi05_pre_post_processors(
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
@@ -156,6 +165,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"),
]
+2
View File
@@ -29,6 +29,7 @@ from .core import (
TransitionKey,
)
from .delta_action_processor import (
AbsoluteActionsProcessorStep,
DeltaActionsProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
@@ -103,6 +104,7 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeltaActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
@@ -15,7 +15,7 @@
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from dataclasses import dataclass, field
from typing import Any
import torch
@@ -188,33 +188,102 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class DeltaActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to delta actions (action -= state) for all dimensions.
"""Converts absolute actions to delta actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes:
enabled: Whether to apply the delta conversion.
exclude_joints: Joint names to keep absolute (not converted to delta).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
"""
enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude = set(self.exclude_joints)
return [n not in exclude for n in self.action_names]
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted deltas are converted back to absolute positions for execution.
Reads the cached state from its paired DeltaActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
"""
enabled: bool = False
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
if self.delta_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
)
if self.delta_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
observation = new_transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
if state is None:
return new_transition
mask = [True] * action.shape[-1]
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
mask = self.delta_step._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_absolute_actions(
action, self.delta_step._last_state, mask
)
return new_transition
def get_config(self) -> dict[str, Any]:
+1 -1
View File
@@ -248,7 +248,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
from lerobot.datasets.dataset_tools import recompute_stats
exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"])
exclude = getattr(cfg.policy, "delta_exclude_joints", [])
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
# Wait for all processes to finish policy creation before continuing