mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
Add option for pi family models to train with relative actions (relative to state)
This commit is contained in:
@@ -50,6 +50,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state)
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -44,6 +44,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -1259,6 +1260,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
@@ -1276,6 +1280,9 @@ class PI0Policy(PreTrainedPolicy):
|
||||
state = self.prepare_state(batch)
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||
|
||||
# Compute loss
|
||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||
|
||||
|
||||
@@ -23,6 +23,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
ComplementaryDataProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -143,6 +144,7 @@ def make_pi0_pre_post_processors(
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
]
|
||||
|
||||
output_steps: list[ProcessorStep] = [
|
||||
|
||||
@@ -50,6 +50,9 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -44,10 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1232,6 +1234,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
actions = actions[:, :, :original_action_dim]
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||
|
||||
return actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||
@@ -1249,6 +1255,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
|
||||
actions = self.prepare_action(batch)
|
||||
|
||||
if self.config.use_delta_actions:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||
|
||||
# Compute loss (no separate state needed for PI05)
|
||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -140,6 +141,7 @@ def make_pi05_pre_post_processors(
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name="google/paligemma-3b-pt-224",
|
||||
|
||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
||||
max_action_dim: int = 32
|
||||
max_action_tokens: int = 256
|
||||
|
||||
# Delta actions: converts absolute actions to delta (relative to state).
|
||||
use_delta_actions: bool = False
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
|
||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
ACTION_TOKEN_MASK,
|
||||
ACTION_TOKENS,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
OBS_LANGUAGE_TOKENS,
|
||||
OBS_STATE,
|
||||
OPENPI_ATTENTION_MASK_VALUE,
|
||||
)
|
||||
|
||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||
)
|
||||
|
||||
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||
continuous_actions = to_absolute_actions(
|
||||
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||
)
|
||||
|
||||
return continuous_actions
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
|
||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
||||
from lerobot.processor import (
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeltaActionsProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
|
||||
@@ -28,7 +28,13 @@ from .core import (
|
||||
RobotObservation,
|
||||
TransitionKey,
|
||||
)
|
||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
||||
from .delta_action_processor import (
|
||||
DeltaActionsProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from .device_processor import DeviceProcessorStep
|
||||
from .factory import (
|
||||
make_default_processors,
|
||||
@@ -97,6 +103,7 @@ __all__ = [
|
||||
"make_default_teleop_action_processor",
|
||||
"make_default_robot_action_processor",
|
||||
"make_default_robot_observation_processor",
|
||||
"DeltaActionsProcessorStep",
|
||||
"MapDeltaActionToRobotActionStep",
|
||||
"MapTensorToDeltaActionDictStep",
|
||||
"NormalizerProcessorStep",
|
||||
@@ -126,6 +133,8 @@ __all__ = [
|
||||
"transition_to_batch",
|
||||
"TransitionKey",
|
||||
"TruncatedProcessorStep",
|
||||
"to_absolute_actions",
|
||||
"to_delta_actions",
|
||||
"UnnormalizerProcessorStep",
|
||||
"VanillaObservationProcessorStep",
|
||||
]
|
||||
|
||||
@@ -14,12 +14,54 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
from .core import PolicyAction, RobotAction
|
||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||
|
||||
|
||||
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] -= state_offset
|
||||
return actions
|
||||
|
||||
|
||||
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||
|
||||
Args:
|
||||
actions: (B, T, action_dim) or (B, action_dim).
|
||||
state: (B, state_dim). Broadcast across time dimension.
|
||||
mask: Which dims to convert. Can be shorter than action_dim.
|
||||
"""
|
||||
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||
dims = mask_t.shape[0]
|
||||
state_offset = state[..., :dims] * mask_t
|
||||
if actions.ndim == 3:
|
||||
state_offset = state_offset.unsqueeze(-2)
|
||||
actions = actions.clone()
|
||||
actions[..., :dims] += state_offset
|
||||
return actions
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||
@@ -141,3 +183,44 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
||||
)
|
||||
|
||||
return features
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||
@dataclass
|
||||
class DeltaActionsProcessorStep(ProcessorStep):
|
||||
"""Converts absolute actions to delta actions (action -= state) for all dimensions.
|
||||
|
||||
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||
trains on relative offsets instead of absolute positions.
|
||||
|
||||
Attributes:
|
||||
enabled: Whether to apply the delta conversion.
|
||||
"""
|
||||
|
||||
enabled: bool = False
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
if not self.enabled:
|
||||
return transition
|
||||
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
if action is None:
|
||||
return new_transition
|
||||
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
state = observation.get(OBS_STATE) if observation else None
|
||||
if state is None:
|
||||
return new_transition
|
||||
|
||||
mask = [True] * action.shape[-1]
|
||||
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||
return new_transition
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"enabled": self.enabled}
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
@@ -0,0 +1,122 @@
|
||||
"""Tests for delta action transforms using a local dummy dataset."""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.delta_action_processor import (
|
||||
DeltaActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_delta_actions,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
ACTION_DIM = 14
|
||||
STATE_DIM = 14
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||
features = {
|
||||
"action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None},
|
||||
"observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None},
|
||||
}
|
||||
ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features)
|
||||
for ep in range(2):
|
||||
for _ in range(5):
|
||||
ds.add_frame({
|
||||
"action": np.random.randn(ACTION_DIM).astype(np.float32),
|
||||
"observation.state": np.random.randn(STATE_DIM).astype(np.float32),
|
||||
"task": f"task_{ep}",
|
||||
})
|
||||
ds.save_episode()
|
||||
ds.finalize()
|
||||
return ds
|
||||
|
||||
|
||||
def _collate(dataset, indices):
|
||||
items = [dataset[i] for i in indices]
|
||||
batch = {}
|
||||
for key in items[0]:
|
||||
vals = [item[key] for item in items]
|
||||
if isinstance(vals[0], torch.Tensor):
|
||||
batch[key] = torch.stack(vals)
|
||||
else:
|
||||
batch[key] = vals
|
||||
return batch
|
||||
|
||||
|
||||
def test_roundtrip_3d(dataset):
|
||||
"""Delta then absolute on real data should recover original actions."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone()
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
recovered = to_absolute_actions(delta, state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(dataset):
|
||||
"""Works with (B, action_dim) shaped actions too."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION]
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
recovered = to_absolute_actions(delta, state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_delta_changes_all_dims(dataset):
|
||||
"""All dims should change when mask is all True."""
|
||||
batch = _collate(dataset, range(4))
|
||||
actions = batch[ACTION].unsqueeze(1)
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
delta = to_delta_actions(actions, state, mask)
|
||||
assert (delta - actions).abs().sum() > 0
|
||||
|
||||
|
||||
def test_no_mutation(dataset):
|
||||
"""Original tensors should not be modified."""
|
||||
batch = _collate(dataset, range(2))
|
||||
actions = batch[ACTION].unsqueeze(1)
|
||||
original = actions.clone()
|
||||
state = batch[OBS_STATE]
|
||||
mask = [True] * actions.shape[-1]
|
||||
|
||||
to_delta_actions(actions, state, mask)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset):
|
||||
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||
batch = _collate(dataset, range(4))
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = DeltaActionsProcessorStep(enabled=True)
|
||||
delta_transition = step(transition)
|
||||
|
||||
delta_actions = delta_transition[TransitionKey.ACTION]
|
||||
assert not torch.allclose(delta_actions, original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * original_actions.shape[-1]
|
||||
recovered = to_absolute_actions(delta_actions, state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset):
|
||||
"""enabled=False should be a no-op."""
|
||||
batch = _collate(dataset, range(2))
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
Reference in New Issue
Block a user