Files
lerobot/tests/policies/test_relative_actions.py
T

349 lines
14 KiB
Python

"""Tests for relative action transforms — full pipeline validation.
Tests the complete flow matching OpenPI:
raw actions → RelativeActions → Normalize(relative_stats) → model → Unnormalize → AbsoluteActions
Uses real dataset: lerobot-data-collection/dagger_final_1_21
"""
import numpy as np
import pytest
import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.compute_stats import get_feature_stats
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
RelativeActionsProcessorStep,
to_absolute_actions,
to_relative_actions,
)
from lerobot.utils.constants import ACTION, OBS_STATE
CHUNK_SIZE = 10
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
@pytest.fixture(scope="module")
def dataset():
return LeRobotDataset(REPO_ID, episodes=[0])
@pytest.fixture(scope="module")
def action_dim(dataset):
return dataset.meta.features["action"]["shape"][0]
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
"""Build action chunks from hf_dataset, like the training script does."""
hf = dataset.hf_dataset
total = len(hf)
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
chunks, states = [], []
for i in range(total - chunk_size + 1):
if all_ep[i] != all_ep[i + chunk_size - 1]:
continue
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
state = hf[i]["observation.state"].float()
chunks.append(chunk_actions)
states.append(state)
if len(chunks) >= max_chunks:
break
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
return torch.stack(chunks), torch.stack(states)
def _compute_relative_chunk_stats(action_chunks, states, mask):
all_chunks = []
for actions, state in zip(action_chunks, states, strict=True):
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_chunks.append(relative.numpy())
all_relative = np.concatenate(all_chunks, axis=0)
return get_feature_stats(all_relative, axis=0, keepdims=all_relative.ndim == 1)
# Basic roundtrip tests
def test_roundtrip_3d(action_dim):
actions = torch.randn(4, CHUNK_SIZE, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_roundtrip_2d(action_dim):
actions = torch.randn(4, action_dim)
state = torch.randn(4, action_dim)
mask = [True] * action_dim
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
torch.testing.assert_close(recovered, actions)
def test_no_mutation(action_dim):
actions = torch.randn(2, CHUNK_SIZE, action_dim)
original = actions.clone()
state = torch.randn(2, action_dim)
to_relative_actions(actions, state, [True] * action_dim)
torch.testing.assert_close(actions, original)
def test_exclude_joints_supports_partial_name_matching():
names = [
"right_joint_1.pos",
"right_gripper.pos",
"left_joint_1.pos",
"left_gripper.pos",
]
step = RelativeActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
assert step._build_mask(len(names)) == [True, False, True, False]
# Chunk-level relative stats test
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
"""Chunk-level relative stats should have larger std than per-frame relative stats."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
chunk_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
# Per-frame stats
hf = dataset.hf_dataset
n = min(500, len(hf))
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
frame_relatives = to_relative_actions(frame_actions, frame_states, mask).numpy()
frame_stats = get_feature_stats(frame_relatives, axis=0, keepdims=frame_relatives.ndim == 1)
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
f"frame std ({frame_stats['std'].mean():.4f})"
)
# Full pipeline roundtrip: relative → normalize → unnormalize → absolute
def test_full_pipeline_roundtrip(dataset, action_dim):
"""Test the complete OpenPI pipeline: relative → normalize → unnormalize → absolute."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
stats = {ACTION: dict(relative_stats.items())}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
relative_step = RelativeActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: relative → normalize
t1 = relative_step(transition)
t2 = normalizer(t1)
normalized_action = t2[TransitionKey.ACTION]
assert normalized_action.abs().mean() < 10, (
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
)
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered_actions = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
def test_normalized_relative_values_are_reasonable(dataset, action_dim):
"""With correct chunk stats, normalized relative actions should be in a reasonable range."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
mean = torch.tensor(relative_stats["mean"]).float()
std = torch.tensor(relative_stats["std"]).float()
all_normalized = []
for actions, state in zip(action_chunks, states, strict=True):
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
normalized = (relative - mean) / (std + 1e-6)
all_normalized.append(normalized)
all_normalized = torch.cat(all_normalized, dim=0)
pct_in_range = (all_normalized.abs() < 5).float().mean()
assert pct_in_range > 0.9, (
f"Only {pct_in_range * 100:.1f}% of normalized values in [-5, 5], expected >90%"
)
assert all_normalized.mean().abs() < 1.0, (
f"Mean of normalized relative actions is {all_normalized.mean():.2f}, expected near 0"
)
def test_processor_step_roundtrip(dataset, action_dim):
"""RelativeActionsProcessorStep applies relative offsets; to_absolute_actions recovers original."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_actions = batch[ACTION].clone()
transition = batch_to_transition(batch)
step = RelativeActionsProcessorStep(enabled=True)
relative_transition = step(transition)
assert not torch.allclose(relative_transition[TransitionKey.ACTION], original_actions)
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
mask = [True] * action_dim
recovered = to_absolute_actions(relative_transition[TransitionKey.ACTION], state, mask)
torch.testing.assert_close(recovered, original_actions)
def test_processor_step_disabled_is_noop(dataset, action_dim):
"""enabled=False should be a no-op."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
}
original = batch[ACTION].clone()
transition = batch_to_transition(batch)
result = RelativeActionsProcessorStep(enabled=False)(transition)
torch.testing.assert_close(result[TransitionKey.ACTION], original)
# Training batch shape validation
def test_relative_with_action_chunks(dataset, action_dim):
"""Verify relative actions work correctly with (B, chunk_size, action_dim) shaped actions."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
batch_states = states[:4] # (4, state_dim)
mask = [True] * action_dim
relative = to_relative_actions(batch_actions, batch_states, mask)
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
first_relatives = relative[:, 0, :] # (B, action_dim)
assert first_relatives.abs().mean() < relative.abs().mean(), (
f"First action in chunk should have smaller relative offset than average. "
f"First: {first_relatives.abs().mean():.4f}, Average: {relative.abs().mean():.4f}"
)
# Later actions should have larger relative offsets
last_relatives = relative[:, -1, :] # (B, action_dim)
assert last_relatives.abs().mean() >= first_relatives.abs().mean(), (
f"Last action in chunk should have >= relative offset than first. "
f"Last: {last_relatives.abs().mean():.4f}, First: {first_relatives.abs().mean():.4f}"
)
# Roundtrip
recovered = to_absolute_actions(relative, batch_states, mask)
torch.testing.assert_close(recovered, batch_actions)
def test_relative_stats_match_actual_data_distribution(dataset, action_dim):
"""Verify computed stats match the actual relative-action distribution."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
# Compute stats like the training script does
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
# Also compute directly
all_relatives = []
for actions, state in zip(action_chunks, states, strict=True):
rel = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
all_relatives.append(rel)
all_relatives_tensor = torch.cat(all_relatives, dim=0)
# Compare mean
actual_mean = all_relatives_tensor.mean(dim=0).numpy()
np.testing.assert_allclose(relative_stats["mean"], actual_mean, atol=0.01)
# Compare std
actual_std = all_relatives_tensor.std(dim=0).numpy()
np.testing.assert_allclose(relative_stats["std"], actual_std, atol=0.1)
# Verify q01 < mean < q99
assert (relative_stats["q01"] < relative_stats["mean"]).all(), "q01 should be < mean"
assert (relative_stats["mean"] < relative_stats["q99"]).all(), "mean should be < q99"
def test_quantile_normalization_roundtrip(dataset, action_dim):
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
mask = [True] * action_dim
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
stats = {ACTION: dict(relative_stats.items())}
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
relative_step = RelativeActionsProcessorStep(enabled=True)
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
original_actions = action_chunks[0].unsqueeze(0)
state = states[0].unsqueeze(0)
batch = {ACTION: original_actions, OBS_STATE: state}
transition = batch_to_transition(batch)
# Forward: relative → quantile normalize
t1 = relative_step(transition)
t2 = normalizer(t1)
normalized = t2[TransitionKey.ACTION]
# Most values should be in [-1, 1] with quantile normalization
pct_in_range = (normalized.abs() < 2).float().mean()
assert pct_in_range > 0.5, f"Only {pct_in_range * 100:.1f}% in [-2, 2] after quantile norm, expected >50%"
# Reverse: unnormalize → absolute
t3 = unnormalizer(t2)
t4 = absolute_step(t3)
recovered = t4[TransitionKey.ACTION]
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
def test_state_not_modified_by_relative_processor(dataset, action_dim):
"""State should never be modified by the relative-action processor."""
hf = dataset.hf_dataset
batch = {
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
}
original_state = batch[OBS_STATE].clone()
transition = batch_to_transition(batch)
step = RelativeActionsProcessorStep(enabled=True)
result = step(transition)
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
torch.testing.assert_close(result_state, original_state)