mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
Add option for pi family models to train with relative actions (relative to state)
This commit is contained in:
@@ -50,6 +50,9 @@ class PI0Config(PreTrainedConfig):
|
|||||||
min_period: float = 4e-3
|
min_period: float = 4e-3
|
||||||
max_period: float = 4.0
|
max_period: float = 4.0
|
||||||
|
|
||||||
|
# Delta actions: converts absolute actions to delta (relative to state)
|
||||||
|
use_delta_actions: bool = False
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
|
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
@@ -1259,6 +1260,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
actions = actions[:, :, :original_action_dim]
|
actions = actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
|
if self.config.use_delta_actions:
|
||||||
|
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||||
@@ -1276,6 +1280,9 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
if self.config.use_delta_actions:
|
||||||
|
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
||||||
|
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
|||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
ComplementaryDataProcessorStep,
|
ComplementaryDataProcessorStep,
|
||||||
|
DeltaActionsProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
@@ -143,6 +144,7 @@ def make_pi0_pre_post_processors(
|
|||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
),
|
),
|
||||||
|
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||||
]
|
]
|
||||||
|
|
||||||
output_steps: list[ProcessorStep] = [
|
output_steps: list[ProcessorStep] = [
|
||||||
|
|||||||
@@ -50,6 +50,9 @@ class PI05Config(PreTrainedConfig):
|
|||||||
min_period: float = 4e-3
|
min_period: float = 4e-3
|
||||||
max_period: float = 4.0
|
max_period: float = 4.0
|
||||||
|
|
||||||
|
# Delta actions: converts absolute actions to delta (relative to state).
|
||||||
|
use_delta_actions: bool = False
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -44,10 +44,12 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
|
from lerobot.processor.delta_action_processor import to_absolute_actions, to_delta_actions
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
OPENPI_ATTENTION_MASK_VALUE,
|
OPENPI_ATTENTION_MASK_VALUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1232,6 +1234,10 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
actions = actions[:, :, :original_action_dim]
|
actions = actions[:, :, :original_action_dim]
|
||||||
|
|
||||||
|
if self.config.use_delta_actions:
|
||||||
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
|
actions = to_absolute_actions(actions, state, [True] * actions.shape[-1])
|
||||||
|
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor], reduction: str = "mean") -> tuple[Tensor, dict]:
|
||||||
@@ -1249,6 +1255,10 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
if self.config.use_delta_actions:
|
||||||
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
|
actions = to_delta_actions(actions, state, [True] * actions.shape[-1])
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
||||||
|
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
|||||||
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
from lerobot.policies.pi05.modeling_pi05 import pad_vector
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeltaActionsProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
@@ -140,6 +141,7 @@ def make_pi05_pre_post_processors(
|
|||||||
norm_map=config.normalization_mapping,
|
norm_map=config.normalization_mapping,
|
||||||
stats=dataset_stats,
|
stats=dataset_stats,
|
||||||
),
|
),
|
||||||
|
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||||
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
Pi05PrepareStateTokenizerProcessorStep(max_state_dim=config.max_state_dim),
|
||||||
TokenizerProcessorStep(
|
TokenizerProcessorStep(
|
||||||
tokenizer_name="google/paligemma-3b-pt-224",
|
tokenizer_name="google/paligemma-3b-pt-224",
|
||||||
|
|||||||
@@ -41,6 +41,9 @@ class PI0FastConfig(PreTrainedConfig):
|
|||||||
max_action_dim: int = 32
|
max_action_dim: int = 32
|
||||||
max_action_tokens: int = 256
|
max_action_tokens: int = 256
|
||||||
|
|
||||||
|
# Delta actions: converts absolute actions to delta (relative to state).
|
||||||
|
use_delta_actions: bool = False
|
||||||
|
|
||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
|
|||||||
@@ -48,12 +48,14 @@ from lerobot.configs.policies import PreTrainedConfig
|
|||||||
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
from lerobot.policies.pi0_fast.configuration_pi0_fast import PI0FastConfig
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
|
from lerobot.processor.delta_action_processor import to_absolute_actions
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
ACTION_TOKEN_MASK,
|
ACTION_TOKEN_MASK,
|
||||||
ACTION_TOKENS,
|
ACTION_TOKENS,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
OBS_LANGUAGE_TOKENS,
|
OBS_LANGUAGE_TOKENS,
|
||||||
|
OBS_STATE,
|
||||||
OPENPI_ATTENTION_MASK_VALUE,
|
OPENPI_ATTENTION_MASK_VALUE,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -1315,6 +1317,12 @@ class PI0FastPolicy(PreTrainedPolicy):
|
|||||||
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
action_tokens, action_horizon=action_horizon, action_dim=action_dim
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if self.config.use_delta_actions and OBS_STATE in batch:
|
||||||
|
state = pad_vector(batch[OBS_STATE], self.config.max_state_dim)
|
||||||
|
continuous_actions = to_absolute_actions(
|
||||||
|
continuous_actions, state, [True] * continuous_actions.shape[-1]
|
||||||
|
)
|
||||||
|
|
||||||
return continuous_actions
|
return continuous_actions
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||||
|
|||||||
@@ -27,6 +27,7 @@ from lerobot.policies.pi0_fast.modeling_pi0_fast import pad_vector
|
|||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
ActionTokenizerProcessorStep,
|
ActionTokenizerProcessorStep,
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
|
DeltaActionsProcessorStep,
|
||||||
DeviceProcessorStep,
|
DeviceProcessorStep,
|
||||||
NormalizerProcessorStep,
|
NormalizerProcessorStep,
|
||||||
PolicyAction,
|
PolicyAction,
|
||||||
@@ -147,6 +148,7 @@ def make_pi0_fast_pre_post_processors(
|
|||||||
padding_side="right",
|
padding_side="right",
|
||||||
padding="max_length",
|
padding="max_length",
|
||||||
),
|
),
|
||||||
|
DeltaActionsProcessorStep(enabled=config.use_delta_actions),
|
||||||
ActionTokenizerProcessorStep(
|
ActionTokenizerProcessorStep(
|
||||||
action_tokenizer_name=config.action_tokenizer_name,
|
action_tokenizer_name=config.action_tokenizer_name,
|
||||||
max_action_tokens=config.max_action_tokens,
|
max_action_tokens=config.max_action_tokens,
|
||||||
|
|||||||
@@ -28,7 +28,13 @@ from .core import (
|
|||||||
RobotObservation,
|
RobotObservation,
|
||||||
TransitionKey,
|
TransitionKey,
|
||||||
)
|
)
|
||||||
from .delta_action_processor import MapDeltaActionToRobotActionStep, MapTensorToDeltaActionDictStep
|
from .delta_action_processor import (
|
||||||
|
DeltaActionsProcessorStep,
|
||||||
|
MapDeltaActionToRobotActionStep,
|
||||||
|
MapTensorToDeltaActionDictStep,
|
||||||
|
to_absolute_actions,
|
||||||
|
to_delta_actions,
|
||||||
|
)
|
||||||
from .device_processor import DeviceProcessorStep
|
from .device_processor import DeviceProcessorStep
|
||||||
from .factory import (
|
from .factory import (
|
||||||
make_default_processors,
|
make_default_processors,
|
||||||
@@ -97,6 +103,7 @@ __all__ = [
|
|||||||
"make_default_teleop_action_processor",
|
"make_default_teleop_action_processor",
|
||||||
"make_default_robot_action_processor",
|
"make_default_robot_action_processor",
|
||||||
"make_default_robot_observation_processor",
|
"make_default_robot_observation_processor",
|
||||||
|
"DeltaActionsProcessorStep",
|
||||||
"MapDeltaActionToRobotActionStep",
|
"MapDeltaActionToRobotActionStep",
|
||||||
"MapTensorToDeltaActionDictStep",
|
"MapTensorToDeltaActionDictStep",
|
||||||
"NormalizerProcessorStep",
|
"NormalizerProcessorStep",
|
||||||
@@ -126,6 +133,8 @@ __all__ = [
|
|||||||
"transition_to_batch",
|
"transition_to_batch",
|
||||||
"TransitionKey",
|
"TransitionKey",
|
||||||
"TruncatedProcessorStep",
|
"TruncatedProcessorStep",
|
||||||
|
"to_absolute_actions",
|
||||||
|
"to_delta_actions",
|
||||||
"UnnormalizerProcessorStep",
|
"UnnormalizerProcessorStep",
|
||||||
"VanillaObservationProcessorStep",
|
"VanillaObservationProcessorStep",
|
||||||
]
|
]
|
||||||
|
|||||||
@@ -14,12 +14,54 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from collections.abc import Sequence
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
from lerobot.configs.types import FeatureType, PipelineFeatureType, PolicyFeature
|
||||||
|
from lerobot.utils.constants import OBS_STATE
|
||||||
|
|
||||||
from .core import PolicyAction, RobotAction
|
from .core import EnvTransition, PolicyAction, RobotAction, TransitionKey
|
||||||
from .pipeline import ActionProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
from .pipeline import ActionProcessorStep, ProcessorStep, ProcessorStepRegistry, RobotActionProcessorStep
|
||||||
|
|
||||||
|
|
||||||
|
def to_delta_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||||
|
"""Convert absolute actions to delta: delta = action - state (for masked dims).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: (B, T, action_dim) or (B, action_dim).
|
||||||
|
state: (B, state_dim). Broadcast across time dimension.
|
||||||
|
mask: Which dims to convert. Can be shorter than action_dim.
|
||||||
|
"""
|
||||||
|
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||||
|
dims = mask_t.shape[0]
|
||||||
|
state_offset = state[..., :dims] * mask_t
|
||||||
|
if actions.ndim == 3:
|
||||||
|
state_offset = state_offset.unsqueeze(-2)
|
||||||
|
actions = actions.clone()
|
||||||
|
actions[..., :dims] -= state_offset
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
|
def to_absolute_actions(actions: Tensor, state: Tensor, mask: Sequence[bool]) -> Tensor:
|
||||||
|
"""Convert delta actions back to absolute: absolute = delta + state (for masked dims).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actions: (B, T, action_dim) or (B, action_dim).
|
||||||
|
state: (B, state_dim). Broadcast across time dimension.
|
||||||
|
mask: Which dims to convert. Can be shorter than action_dim.
|
||||||
|
"""
|
||||||
|
mask_t = torch.tensor(mask, dtype=actions.dtype, device=actions.device)
|
||||||
|
dims = mask_t.shape[0]
|
||||||
|
state_offset = state[..., :dims] * mask_t
|
||||||
|
if actions.ndim == 3:
|
||||||
|
state_offset = state_offset.unsqueeze(-2)
|
||||||
|
actions = actions.clone()
|
||||||
|
actions[..., :dims] += state_offset
|
||||||
|
return actions
|
||||||
|
|
||||||
|
|
||||||
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
@ProcessorStepRegistry.register("map_tensor_to_delta_action_dict")
|
||||||
@@ -141,3 +183,44 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
|
|||||||
)
|
)
|
||||||
|
|
||||||
return features
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register("delta_actions_processor")
|
||||||
|
@dataclass
|
||||||
|
class DeltaActionsProcessorStep(ProcessorStep):
|
||||||
|
"""Converts absolute actions to delta actions (action -= state) for all dimensions.
|
||||||
|
|
||||||
|
Mirrors OpenPI's DeltaActions transform. Applied during preprocessing so the model
|
||||||
|
trains on relative offsets instead of absolute positions.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
enabled: Whether to apply the delta conversion.
|
||||||
|
"""
|
||||||
|
|
||||||
|
enabled: bool = False
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
if not self.enabled:
|
||||||
|
return transition
|
||||||
|
|
||||||
|
new_transition = transition.copy()
|
||||||
|
action = new_transition.get(TransitionKey.ACTION)
|
||||||
|
if action is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
observation = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||||
|
state = observation.get(OBS_STATE) if observation else None
|
||||||
|
if state is None:
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
mask = [True] * action.shape[-1]
|
||||||
|
new_transition[TransitionKey.ACTION] = to_delta_actions(action, state, mask)
|
||||||
|
return new_transition
|
||||||
|
|
||||||
|
def get_config(self) -> dict[str, Any]:
|
||||||
|
return {"enabled": self.enabled}
|
||||||
|
|
||||||
|
def transform_features(
|
||||||
|
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||||
|
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||||
|
return features
|
||||||
|
|||||||
@@ -0,0 +1,122 @@
|
|||||||
|
"""Tests for delta action transforms using a local dummy dataset."""
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import pytest
|
||||||
|
import torch
|
||||||
|
|
||||||
|
from lerobot.processor import TransitionKey, batch_to_transition
|
||||||
|
from lerobot.processor.delta_action_processor import (
|
||||||
|
DeltaActionsProcessorStep,
|
||||||
|
to_absolute_actions,
|
||||||
|
to_delta_actions,
|
||||||
|
)
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
|
ACTION_DIM = 14
|
||||||
|
STATE_DIM = 14
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def dataset(tmp_path, empty_lerobot_dataset_factory):
|
||||||
|
features = {
|
||||||
|
"action": {"dtype": "float32", "shape": (ACTION_DIM,), "names": None},
|
||||||
|
"observation.state": {"dtype": "float32", "shape": (STATE_DIM,), "names": None},
|
||||||
|
}
|
||||||
|
ds = empty_lerobot_dataset_factory(root=tmp_path / "delta_test", features=features)
|
||||||
|
for ep in range(2):
|
||||||
|
for _ in range(5):
|
||||||
|
ds.add_frame({
|
||||||
|
"action": np.random.randn(ACTION_DIM).astype(np.float32),
|
||||||
|
"observation.state": np.random.randn(STATE_DIM).astype(np.float32),
|
||||||
|
"task": f"task_{ep}",
|
||||||
|
})
|
||||||
|
ds.save_episode()
|
||||||
|
ds.finalize()
|
||||||
|
return ds
|
||||||
|
|
||||||
|
|
||||||
|
def _collate(dataset, indices):
|
||||||
|
items = [dataset[i] for i in indices]
|
||||||
|
batch = {}
|
||||||
|
for key in items[0]:
|
||||||
|
vals = [item[key] for item in items]
|
||||||
|
if isinstance(vals[0], torch.Tensor):
|
||||||
|
batch[key] = torch.stack(vals)
|
||||||
|
else:
|
||||||
|
batch[key] = vals
|
||||||
|
return batch
|
||||||
|
|
||||||
|
|
||||||
|
def test_roundtrip_3d(dataset):
|
||||||
|
"""Delta then absolute on real data should recover original actions."""
|
||||||
|
batch = _collate(dataset, range(4))
|
||||||
|
actions = batch[ACTION].unsqueeze(1).expand(-1, 10, -1).clone()
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
mask = [True] * actions.shape[-1]
|
||||||
|
|
||||||
|
delta = to_delta_actions(actions, state, mask)
|
||||||
|
recovered = to_absolute_actions(delta, state, mask)
|
||||||
|
torch.testing.assert_close(recovered, actions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_roundtrip_2d(dataset):
|
||||||
|
"""Works with (B, action_dim) shaped actions too."""
|
||||||
|
batch = _collate(dataset, range(4))
|
||||||
|
actions = batch[ACTION]
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
mask = [True] * actions.shape[-1]
|
||||||
|
|
||||||
|
delta = to_delta_actions(actions, state, mask)
|
||||||
|
recovered = to_absolute_actions(delta, state, mask)
|
||||||
|
torch.testing.assert_close(recovered, actions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_delta_changes_all_dims(dataset):
|
||||||
|
"""All dims should change when mask is all True."""
|
||||||
|
batch = _collate(dataset, range(4))
|
||||||
|
actions = batch[ACTION].unsqueeze(1)
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
mask = [True] * actions.shape[-1]
|
||||||
|
|
||||||
|
delta = to_delta_actions(actions, state, mask)
|
||||||
|
assert (delta - actions).abs().sum() > 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_no_mutation(dataset):
|
||||||
|
"""Original tensors should not be modified."""
|
||||||
|
batch = _collate(dataset, range(2))
|
||||||
|
actions = batch[ACTION].unsqueeze(1)
|
||||||
|
original = actions.clone()
|
||||||
|
state = batch[OBS_STATE]
|
||||||
|
mask = [True] * actions.shape[-1]
|
||||||
|
|
||||||
|
to_delta_actions(actions, state, mask)
|
||||||
|
torch.testing.assert_close(actions, original)
|
||||||
|
|
||||||
|
|
||||||
|
def test_processor_step_roundtrip(dataset):
|
||||||
|
"""DeltaActionsProcessorStep applies delta; to_absolute_actions recovers original."""
|
||||||
|
batch = _collate(dataset, range(4))
|
||||||
|
original_actions = batch[ACTION].clone()
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
|
step = DeltaActionsProcessorStep(enabled=True)
|
||||||
|
delta_transition = step(transition)
|
||||||
|
|
||||||
|
delta_actions = delta_transition[TransitionKey.ACTION]
|
||||||
|
assert not torch.allclose(delta_actions, original_actions)
|
||||||
|
|
||||||
|
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||||
|
mask = [True] * original_actions.shape[-1]
|
||||||
|
recovered = to_absolute_actions(delta_actions, state, mask)
|
||||||
|
torch.testing.assert_close(recovered, original_actions)
|
||||||
|
|
||||||
|
|
||||||
|
def test_processor_step_disabled_is_noop(dataset):
|
||||||
|
"""enabled=False should be a no-op."""
|
||||||
|
batch = _collate(dataset, range(2))
|
||||||
|
original = batch[ACTION].clone()
|
||||||
|
transition = batch_to_transition(batch)
|
||||||
|
|
||||||
|
result = DeltaActionsProcessorStep(enabled=False)(transition)
|
||||||
|
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||||
Reference in New Issue
Block a user