Merge branch 'feat/add_relative_action_pi_models' into feat/mirror

This commit is contained in:
Pepijn
2026-02-20 22:54:46 +01:00
10 changed files with 108 additions and 56 deletions
+1 -1
View File
@@ -471,7 +471,7 @@ def make_policy(
if not cfg.input_features: if not cfg.input_features:
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features} cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
# Store action feature names on config for delta_exclude_joints support # Store action feature names for delta_exclude_joints support
if ds_meta is not None and hasattr(cfg, "action_feature_names"): if ds_meta is not None and hasattr(cfg, "action_feature_names"):
action_names = ds_meta.features.get(ACTION, {}).get("names") action_names = ds_meta.features.get(ACTION, {}).get("names")
if action_names is not None: if action_names is not None:
@@ -52,9 +52,9 @@ class PI0Config(PreTrainedConfig):
# Delta actions: converts absolute actions to delta (relative to state). # Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute). # Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"]) delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata. # Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
-19
View File
@@ -44,7 +44,6 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -1221,18 +1220,6 @@ class PI0Policy(PreTrainedPolicy):
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim) state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
return state return state
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion.
Uses action_feature_names and delta_exclude_joints to determine which
dims get delta conversion. Falls back to all-True if names are unavailable.
"""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
def prepare_action(self, batch): def prepare_action(self, batch):
"""Pad action""" """Pad action"""
actions = pad_vector(batch[ACTION], self.config.max_action_dim) actions = pad_vector(batch[ACTION], self.config.max_action_dim)
@@ -1272,9 +1259,6 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim] actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1292,9 +1276,6 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch) state = self.prepare_state(batch)
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
if self.config.use_delta_actions:
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss # Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
+10 -2
View File
@@ -21,6 +21,7 @@ import torch
from lerobot.configs.types import PipelineFeatureType, PolicyFeature from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import ( from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep, ComplementaryDataProcessorStep,
DeltaActionsProcessorStep, DeltaActionsProcessorStep,
@@ -127,7 +128,13 @@ def make_pi0_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines. A tuple containing the configured pre-processor and post-processor pipelines.
""" """
# Add remaining processors delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
input_steps: list[ProcessorStep] = [ input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
@@ -139,18 +146,19 @@ def make_pi0_pre_post_processors(
padding="max_length", padding="max_length",
), ),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
delta_step,
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
] ]
output_steps: list[ProcessorStep] = [ output_steps: list[ProcessorStep] = [
UnnormalizerProcessorStep( UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
), ),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
] ]
@@ -52,9 +52,9 @@ class PI05Config(PreTrainedConfig):
# Delta actions: converts absolute actions to delta (relative to state). # Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False use_delta_actions: bool = False
# Joint names to exclude from delta conversion (kept as absolute). # Joint names to exclude from delta (kept absolute). Empty list = all dims delta.
delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"]) delta_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Populated at runtime by make_policy from dataset metadata. # Populated at runtime from dataset metadata by make_policy.
action_feature_names: list[str] | None = None action_feature_names: list[str] | None = None
# Real-Time Chunking (RTC) configuration # Real-Time Chunking (RTC) configuration
@@ -44,12 +44,10 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS, OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE, OPENPI_ATTENTION_MASK_VALUE,
) )
@@ -1201,14 +1199,6 @@ class PI05Policy(PreTrainedPolicy):
actions = pad_vector(batch[ACTION], self.config.max_action_dim) actions = pad_vector(batch[ACTION], self.config.max_action_dim)
return actions return actions
def _build_delta_mask(self, action_dim: int) -> list[bool]:
"""Build a boolean mask for delta action conversion."""
names = self.config.action_feature_names
if names is None:
return [True] * action_dim
exclude = set(self.config.delta_exclude_joints)
return [n not in exclude for n in names]
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
@@ -1242,10 +1232,6 @@ class PI05Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim] actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_absolute_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
return actions return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]: def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1263,10 +1249,6 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch) actions = self.prepare_action(batch)
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_delta_actions(actions, state, self._build_delta_mask(actions.shape[-1]))
# Compute loss (no separate state needed for PI05) # Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions) losses = self.model.forward(images, img_masks, tokens, masks, actions)
+12 -2
View File
@@ -25,6 +25,7 @@ from lerobot.configs.types import PipelineFeatureType, PolicyFeature
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import ( from lerobot.processor import (
AbsoluteActionsProcessorStep,
AddBatchDimensionProcessorStep, AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep, DeltaActionsProcessorStep,
DeviceProcessorStep, DeviceProcessorStep,
@@ -130,10 +131,19 @@ def make_pi05_pre_post_processors(
A tuple containing the configured pre-processor and post-processor pipelines. A tuple containing the configured pre-processor and post-processor pipelines.
""" """
# Add remaining processors delta_step = DeltaActionsProcessorStep(
enabled=config.use_delta_actions,
exclude_joints=getattr(config, "delta_exclude_joints", []),
action_names=getattr(config, "action_feature_names", None),
)
# OpenPI order: raw → delta → normalize → model → unnormalize → absolute
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization
input_steps: list[ProcessorStep] = [ input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameObservationsProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
delta_step,
# NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep # NOTE: NormalizerProcessorStep MUST come before Pi05PrepareStateTokenizerProcessorStep
# because the tokenizer step expects normalized state in [-1, 1] range for discretization # because the tokenizer step expects normalized state in [-1, 1] range for discretization
NormalizerProcessorStep( NormalizerProcessorStep(
@@ -141,7 +151,6 @@ def make_pi05_pre_post_processors(
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim), Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep( TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224", tokenizer_name="google/paligemma-3b-pt-224",
@@ -156,6 +165,7 @@ def make_pi05_pre_post_processors(
UnnormalizerProcessorStep( UnnormalizerProcessorStep(
features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats
), ),
AbsoluteActionsProcessorStep(enabled=config.use_delta_actions, delta_step=delta_step),
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
] ]
+2
View File
@@ -29,6 +29,7 @@ from .core import (
TransitionKey, TransitionKey,
) )
from .delta_action_processor import ( from .delta_action_processor import (
AbsoluteActionsProcessorStep,
DeltaActionsProcessorStep, DeltaActionsProcessorStep,
MapDeltaActionToRobotActionStep, MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep, MapTensorToDeltaActionDictStep,
@@ -103,6 +104,7 @@ __all__ = [
"make_default_teleop_action_processor", "make_default_teleop_action_processor",
"make_default_robot_action_processor", "make_default_robot_action_processor",
"make_default_robot_observation_processor", "make_default_robot_observation_processor",
"AbsoluteActionsProcessorStep",
"DeltaActionsProcessorStep", "DeltaActionsProcessorStep",
"MapDeltaActionToRobotActionStep", "MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep", "MapTensorToDeltaActionDictStep",
@@ -15,7 +15,7 @@
# limitations under the License. # limitations under the License.
from collections.abc import Sequence from collections.abc import Sequence
from dataclasses import dataclass from dataclasses import dataclass, field
from typing import Any from typing import Any
import torch import torch
@@ -188,33 +188,102 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
@ProcessorStepRegistry.register("delta_actions_processor") @ProcessorStepRegistry.register("delta_actions_processor")
@dataclass @dataclass
class DeltaActionsProcessorStep(ProcessorStep): class DeltaActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to delta actions (action -= state) for all dimensions. """Converts absolute actions to delta actions (action -= state) for masked dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions. trains on relative offsets instead of absolute positions.
Caches the last seen state so a paired AbsoluteActionsProcessorStep can reverse
the conversion during postprocessing.
Attributes: Attributes:
enabled: Whether to apply the delta conversion. enabled: Whether to apply the delta conversion.
exclude_joints: Joint names to keep absolute (not converted to delta).
action_names: Action dimension names from dataset metadata, used to build
the mask from exclude_joints. If None, all dims are converted.
""" """
enabled: bool = False enabled: bool = False
exclude_joints: list[str] = field(default_factory=list)
action_names: list[str] | None = None
_last_state: torch.Tensor | None = field(default=None, init=False, repr=False)
def _build_mask(self, action_dim: int) -> list[bool]:
if not self.exclude_joints or self.action_names is None:
return [True] * action_dim
exclude = set(self.exclude_joints)
return [n not in exclude for n in self.action_names]
def __call__(self, transition: EnvTransition) -> EnvTransition:
observation = transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
# Always cache state for the paired AbsoluteActionsProcessorStep
if state is not None:
self._last_state = state
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None or state is None:
return new_transition
mask = self._build_mask(action.shape[-1])
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled, "exclude_joints": self.exclude_joints}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
@ProcessorStepRegistry.register("absolute_actions_processor")
@dataclass
class AbsoluteActionsProcessorStep(ProcessorStep):
"""Converts delta actions back to absolute actions (action += state) for all dimensions.
Mirrors OpenPI's AbsoluteActions transform. Applied during postprocessing so
predicted deltas are converted back to absolute positions for execution.
Reads the cached state from its paired DeltaActionsProcessorStep.
Attributes:
enabled: Whether to apply the absolute conversion.
delta_step: Reference to the paired DeltaActionsProcessorStep that caches state.
"""
enabled: bool = False
delta_step: DeltaActionsProcessorStep | None = field(default=None, repr=False)
def __call__(self, transition: EnvTransition) -> EnvTransition: def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled: if not self.enabled:
return transition return transition
if self.delta_step is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires a paired DeltaActionsProcessorStep "
"but delta_step is None. Ensure delta_step is set when constructing the postprocessor."
)
if self.delta_step._last_state is None:
raise RuntimeError(
"AbsoluteActionsProcessorStep requires state from DeltaActionsProcessorStep "
"but no state has been cached. Ensure the preprocessor runs before the postprocessor."
)
new_transition = transition.copy() new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION) action = new_transition.get(TransitionKey.ACTION)
if action is None: if action is None:
return new_transition return new_transition
observation = new_transition.get(TransitionKey.OBSERVATION, {}) mask = self.delta_step._build_mask(action.shape[-1])
state = observation.get(OBS_STATE) if observation else None new_transition[TransitionKey.ACTION] = to_absolute_actions(
if state is None: action, self.delta_step._last_state, mask
return new_transition )
mask = [True] * action.shape[-1]
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition return new_transition
def get_config(self) -> dict[str, Any]: def get_config(self) -> dict[str, Any]:
+1 -1
View File
@@ -248,7 +248,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)") logging.info("use_delta_actions is enabled — recomputing action stats as delta (action - state)")
from lerobot.datasets.dataset_tools import recompute_stats from lerobot.datasets.dataset_tools import recompute_stats
exclude = getattr(cfg.policy, "delta_exclude_joints", ["gripper"]) exclude = getattr(cfg.policy, "delta_exclude_joints", [])
recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude) recompute_stats(dataset, skip_image_video=True, delta_action=True, delta_exclude_joints=exclude)
# Wait for all processes to finish policy creation before continuing # Wait for all processes to finish policy creation before continuing