mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
feat(policies): add relative action support for pi0, pi0.5, and pi0_fast (#2970)
* Add option for pi family models to train with relative actions (relative to state) * formatting * add recomputation of stats and option to compute delta stats * normalzie after delta conversion * only recompute state for stats * calulate chunk based stats * sample 100k * load from parquet * sample 1m * stats per chunck * fix * use quantiles * stats for entire dataset * fix * max 1m frames * compute before dist * fix multi gpu processor bug * Fix RTC with delta actions and OpenArms motor_type wiring * feat: align pi0_fast delta actions with pi0/pi05 and add RTC integration tests - Add delta_exclude_joints and action_feature_names to PI0FastConfig - Move to_absolute_actions from modeling to processor pipeline for pi0_fast - Add delta action detection and logging to eval_with_real_robot.py - Add delta actions documentation to pi0 and pi05 READMEs - Fix ruff lint issues in test_delta_actions.py - Add test_rtc_delta_actions.py (24 tests) covering: - ActionQueue with delta vs absolute actions - RTC denoise step with delta leftovers - Full pipeline roundtrip (delta → RTC → absolute) - State rebasing approximation bounds - Non-delta policy compatibility - Multi-chunk consistency * chore: clean up test comments, add OpenPI attribution, remove debug logging - Replace decorative comment separators in test files with plain section headers - Add attribution comments for 1e-6 epsilon in normalize_processor.py (from OpenPI) - Remove debug logging blocks from lerobot_train.py * refactor: extract compute_delta_action_stats into compute_stats.py Move the ~70-line inline delta action stats block from lerobot_train.py into a dedicated function in compute_stats.py, where all other stats computation already lives. The training script now calls it in 6 lines. * refactor: remove unused get_processed_left_over from ActionQueue This method was never called outside of tests. Leftover actions for RTC guidance are always retrieved via get_left_over() (delta/original space). * revert: remove logging-only changes from eval_with_real_robot.py The delta actions detection helper and log message added no functional value — the script already handles delta policies correctly via the processor pipeline. * refactor: use ACTION/OBS_STATE constants instead of hardcoded strings Replace hardcoded "action" and "observation.state" with ACTION and OBS_STATE from utils.constants in compute_stats.py, dataset_tools.py, and lerobot_train.py. * style: remove stray blank lines in training loop * refactor: move delta action stats to preprocessing step, remove on-the-fly computation - Remove on-the-fly compute_delta_action_stats from lerobot_train.py - Rewrite recompute_stats to delegate action stats to compute_delta_action_stats (chunk-based sampling matching what the model sees during training) - Add chunk_size parameter to recompute_stats for delta action computation - Add delta actions documentation to pi0.mdx and pi05.mdx * feat: add recompute_stats CLI operation to lerobot-edit-dataset * fix(tests): relax quantile normalization test tolerance for 1e-6 epsilon * chore: remove agents_memory/pr_details.md from repo * refactor: rename delta actions to relative actions throughout What OpenPI calls "DeltaActions" is actually UMI's "relative trajectory" representation: each action in the chunk is an offset from the current state, not from the previous action. This avoids error accumulation. Renamed across all source, tests, docs, and CLI: - DeltaActionsProcessorStep → RelativeActionsProcessorStep - to_delta_actions → to_relative_actions - use_delta_actions → use_relative_actions - delta_exclude_joints → relative_exclude_joints - compute_delta_action_stats → compute_relative_action_stats - delta_action_processor.py → relative_action_processor.py - test_delta_actions.py → test_relative_actions.py Kept as-is: AbsoluteActionsProcessorStep (converts TO absolute), registry ID "delta_actions_processor" (backward compat), and unrelated delta references (IK pipeline, Robosuite, RA-BC metrics, gym envs). * docs: add Action Representations guide Dedicated page explaining absolute, relative, and delta actions with numerical examples, joint vs EE space, and how to use kinematics pipelines and the relative action processor. References UMI paper (Chi et al., 2024) for the terminology. * docs: remove redundant OpenPI naming note from action representations * docs: remove opinionated OpenPI reference from delta actions section * docs: replace ASCII diagram with UMI paper figure * docs: remove OpenPI reference from action representations * docs: use HF-hosted image instead of local asset * docs: clarify figure attribution * revert: restore original normalization epsilon behavior The 1e-6 unconditional epsilon change perturbed all normalized values, breaking backward compatibility tests. The original approach (1e-8 eps for MEAN_STD, conditional torch.where for QUANTILES) already handles division by zero correctly without affecting non-degenerate cases. * fix: restore delta_action_processor.py used by phone/RL teleop The rename commit incorrectly deleted delta_action_processor.py and duplicated its classes into relative_action_processor.py. Restore the original file and import from it instead. * fix(processor): address PR #2970 review comments - Remove shebang from relative_action_processor.py (library module, not script) - Add device alignment in to_relative_actions/to_absolute_actions so _last_state on CPU doesn't cause cross-device errors when actions are on CUDA - Rename delta_step → relative_step in AbsoluteActionsProcessorStep for naming consistency; update factory.py, all processor files, and tests - Expand _reconnect_relative_absolute_steps docstring to explain why post-hoc rewiring is needed after deserialization - Fix off-by-one in compute_stats.py: sample_upper_bound = total_frames - chunk_size + 1 so last valid start index is included and total_frames == chunk_size is not rejected - Remove redundant NOTE comment in processor_pi05.py (duplicated two lines below) - Fix pi0_fast processor ordering: move relative_step before NormalizerProcessorStep so normalizer sees delta actions (matching pi0/pi05); flip postprocessor to unnormalize → absolute accordingly. Relative stats are now required for all pi models - Revert use_relative_joint_actions_aloha → use_delta_joint_actions_aloha in configuration_smolvla.py (preserve existing public API) - Update action_representations.mdx: add missing joint to 6-DOF example, fix 'based on a figure', clarify pi family ordering, add RTC compatibility section * update rtc link * feat: compute relative action stats over full dataset with optional parallelism Remove the 100k sample cap from compute_relative_action_stats and process all valid chunks. Vectorize with numpy (pre-load actions/states, fancy indexing + broadcasting) for a large speedup over the per-index HF dataset loop. Add num_workers param for thread-based parallelism (numpy releases the GIL). Update docs to show --push_to_hub for recompute_stats. * style: apply ruff formatting to compute_stats.py * testing on real robot * style: fix ruff format and remove redundant .keys() calls
This commit is contained in:
@@ -25,7 +25,7 @@ import torch
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
# Fixtures
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -63,7 +63,7 @@ def action_queue_rtc_disabled(rtc_config_disabled):
|
||||
return ActionQueue(rtc_config_disabled)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
# Initialization tests
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
|
||||
@@ -84,7 +84,7 @@ def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
|
||||
assert queue.cfg.enabled is False
|
||||
|
||||
|
||||
# ====================== get() Tests ======================
|
||||
# get() tests
|
||||
|
||||
|
||||
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
@@ -136,7 +136,7 @@ def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
|
||||
# ====================== qsize() Tests ======================
|
||||
# qsize() tests
|
||||
|
||||
|
||||
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
|
||||
@@ -167,7 +167,7 @@ def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== empty() Tests ======================
|
||||
# empty() tests
|
||||
|
||||
|
||||
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
|
||||
@@ -202,7 +202,7 @@ def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
# ====================== get_action_index() Tests ======================
|
||||
# get_action_index() tests
|
||||
|
||||
|
||||
def test_get_action_index_initial_value(action_queue_rtc_enabled):
|
||||
@@ -222,7 +222,7 @@ def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_act
|
||||
assert action_queue_rtc_enabled.get_action_index() == 3
|
||||
|
||||
|
||||
# ====================== get_left_over() Tests ======================
|
||||
# get_left_over() tests
|
||||
|
||||
|
||||
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
@@ -269,7 +269,7 @@ def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled,
|
||||
assert leftover.shape == (0, 6)
|
||||
|
||||
|
||||
# ====================== merge() with RTC Enabled Tests ======================
|
||||
# merge() with RTC enabled tests
|
||||
|
||||
|
||||
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
@@ -336,7 +336,7 @@ def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== merge() with RTC Disabled Tests ======================
|
||||
# merge() with RTC disabled tests
|
||||
|
||||
|
||||
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
@@ -402,7 +402,7 @@ def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_ac
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
# ====================== merge() with Different Action Shapes Tests ======================
|
||||
# merge() with different action shapes tests
|
||||
|
||||
|
||||
def test_merge_with_different_action_dims():
|
||||
@@ -431,7 +431,7 @@ def test_merge_with_different_lengths():
|
||||
assert queue.qsize() == 35
|
||||
|
||||
|
||||
# ====================== merge() Delay Validation Tests ======================
|
||||
# merge() delay validation tests
|
||||
|
||||
|
||||
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
@@ -509,7 +509,7 @@ def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled,
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
# ====================== Thread Safety Tests ======================
|
||||
# Thread safety tests
|
||||
|
||||
|
||||
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
@@ -621,7 +621,7 @@ def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
|
||||
assert consumed_count[0] <= 200
|
||||
|
||||
|
||||
# ====================== get_left_over() Thread Safety Tests ======================
|
||||
# get_left_over() thread safety tests
|
||||
|
||||
|
||||
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
@@ -670,7 +670,7 @@ def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
assert len(leftovers) > 0
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
# Edge cases tests
|
||||
|
||||
|
||||
def test_queue_with_single_action(action_queue_rtc_enabled):
|
||||
@@ -773,7 +773,7 @@ def test_qsize_with_none_queue(action_queue_rtc_enabled):
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
# Integration tests
|
||||
|
||||
|
||||
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
|
||||
|
||||
@@ -0,0 +1,607 @@
|
||||
"""Tests for RTC + relative actions integration.
|
||||
|
||||
Validates that Real-Time Chunking (RTC) works correctly when the policy uses
|
||||
relative actions. The key invariant: RTC guidance operates in model space
|
||||
(normalized relative actions), while the robot receives absolute actions after postprocessing.
|
||||
|
||||
Flow under test:
|
||||
Preprocessor: raw obs → relative step caches state → normalizer
|
||||
Model: generates normalized relative actions (guided by RTC using leftover relative actions)
|
||||
Postprocessor: unnormalize → absolute step (relative + cached state) → robot actions
|
||||
"""
|
||||
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import (
|
||||
FeatureType,
|
||||
NormalizationMode,
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
def _import_rtc_module(module_name: str, filename: str):
|
||||
"""Import an RTC module directly from its file path, bypassing lerobot.policies.__init__."""
|
||||
rtc_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "policies" / "rtc"
|
||||
spec = importlib.util.spec_from_file_location(module_name, rtc_dir / filename)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = mod
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
_rtc_cfg_mod = _import_rtc_module("lerobot.policies.rtc.configuration_rtc", "configuration_rtc.py")
|
||||
RTCConfig = _rtc_cfg_mod.RTCConfig
|
||||
|
||||
_action_queue_mod = _import_rtc_module("lerobot.policies.rtc.action_queue", "action_queue.py")
|
||||
ActionQueue = _action_queue_mod.ActionQueue
|
||||
|
||||
_rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug_tracker.py")
|
||||
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
||||
RTCProcessor = _rtc_mod.RTCProcessor
|
||||
|
||||
ACTION_DIM = 6
|
||||
CHUNK_SIZE = 50
|
||||
EXECUTION_HORIZON = 10
|
||||
|
||||
|
||||
def _make_rtc_config(enabled=True):
|
||||
return RTCConfig(
|
||||
enabled=enabled,
|
||||
execution_horizon=EXECUTION_HORIZON,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
|
||||
|
||||
def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.MEAN_STD):
|
||||
"""Build paired relative/absolute processor steps and normalizer/unnormalizer."""
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: norm_mode}
|
||||
|
||||
stats = {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(action_dim).numpy(),
|
||||
"std": torch.ones(action_dim).numpy(),
|
||||
"q01": (-2 * torch.ones(action_dim)).numpy(),
|
||||
"q99": (2 * torch.ones(action_dim)).numpy(),
|
||||
"min": (-3 * torch.ones(action_dim)).numpy(),
|
||||
"max": (3 * torch.ones(action_dim)).numpy(),
|
||||
}
|
||||
}
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
||||
return relative_step, normalizer, unnormalizer, absolute_step
|
||||
|
||||
|
||||
class TestActionQueueRelativeActions:
|
||||
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
|
||||
|
||||
def test_left_over_returns_relative_actions(self):
|
||||
"""get_left_over() should return the original (relative-space) actions."""
|
||||
cfg = _make_rtc_config()
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
||||
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
||||
queue.merge(relative_actions, absolute_actions, real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
queue.get()
|
||||
|
||||
leftover = queue.get_left_over()
|
||||
torch.testing.assert_close(leftover, relative_actions[5:])
|
||||
|
||||
def test_robot_receives_absolute_actions(self):
|
||||
"""The robot (via get()) should receive postprocessed absolute actions."""
|
||||
cfg = _make_rtc_config()
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
||||
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
||||
queue.merge(relative_actions, absolute_actions, real_delay=0)
|
||||
|
||||
first_action = queue.get()
|
||||
torch.testing.assert_close(first_action, absolute_actions[0])
|
||||
|
||||
|
||||
class TestRTCDenoiseWithRelativeLeftovers:
|
||||
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
|
||||
|
||||
def test_first_chunk_no_guidance(self):
|
||||
"""First chunk (no leftovers) should return v_t without guidance."""
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
def mock_denoise(x):
|
||||
return torch.ones_like(x)
|
||||
|
||||
result = rtc.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=0,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=mock_denoise,
|
||||
)
|
||||
torch.testing.assert_close(result, torch.ones_like(x_t))
|
||||
|
||||
def test_relative_leftovers_shape_preserved(self):
|
||||
"""RTC output should have the same shape as input regardless of leftover shape."""
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
shorter_leftover = torch.randn(1, 20, ACTION_DIM)
|
||||
|
||||
def mock_denoise(x):
|
||||
return torch.zeros_like(x)
|
||||
|
||||
result = rtc.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=shorter_leftover,
|
||||
inference_delay=5,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=mock_denoise,
|
||||
)
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
def test_guidance_steers_toward_previous_relative_actions(self):
|
||||
"""RTC guidance should push x1_t toward prev_chunk_left_over in relative space."""
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
prev_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
def mock_denoise(x):
|
||||
return torch.zeros_like(x)
|
||||
|
||||
result_without_guidance = rtc.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=mock_denoise,
|
||||
)
|
||||
|
||||
result_with_guidance = rtc.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_relatives,
|
||||
inference_delay=5,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=mock_denoise,
|
||||
)
|
||||
|
||||
assert not torch.allclose(result_with_guidance, result_without_guidance, atol=1e-6)
|
||||
|
||||
|
||||
class TestFullPipelineRelativeRTC:
|
||||
"""End-to-end test of the RTC + relative actions pipeline matching eval_with_real_robot.py flow."""
|
||||
|
||||
def test_preprocessor_caches_state_for_postprocessor(self):
|
||||
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
batch = {ACTION: actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
relative_step(transition)
|
||||
assert relative_step._last_state is not None
|
||||
torch.testing.assert_close(relative_step._last_state, state)
|
||||
|
||||
def test_preprocessor_caches_state_without_actions(self):
|
||||
"""During inference, preprocessor receives only observations (no actions).
|
||||
Relative step should still cache state for the postprocessor."""
|
||||
relative_step, _, _, _ = _make_relative_pipeline()
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
batch = {OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
relative_step(transition)
|
||||
assert relative_step._last_state is not None
|
||||
torch.testing.assert_close(relative_step._last_state, state)
|
||||
|
||||
def test_roundtrip_with_identity_normalization(self):
|
||||
"""Actions → relative → normalize → [model] → unnormalize → absolute should recover originals.
|
||||
Using mean=0, std=1 normalization (identity)."""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
batch = {ACTION: actions.clone(), OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
t1 = relative_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
model_output = t2[TransitionKey.ACTION].clone()
|
||||
|
||||
model_transition = {TransitionKey.ACTION: model_output}
|
||||
t3 = unnormalizer(model_transition)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_eval_loop_simulation(self):
|
||||
"""Simulate the eval_with_real_robot.py loop with relative actions.
|
||||
|
||||
Iteration 1: No leftovers → model generates relative actions → store for RTC
|
||||
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
|
||||
Both iterations: postprocessor converts relative actions to absolute for robot
|
||||
"""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
queue = ActionQueue(_make_rtc_config())
|
||||
|
||||
def mock_model(prev_chunk_left_over, inference_delay, state):
|
||||
"""Simulate model generating relative actions with RTC."""
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
def denoise(x):
|
||||
return -0.1 * x
|
||||
|
||||
guided_v = rtc.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=denoise,
|
||||
)
|
||||
return x_t - 0.5 * guided_v
|
||||
|
||||
# --- Iteration 1: first chunk, no leftovers ---
|
||||
state_1 = torch.randn(1, ACTION_DIM)
|
||||
obs_batch_1 = {OBS_STATE: state_1}
|
||||
relative_step(batch_to_transition(obs_batch_1))
|
||||
|
||||
model_relatives_1 = mock_model(prev_chunk_left_over=None, inference_delay=0, state=state_1)
|
||||
original_actions_1 = model_relatives_1.squeeze(0)
|
||||
|
||||
model_transition_1 = {TransitionKey.ACTION: model_relatives_1}
|
||||
postprocessed_1 = absolute_step(unnormalizer(model_transition_1))[TransitionKey.ACTION].squeeze(0)
|
||||
|
||||
queue.merge(original_actions_1, postprocessed_1, real_delay=0)
|
||||
|
||||
# Consume some actions (simulate robot executing)
|
||||
for _ in range(5):
|
||||
action = queue.get()
|
||||
assert action is not None
|
||||
|
||||
# --- Iteration 2: use leftovers for RTC ---
|
||||
prev_actions = queue.get_left_over()
|
||||
assert prev_actions is not None
|
||||
assert prev_actions.shape[0] == CHUNK_SIZE - 5
|
||||
|
||||
state_2 = state_1 + 0.01 * torch.randn(1, ACTION_DIM)
|
||||
obs_batch_2 = {OBS_STATE: state_2}
|
||||
relative_step(batch_to_transition(obs_batch_2))
|
||||
|
||||
model_relatives_2 = mock_model(
|
||||
prev_chunk_left_over=prev_actions.unsqueeze(0), inference_delay=3, state=state_2
|
||||
)
|
||||
original_actions_2 = model_relatives_2.squeeze(0)
|
||||
|
||||
model_transition_2 = {TransitionKey.ACTION: model_relatives_2}
|
||||
postprocessed_2 = absolute_step(unnormalizer(model_transition_2))[TransitionKey.ACTION].squeeze(0)
|
||||
|
||||
queue.merge(original_actions_2, postprocessed_2, real_delay=3)
|
||||
|
||||
# Postprocessed actions should be in absolute space
|
||||
action = queue.get()
|
||||
assert action is not None
|
||||
assert action.shape == (ACTION_DIM,)
|
||||
|
||||
# Verify leftovers are in relative space (original_queue stores relative actions)
|
||||
leftover_relatives = queue.get_left_over()
|
||||
assert leftover_relatives is not None
|
||||
assert leftover_relatives.shape[1] == ACTION_DIM
|
||||
|
||||
def test_postprocessor_uses_correct_state_per_iteration(self):
|
||||
"""Each iteration's postprocessor should use the state from that iteration's preprocessor,
|
||||
not a stale state from a previous iteration."""
|
||||
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
|
||||
state_1 = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
|
||||
state_2 = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
|
||||
relatives = torch.zeros(1, 5, ACTION_DIM)
|
||||
|
||||
# Iteration 1: cache state_1
|
||||
relative_step(batch_to_transition({OBS_STATE: state_1}))
|
||||
result_1 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
# relative=0 + state_1 should give state_1
|
||||
for t in range(5):
|
||||
torch.testing.assert_close(result_1[0, t], state_1[0], atol=1e-5, rtol=1e-5)
|
||||
|
||||
# Iteration 2: cache state_2
|
||||
relative_step(batch_to_transition({OBS_STATE: state_2}))
|
||||
result_2 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
|
||||
TransitionKey.ACTION
|
||||
]
|
||||
for t in range(5):
|
||||
torch.testing.assert_close(result_2[0, t], state_2[0], atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
class TestStateRebasingApproximation:
|
||||
"""Verify that the approximation from not rebasing leftover relative actions is small
|
||||
when state changes between inference calls are small (real-time control regime)."""
|
||||
|
||||
def test_small_state_change_produces_small_error(self):
|
||||
"""With small state changes (typical in real-time control),
|
||||
using stale relative actions for RTC guidance introduces negligible error."""
|
||||
state_old = torch.randn(1, ACTION_DIM)
|
||||
state_new = state_old + 0.01 * torch.randn(1, ACTION_DIM)
|
||||
|
||||
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
mask = [True] * ACTION_DIM
|
||||
|
||||
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
|
||||
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
|
||||
|
||||
error = (relatives_old - relatives_new).abs().mean()
|
||||
state_change = (state_old - state_new).abs().mean()
|
||||
|
||||
# Error should be proportional to state change
|
||||
assert error < 0.1, (
|
||||
f"Relative-action error {error:.4f} should be small for small state change {state_change:.4f}"
|
||||
)
|
||||
|
||||
def test_large_state_change_produces_proportional_error(self):
|
||||
"""With large state changes, stale relative actions diverge more (but RTC guidance decays)."""
|
||||
state_old = torch.randn(1, ACTION_DIM)
|
||||
state_new = state_old + 10.0 * torch.randn(1, ACTION_DIM)
|
||||
|
||||
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
mask = [True] * ACTION_DIM
|
||||
|
||||
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
|
||||
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
|
||||
|
||||
error = (relatives_old - relatives_new).abs().mean()
|
||||
state_change = (state_old - state_new).abs().mean()
|
||||
|
||||
# Error should be roughly equal to state change
|
||||
torch.testing.assert_close(
|
||||
error.clone().detach(), state_change.clone().detach(), atol=1e-5, rtol=1e-5
|
||||
)
|
||||
|
||||
def test_excluded_joints_not_affected_by_state_change(self):
|
||||
"""Joints excluded from relative conversion should not contribute rebasing error."""
|
||||
state_old = torch.randn(1, ACTION_DIM)
|
||||
state_new = state_old.clone()
|
||||
state_new[0, -1] = state_old[0, -1] + 100.0
|
||||
|
||||
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
mask = [True] * (ACTION_DIM - 1) + [False]
|
||||
|
||||
relatives_old = to_relative_actions(actions, state_old, mask)
|
||||
relatives_new = to_relative_actions(actions, state_new, mask)
|
||||
|
||||
# Last dim (excluded) should have zero error
|
||||
error_excluded = (relatives_old[..., -1] - relatives_new[..., -1]).abs().max()
|
||||
assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}"
|
||||
|
||||
|
||||
def _detect_relative_actions(preprocessor) -> bool:
|
||||
"""Mirror of the helper in eval_with_real_robot.py for testing without importing it."""
|
||||
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
|
||||
|
||||
|
||||
class TestDetectRelativeActions:
|
||||
"""Test the _detect_relative_actions helper logic used by eval_with_real_robot.py."""
|
||||
|
||||
def test_detects_enabled_relative_step(self):
|
||||
class FakePipeline:
|
||||
steps = [RelativeActionsProcessorStep(enabled=True)]
|
||||
|
||||
assert _detect_relative_actions(FakePipeline()) is True
|
||||
|
||||
def test_ignores_disabled_relative_step(self):
|
||||
class FakePipeline:
|
||||
steps = [RelativeActionsProcessorStep(enabled=False)]
|
||||
|
||||
assert _detect_relative_actions(FakePipeline()) is False
|
||||
|
||||
def test_returns_false_when_no_relative_step(self):
|
||||
class FakePipeline:
|
||||
steps = []
|
||||
|
||||
assert _detect_relative_actions(FakePipeline()) is False
|
||||
|
||||
|
||||
class TestNonRelativePolicy:
|
||||
"""Verify the same pipeline works when relative actions are disabled (standard absolute policy)."""
|
||||
|
||||
def test_disabled_relative_step_is_noop(self):
|
||||
relative_step = RelativeActionsProcessorStep(enabled=False)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
transition = batch_to_transition({ACTION: actions.clone(), OBS_STATE: state})
|
||||
t1 = relative_step(transition)
|
||||
torch.testing.assert_close(t1[TransitionKey.ACTION], actions)
|
||||
|
||||
t2 = absolute_step({TransitionKey.ACTION: actions.clone()})
|
||||
torch.testing.assert_close(t2[TransitionKey.ACTION], actions)
|
||||
|
||||
def test_eval_loop_without_relative_actions(self):
|
||||
"""Full eval loop simulation with relative actions disabled: original and processed actions are identical."""
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
stats = {
|
||||
ACTION: {
|
||||
"mean": torch.zeros(ACTION_DIM).numpy(),
|
||||
"std": torch.ones(ACTION_DIM).numpy(),
|
||||
}
|
||||
}
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(enabled=False)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
|
||||
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
queue = ActionQueue(_make_rtc_config())
|
||||
|
||||
state = torch.randn(1, ACTION_DIM)
|
||||
relative_step(batch_to_transition({OBS_STATE: state}))
|
||||
|
||||
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))[
|
||||
TransitionKey.ACTION
|
||||
].squeeze(0)
|
||||
original = model_output.squeeze(0)
|
||||
|
||||
# With identity norm and no relative-action transform, original and postprocessed should match
|
||||
torch.testing.assert_close(original, post, atol=1e-5, rtol=1e-5)
|
||||
|
||||
queue.merge(original, post, real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
queue.get()
|
||||
|
||||
prev_actions = queue.get_left_over()
|
||||
assert prev_actions is not None
|
||||
|
||||
# RTC guidance works the same way (absolute space)
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
result = rtc.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_actions.unsqueeze(0),
|
||||
inference_delay=3,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=lambda x: torch.zeros_like(x),
|
||||
)
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
def test_detect_relative_returns_false_when_disabled(self):
|
||||
class FakePipeline:
|
||||
steps = [RelativeActionsProcessorStep(enabled=False)]
|
||||
|
||||
assert not _detect_relative_actions(FakePipeline())
|
||||
|
||||
def test_detect_relative_returns_false_when_absent(self):
|
||||
class FakePipeline:
|
||||
steps = []
|
||||
|
||||
assert not _detect_relative_actions(FakePipeline())
|
||||
|
||||
|
||||
class TestMultiChunkConsistency:
|
||||
"""Test multiple RTC iterations with relative actions maintain consistency."""
|
||||
|
||||
def test_three_iteration_pipeline(self):
|
||||
"""Simulate three consecutive RTC iterations and verify queue state consistency."""
|
||||
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
queue = ActionQueue(_make_rtc_config())
|
||||
|
||||
states = [torch.randn(1, ACTION_DIM) + i * 0.01 for i in range(3)]
|
||||
|
||||
for i in range(3):
|
||||
queue.get_left_over()
|
||||
|
||||
relative_step(batch_to_transition({OBS_STATE: states[i]}))
|
||||
|
||||
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
post_transition = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))
|
||||
postprocessed = post_transition[TransitionKey.ACTION].squeeze(0)
|
||||
original = model_output.squeeze(0)
|
||||
|
||||
delay = min(i * 2, CHUNK_SIZE - 1)
|
||||
queue.merge(original, postprocessed, real_delay=delay)
|
||||
|
||||
for _ in range(5):
|
||||
action = queue.get()
|
||||
assert action is not None
|
||||
assert action.shape == (ACTION_DIM,)
|
||||
|
||||
# After 3 iterations, queue should still be in valid state
|
||||
assert queue.qsize() > 0
|
||||
leftover = queue.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.ndim == 2
|
||||
assert leftover.shape[1] == ACTION_DIM
|
||||
|
||||
def test_leftover_and_processed_differ_when_relative_enabled(self):
|
||||
"""With relative actions enabled, original leftovers (relative) must differ from processed (absolute)."""
|
||||
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
queue = ActionQueue(_make_rtc_config())
|
||||
|
||||
state = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: state}))
|
||||
|
||||
model_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
|
||||
TransitionKey.ACTION
|
||||
].squeeze(0)
|
||||
original = model_relatives.squeeze(0)
|
||||
|
||||
queue.merge(original, post, real_delay=0)
|
||||
|
||||
relative_leftover = queue.get_left_over()
|
||||
|
||||
# Leftovers (relative) must differ from the postprocessed absolute actions
|
||||
assert not torch.allclose(relative_leftover, post, atol=1e-3)
|
||||
state_expanded = state.squeeze(0).unsqueeze(0).expand_as(relative_leftover)
|
||||
torch.testing.assert_close(post, relative_leftover + state_expanded, atol=1e-5, rtol=1e-5)
|
||||
|
||||
def test_rtc_guidance_uses_relative_space(self):
|
||||
"""Verify that RTC denoise_step receives relative-space leftovers, not absolute."""
|
||||
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
||||
rtc = RTCProcessor(_make_rtc_config())
|
||||
queue = ActionQueue(_make_rtc_config())
|
||||
|
||||
state = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
|
||||
relative_step(batch_to_transition({OBS_STATE: state}))
|
||||
|
||||
model_relatives = 0.1 * torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
|
||||
TransitionKey.ACTION
|
||||
].squeeze(0)
|
||||
original = model_relatives.squeeze(0)
|
||||
|
||||
queue.merge(original, post, real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
queue.get()
|
||||
|
||||
prev_left_over = queue.get_left_over()
|
||||
|
||||
# prev_left_over should be small relative offsets (around 0.1 * randn), not large absolute values
|
||||
assert prev_left_over.abs().mean() < 5.0, (
|
||||
f"Leftover should be small relative offsets, got mean abs {prev_left_over.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
||||
|
||||
def denoise(x):
|
||||
return torch.zeros_like(x)
|
||||
|
||||
result = rtc.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_left_over.unsqueeze(0),
|
||||
inference_delay=3,
|
||||
time=0.5,
|
||||
original_denoise_step_partial=denoise,
|
||||
)
|
||||
|
||||
assert result.shape == x_t.shape
|
||||
@@ -0,0 +1,346 @@
|
||||
"""Tests for relative action transforms — full pipeline validation.
|
||||
|
||||
Tests the complete flow matching OpenPI:
|
||||
raw actions → RelativeActions → Normalize(relative_stats) → model → Unnormalize → AbsoluteActions
|
||||
|
||||
Uses real dataset: lerobot-data-collection/dagger_final_1_21
|
||||
"""
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.datasets.compute_stats import get_feature_stats
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
RelativeActionsProcessorStep,
|
||||
to_absolute_actions,
|
||||
to_relative_actions,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
CHUNK_SIZE = 10
|
||||
REPO_ID = "lerobot-data-collection/dagger_final_1_21"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def dataset():
|
||||
return LeRobotDataset(REPO_ID, episodes=[0])
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def action_dim(dataset):
|
||||
return dataset.meta.features["action"]["shape"][0]
|
||||
|
||||
|
||||
def _build_action_chunks(dataset, chunk_size, max_chunks=50):
|
||||
"""Build action chunks from hf_dataset, like the training script does."""
|
||||
hf = dataset.hf_dataset
|
||||
total = len(hf)
|
||||
all_ep = torch.tensor([int(hf[i]["episode_index"]) for i in range(total)])
|
||||
chunks, states = [], []
|
||||
for i in range(total - chunk_size + 1):
|
||||
if all_ep[i] != all_ep[i + chunk_size - 1]:
|
||||
continue
|
||||
chunk_actions = torch.stack([hf[i + k]["action"] for k in range(chunk_size)]).float()
|
||||
state = hf[i]["observation.state"].float()
|
||||
chunks.append(chunk_actions)
|
||||
states.append(state)
|
||||
if len(chunks) >= max_chunks:
|
||||
break
|
||||
assert len(chunks) > 0, f"No valid chunks found. total={total}, ep_indices={all_ep.tolist()}"
|
||||
return torch.stack(chunks), torch.stack(states)
|
||||
|
||||
|
||||
def _compute_relative_chunk_stats(action_chunks, states, mask):
|
||||
all_chunks = []
|
||||
for actions, state in zip(action_chunks, states, strict=True):
|
||||
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_chunks.append(relative.numpy())
|
||||
all_relative = np.concatenate(all_chunks, axis=0)
|
||||
return get_feature_stats(all_relative, axis=0, keepdims=all_relative.ndim == 1)
|
||||
|
||||
|
||||
# Basic roundtrip tests
|
||||
|
||||
|
||||
def test_roundtrip_3d(action_dim):
|
||||
actions = torch.randn(4, CHUNK_SIZE, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_roundtrip_2d(action_dim):
|
||||
actions = torch.randn(4, action_dim)
|
||||
state = torch.randn(4, action_dim)
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(to_relative_actions(actions, state, mask), state, mask)
|
||||
torch.testing.assert_close(recovered, actions)
|
||||
|
||||
|
||||
def test_no_mutation(action_dim):
|
||||
actions = torch.randn(2, CHUNK_SIZE, action_dim)
|
||||
original = actions.clone()
|
||||
state = torch.randn(2, action_dim)
|
||||
to_relative_actions(actions, state, [True] * action_dim)
|
||||
torch.testing.assert_close(actions, original)
|
||||
|
||||
|
||||
def test_exclude_joints_supports_partial_name_matching():
|
||||
names = [
|
||||
"right_joint_1.pos",
|
||||
"right_gripper.pos",
|
||||
"left_joint_1.pos",
|
||||
"left_gripper.pos",
|
||||
]
|
||||
step = RelativeActionsProcessorStep(enabled=True, exclude_joints=["gripper"], action_names=names)
|
||||
assert step._build_mask(len(names)) == [True, False, True, False]
|
||||
|
||||
|
||||
# Chunk-level relative stats test
|
||||
|
||||
|
||||
def test_chunk_stats_have_larger_std_than_frame_stats(dataset, action_dim):
|
||||
"""Chunk-level relative stats should have larger std than per-frame relative stats."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
chunk_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Per-frame stats
|
||||
hf = dataset.hf_dataset
|
||||
n = min(500, len(hf))
|
||||
frame_actions = torch.stack([hf[i]["action"] for i in range(n)]).float()
|
||||
frame_states = torch.stack([hf[i]["observation.state"] for i in range(n)]).float()
|
||||
frame_relatives = to_relative_actions(frame_actions, frame_states, mask).numpy()
|
||||
frame_stats = get_feature_stats(frame_relatives, axis=0, keepdims=frame_relatives.ndim == 1)
|
||||
|
||||
assert chunk_stats["std"].mean() >= frame_stats["std"].mean(), (
|
||||
f"Chunk std ({chunk_stats['std'].mean():.4f}) should be >= "
|
||||
f"frame std ({frame_stats['std'].mean():.4f})"
|
||||
)
|
||||
|
||||
|
||||
# Full pipeline roundtrip: relative → normalize → unnormalize → absolute
|
||||
|
||||
|
||||
def test_full_pipeline_roundtrip(dataset, action_dim):
|
||||
"""Test the complete OpenPI pipeline: relative → normalize → unnormalize → absolute."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: dict(relative_stats.items())}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: relative → normalize
|
||||
t1 = relative_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized_action = t2[TransitionKey.ACTION]
|
||||
assert normalized_action.abs().mean() < 10, (
|
||||
f"Normalized actions should be in reasonable range, got mean abs {normalized_action.abs().mean():.2f}"
|
||||
)
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered_actions = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered_actions, original_actions, atol=1e-4, rtol=1e-4)
|
||||
|
||||
|
||||
def test_normalized_relative_values_are_reasonable(dataset, action_dim):
|
||||
"""With correct chunk stats, normalized relative actions should be in a reasonable range."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
|
||||
mean = torch.tensor(relative_stats["mean"]).float()
|
||||
std = torch.tensor(relative_stats["std"]).float()
|
||||
|
||||
all_normalized = []
|
||||
for actions, state in zip(action_chunks, states, strict=True):
|
||||
relative = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
normalized = (relative - mean) / (std + 1e-6)
|
||||
all_normalized.append(normalized)
|
||||
|
||||
all_normalized = torch.cat(all_normalized, dim=0)
|
||||
|
||||
pct_in_range = (all_normalized.abs() < 5).float().mean()
|
||||
assert pct_in_range > 0.9, (
|
||||
f"Only {pct_in_range * 100:.1f}% of normalized values in [-5, 5], expected >90%"
|
||||
)
|
||||
|
||||
assert all_normalized.mean().abs() < 1.0, (
|
||||
f"Mean of normalized relative actions is {all_normalized.mean():.2f}, expected near 0"
|
||||
)
|
||||
|
||||
|
||||
def test_processor_step_roundtrip(dataset, action_dim):
|
||||
"""RelativeActionsProcessorStep applies relative offsets; to_absolute_actions recovers original."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_actions = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = RelativeActionsProcessorStep(enabled=True)
|
||||
relative_transition = step(transition)
|
||||
assert not torch.allclose(relative_transition[TransitionKey.ACTION], original_actions)
|
||||
|
||||
state = transition[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
mask = [True] * action_dim
|
||||
recovered = to_absolute_actions(relative_transition[TransitionKey.ACTION], state, mask)
|
||||
torch.testing.assert_close(recovered, original_actions)
|
||||
|
||||
|
||||
def test_processor_step_disabled_is_noop(dataset, action_dim):
|
||||
"""enabled=False should be a no-op."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(2)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(2)]),
|
||||
}
|
||||
original = batch[ACTION].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
result = RelativeActionsProcessorStep(enabled=False)(transition)
|
||||
torch.testing.assert_close(result[TransitionKey.ACTION], original)
|
||||
|
||||
|
||||
# Training batch shape validation
|
||||
|
||||
|
||||
def test_relative_with_action_chunks(dataset, action_dim):
|
||||
"""Verify relative actions work correctly with (B, chunk_size, action_dim) shaped actions."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
|
||||
# Simulate a training batch: actions=(B, chunk_size, action_dim), state=(B, state_dim)
|
||||
batch_actions = action_chunks[:4] # (4, chunk_size, action_dim)
|
||||
batch_states = states[:4] # (4, state_dim)
|
||||
|
||||
mask = [True] * action_dim
|
||||
relative = to_relative_actions(batch_actions, batch_states, mask)
|
||||
|
||||
# First action in each chunk should be close to zero (action[t] - state[t] ≈ small)
|
||||
first_relatives = relative[:, 0, :] # (B, action_dim)
|
||||
assert first_relatives.abs().mean() < relative.abs().mean(), (
|
||||
f"First action in chunk should have smaller relative offset than average. "
|
||||
f"First: {first_relatives.abs().mean():.4f}, Average: {relative.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Later actions should have larger relative offsets
|
||||
last_relatives = relative[:, -1, :] # (B, action_dim)
|
||||
assert last_relatives.abs().mean() >= first_relatives.abs().mean(), (
|
||||
f"Last action in chunk should have >= relative offset than first. "
|
||||
f"Last: {last_relatives.abs().mean():.4f}, First: {first_relatives.abs().mean():.4f}"
|
||||
)
|
||||
|
||||
# Roundtrip
|
||||
recovered = to_absolute_actions(relative, batch_states, mask)
|
||||
torch.testing.assert_close(recovered, batch_actions)
|
||||
|
||||
|
||||
def test_relative_stats_match_actual_data_distribution(dataset, action_dim):
|
||||
"""Verify computed stats match the actual relative-action distribution."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
# Compute stats like the training script does
|
||||
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
|
||||
|
||||
# Also compute directly
|
||||
all_relatives = []
|
||||
for actions, state in zip(action_chunks, states, strict=True):
|
||||
rel = to_relative_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
all_relatives.append(rel)
|
||||
all_relatives_tensor = torch.cat(all_relatives, dim=0)
|
||||
|
||||
# Compare mean
|
||||
actual_mean = all_relatives_tensor.mean(dim=0).numpy()
|
||||
np.testing.assert_allclose(relative_stats["mean"], actual_mean, atol=0.01)
|
||||
|
||||
# Compare std
|
||||
actual_std = all_relatives_tensor.std(dim=0).numpy()
|
||||
np.testing.assert_allclose(relative_stats["std"], actual_std, atol=0.1)
|
||||
|
||||
# Verify q01 < mean < q99
|
||||
assert (relative_stats["q01"] < relative_stats["mean"]).all(), "q01 should be < mean"
|
||||
assert (relative_stats["mean"] < relative_stats["q99"]).all(), "mean should be < q99"
|
||||
|
||||
|
||||
def test_quantile_normalization_roundtrip(dataset, action_dim):
|
||||
"""Full roundtrip with QUANTILES normalization (what OpenPI uses for pi05)."""
|
||||
action_chunks, states = _build_action_chunks(dataset, CHUNK_SIZE)
|
||||
mask = [True] * action_dim
|
||||
|
||||
relative_stats = _compute_relative_chunk_stats(action_chunks, states, mask)
|
||||
stats = {ACTION: dict(relative_stats.items())}
|
||||
|
||||
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
||||
norm_map = {FeatureType.ACTION: NormalizationMode.QUANTILES}
|
||||
|
||||
relative_step = RelativeActionsProcessorStep(enabled=True)
|
||||
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
||||
|
||||
original_actions = action_chunks[0].unsqueeze(0)
|
||||
state = states[0].unsqueeze(0)
|
||||
|
||||
batch = {ACTION: original_actions, OBS_STATE: state}
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Forward: relative → quantile normalize
|
||||
t1 = relative_step(transition)
|
||||
t2 = normalizer(t1)
|
||||
|
||||
normalized = t2[TransitionKey.ACTION]
|
||||
# Most values should be in [-1, 1] with quantile normalization
|
||||
pct_in_range = (normalized.abs() < 2).float().mean()
|
||||
assert pct_in_range > 0.5, f"Only {pct_in_range * 100:.1f}% in [-2, 2] after quantile norm, expected >50%"
|
||||
|
||||
# Reverse: unnormalize → absolute
|
||||
t3 = unnormalizer(t2)
|
||||
t4 = absolute_step(t3)
|
||||
|
||||
recovered = t4[TransitionKey.ACTION]
|
||||
torch.testing.assert_close(recovered, original_actions, atol=1e-3, rtol=1e-3)
|
||||
|
||||
|
||||
def test_state_not_modified_by_relative_processor(dataset, action_dim):
|
||||
"""State should never be modified by the relative-action processor."""
|
||||
hf = dataset.hf_dataset
|
||||
batch = {
|
||||
ACTION: torch.stack([hf[i]["action"] for i in range(4)]),
|
||||
OBS_STATE: torch.stack([hf[i]["observation.state"] for i in range(4)]),
|
||||
}
|
||||
original_state = batch[OBS_STATE].clone()
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
step = RelativeActionsProcessorStep(enabled=True)
|
||||
result = step(transition)
|
||||
|
||||
result_state = result[TransitionKey.OBSERVATION][OBS_STATE]
|
||||
torch.testing.assert_close(result_state, original_state)
|
||||
Reference in New Issue
Block a user