Add option for pi family models to train with relative actions (relative to state)

This commit is contained in:
Pepijn
2026-02-13 17:45:59 +01:00
parent 6600b60e7f
commit 1de2b87a92
12 changed files with 257 additions and 3 deletions
@@ -50,6 +50,9 @@ class PI0Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state)
use_delta_actions: bool = False
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
+7
View File
@@ -44,6 +44,7 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
@@ -1259,6 +1260,9 @@ class PI0Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1276,6 +1280,9 @@ class PI0Policy(PreTrainedPolicy):
state = self.prepare_state(batch)
actions = self.prepare_action(batch)
if self.config.use_delta_actions:
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
# Compute loss
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
@@ -23,6 +23,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.processor import (
AddBatchDimensionProcessorStep,
ComplementaryDataProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -143,6 +144,7 @@ def make_pi0_pre_post_processors(
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
]
output_steps: list[ProcessorStep] = [
@@ -50,6 +50,9 @@ class PI05Config(PreTrainedConfig):
min_period: float = 4e-3
max_period: float = 4.0
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -44,10 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
from lerobot.utils.constants import (
ACTION,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -1232,6 +1234,10 @@ class PI05Policy(PreTrainedPolicy):
original_action_dim = self.config.output_features[ACTION].shape[0]
actions = actions[:, :, :original_action_dim]
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
return actions
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
@@ -1249,6 +1255,10 @@ class PI05Policy(PreTrainedPolicy):
actions = self.prepare_action(batch)
if self.config.use_delta_actions:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
# Compute loss (no separate state needed for PI05)
losses = self.model.forward(images, img_masks, tokens, masks, actions)
@@ -26,6 +26,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pi05.modeling_pi05 import pad_vector
from lerobot.processor import (
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -140,6 +141,7 @@ def make_pi05_pre_post_processors(
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
TokenizerProcessorStep(
tokenizer_name="google/paligemma-3b-pt-224",
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
max_action_dim: int = 32
max_action_tokens: int = 256
# Delta actions: converts absolute actions to delta (relative to state).
use_delta_actions: bool = False
# Real-Time Chunking (RTC) configuration
rtc_config: RTCConfig | None = None
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.processor.delta_action_processor import to_absolute_actions
from lerobot.utils.constants import (
ACTION,
ACTION_TOKEN_MASK,
ACTION_TOKENS,
OBS_LANGUAGE_ATTENTION_MASK,
OBS_LANGUAGE_TOKENS,
OBS_STATE,
OPENPI_ATTENTION_MASK_VALUE,
)
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
action_tokens, action_horizon=action_horizon, action_dim=action_dim
)
if self.config.use_delta_actions and OBS_STATE in batch:
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
continuous_actions = to_absolute_actions(
continuous_actions, state, [True] * continuous_actions.shape[-1]
)
return continuous_actions
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
from lerobot.processor import (
ActionTokenizerProcessorStep,
AddBatchDimensionProcessorStep,
DeltaActionsProcessorStep,
DeviceProcessorStep,
NormalizerProcessorStep,
PolicyAction,
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
padding_side="right",
padding="max_length",
),
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
ActionTokenizerProcessorStep(
action_tokenizer_name=config.action_tokenizer_name,
max_action_tokens=config.max_action_tokens,
+10 -1
View File
@@ -28,7 +28,13 @@ from .core import (
RobotObservation,
TransitionKey,
)
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
from .delta_action_processor import (
DeltaActionsProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
to_absolute_actions,
to_delta_actions,
)
from .device_processor import DeviceProcessorStep
from .factory import (
make_default_processors,
@@ -97,6 +103,7 @@ __all__ = [
"make_default_teleop_action_processor",
"make_default_robot_action_processor",
"make_default_robot_observation_processor",
"DeltaActionsProcessorStep",
"MapDeltaActionToRobotActionStep",
"MapTensorToDeltaActionDictStep",
"NormalizerProcessorStep",
@@ -126,6 +133,8 @@ __all__ = [
"transition_to_batch",
"TransitionKey",
"TruncatedProcessorStep",
"to_absolute_actions",
"to_delta_actions",
"UnnormalizerProcessorStep",
"VanillaObservationProcessorStep",
]
@@ -14,12 +14,54 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence
from dataclasses import dataclass
from typing import Any
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
from lerobot.utils.constants import OBS_STATE
from .core import PolicyAction, RobotAction
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert absolute actions to delta: delta = action - state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] -= state_offset
return actions
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
Args:
actions: (B, T, action_dim) or (B, action_dim).
state: (B, state_dim). Broadcast across time dimension.
mask: Which dims to convert. Can be shorter than action_dim.
"""
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
dims = mask_t.shape[0]
state_offset = state[..., :dims] * mask_t
if actions.ndim == 3:
state_offset = state_offset.unsqueeze(-2)
actions = actions.clone()
actions[..., :dims] += state_offset
return actions
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
@@ -141,3 +183,44 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
)
return features
@ProcessorStepRegistry.register("delta_actions_processor")
@dataclass
class DeltaActionsProcessorStep(ProcessorStep):
"""Converts absolute actions to delta actions (action -= state) for all dimensions.
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
trains on relative offsets instead of absolute positions.
Attributes:
enabled: Whether to apply the delta conversion.
"""
enabled: bool = False
def __call__(self, transition: EnvTransition) -> EnvTransition:
if not self.enabled:
return transition
new_transition = transition.copy()
action = new_transition.get(TransitionKey.ACTION)
if action is None:
return new_transition
observation = new_transition.get(TransitionKey.OBSERVATION, {})
state = observation.get(OBS_STATE) if observation else None
if state is None:
return new_transition
mask = [True] * action.shape[-1]
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
return new_transition
def get_config(self) -> dict[str, Any]:
return {"enabled": self.enabled}
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
return features
+122
View File
@@ -0,0 +1,122 @@
"""Tests for delta action transforms using a local dummy dataset."""
import numpy as np
import pytest
import torch
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor.delta_action_processor import (
DeltaActionsProcessorStep,
to_absolute_actions,
to_delta_actions,
)
from lerobot.utils.constants import ACTION, OBS_STATE
ACTION_DIM = 14
STATE_DIM = 14
@pytest.fixture
def dataset(tmp_path, empty_lerobot_dataset_factory):
features = {
"action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None},
"observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None},
}
ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features)
for ep in range(2):
for _ in range(5):
ds.add_frame({
"action": np.random.randn(ACTION_DIM).astype(np.float32),
"observation.state": np.random.randn(STATE_DIM).astype(np.float32),
"task": f"task_{ep}",
})
ds.save_episode()
ds.finalize()
return ds
def _collate(dataset, indices):
items = [dataset[i] for i in indices]
batch = {}
for key in items[0]:
vals = [item[key] for item in items]
if isinstance(vals[0], torch.Tensor):
batch[key] = torch.stack(vals)
else:
batch[key] = vals
return batch
def test_roundtrip_3d(dataset):
"""Delta then absolute on real data should recover original actions."""
batch = _collate(dataset, range(4))
actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone()
state = batch[OBS_STATE]
mask = [True] * actions.shape[-1]
delta = to_delta_actions(actions, state, mask)
recovered = to_absolute_actions(delta, state, mask)
torch.testing.assert_close(recovered, actions)
def test_roundtrip_2d(dataset):
"""Works with (B, action_dim) shaped actions too."""
batch = _collate(dataset, range(4))
actions = batch[ACTION]
state = batch[OBS_STATE]
mask = [True] * actions.shape[-1]
delta = to_delta_actions(actions, state, mask)
recovered = to_absolute_actions(delta, state, mask)
torch.testing.assert_close(recovered, actions)
def test_delta_changes_all_dims(dataset):
"""All dims should change when mask is all True."""
batch = _collate(dataset, range(4))
actions = batch[ACTION].unsqueeze(1)
state = batch[OBS_STATE]
mask = [True] * actions.shape[-1]
delta = to_delta_actions(actions, state, mask)
assert (delta - actions).abs().sum() > 0
def test_no_mutation(dataset):
"""Original tensors should not be modified."""
batch = _collate(dataset, range(2))
actions = batch[ACTION].unsqueeze(1)
original = actions.clone()
state = batch[OBS_STATE]
mask = [True] * actions.shape[-1]
to_delta_actions(actions, state, mask)
torch.testing.assert_close(actions, original)
def test_processor_step_roundtrip(dataset):
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
batch = _collate(dataset, range(4))
original_actions = batch[ACTION].clone()
transition = batch_to_transition(batch)
step = DeltaActionsProcessorStep(enabled=True)
delta_transition = step(transition)
delta_actions = delta_transition[TransitionKey.ACTION]
assert not torch.allclose(delta_actions, original_actions)
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
mask = [True] * original_actions.shape[-1]
recovered = to_absolute_actions(delta_actions, state, mask)
torch.testing.assert_close(recovered, original_actions)
def test_processor_step_disabled_is_noop(dataset):
"""enabled=False should be a no-op."""
batch = _collate(dataset, range(2))
original = batch[ACTION].clone()
transition = batch_to_transition(batch)
result = DeltaActionsProcessorStep(enabled=False)(transition)
torch.testing.assert_close(result[TransitionKey.ACTION], original)