mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
818 lines
32 KiB
Python
818 lines
32 KiB
Python
"""Tests for RTC + relative actions integration.
|
|
|
|
Validates that Real-Time Chunking (RTC) works correctly when the policy uses
|
|
relative actions. The key invariant: RTC guidance operates in model space
|
|
(normalized relative actions), while the robot receives absolute actions after postprocessing.
|
|
|
|
Flow under test:
|
|
Preprocessor: raw obs → relative step caches state → normalizer
|
|
Model: generates normalized relative actions (guided by RTC using leftover relative actions)
|
|
Postprocessor: unnormalize → absolute step (relative + cached state) → robot actions
|
|
"""
|
|
|
|
import importlib.util
|
|
import sys
|
|
from pathlib import Path
|
|
from types import ModuleType, SimpleNamespace
|
|
|
|
import numpy as np
|
|
import torch
|
|
|
|
from lerobot.configs.types import (
|
|
FeatureType,
|
|
NormalizationMode,
|
|
PolicyFeature,
|
|
RTCAttentionSchedule,
|
|
)
|
|
from lerobot.processor import (
|
|
PolicyProcessorPipeline,
|
|
TransitionKey,
|
|
batch_to_transition,
|
|
policy_action_to_transition,
|
|
transition_to_policy_action,
|
|
)
|
|
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
|
from lerobot.processor.relative_action_processor import (
|
|
AbsoluteActionsProcessorStep,
|
|
RelativeActionsProcessorStep,
|
|
to_relative_actions,
|
|
)
|
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
|
|
|
|
|
def _import_rtc_module(module_name: str, filename: str):
|
|
"""Import an RTC module directly from its file path, bypassing lerobot.policies.__init__."""
|
|
rtc_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "policies" / "rtc"
|
|
spec = importlib.util.spec_from_file_location(module_name, rtc_dir / filename)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = mod
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
_rtc_cfg_mod = _import_rtc_module("lerobot.policies.rtc.configuration_rtc", "configuration_rtc.py")
|
|
RTCConfig = _rtc_cfg_mod.RTCConfig
|
|
|
|
_action_queue_mod = _import_rtc_module("lerobot.policies.rtc.action_queue", "action_queue.py")
|
|
ActionQueue = _action_queue_mod.ActionQueue
|
|
|
|
_rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug_tracker.py")
|
|
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
|
RTCProcessor = _rtc_mod.RTCProcessor
|
|
|
|
|
|
def _ensure_rollout_test_packages() -> Path:
|
|
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
|
|
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
|
|
rollout_pkg.__path__ = [str(rollout_dir)]
|
|
inference_pkg = sys.modules.setdefault(
|
|
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
|
|
)
|
|
inference_pkg.__path__ = [str(rollout_dir / "inference")]
|
|
return rollout_dir
|
|
|
|
|
|
def _import_rollout_module(module_name: str, relative_path: str):
|
|
rollout_dir = _ensure_rollout_test_packages()
|
|
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
|
|
mod = importlib.util.module_from_spec(spec)
|
|
sys.modules[module_name] = mod
|
|
spec.loader.exec_module(mod)
|
|
return mod
|
|
|
|
|
|
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
|
|
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
|
|
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
|
|
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
|
|
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
|
|
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
|
|
|
|
ACTION_DIM = 6
|
|
CHUNK_SIZE = 50
|
|
EXECUTION_HORIZON = 10
|
|
|
|
|
|
def _make_rtc_config(enabled=True):
|
|
return RTCConfig(
|
|
enabled=enabled,
|
|
execution_horizon=EXECUTION_HORIZON,
|
|
max_guidance_weight=10.0,
|
|
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
|
)
|
|
|
|
|
|
def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.MEAN_STD):
|
|
"""Build paired relative/absolute processor steps and normalizer/unnormalizer."""
|
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))}
|
|
norm_map = {FeatureType.ACTION: norm_mode}
|
|
|
|
stats = {
|
|
ACTION: {
|
|
"mean": torch.zeros(action_dim).numpy(),
|
|
"std": torch.ones(action_dim).numpy(),
|
|
"q01": (-2 * torch.ones(action_dim)).numpy(),
|
|
"q99": (2 * torch.ones(action_dim)).numpy(),
|
|
"min": (-3 * torch.ones(action_dim)).numpy(),
|
|
"max": (3 * torch.ones(action_dim)).numpy(),
|
|
}
|
|
}
|
|
|
|
relative_step = RelativeActionsProcessorStep(enabled=True)
|
|
normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
|
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
|
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
|
return relative_step, normalizer, unnormalizer, absolute_step
|
|
|
|
|
|
def _make_relative_sync_pipelines(
|
|
action_dim=ACTION_DIM,
|
|
action_names: list[str] | None = None,
|
|
exclude_joints: list[str] | None = None,
|
|
):
|
|
relative_step = RelativeActionsProcessorStep(
|
|
enabled=True,
|
|
exclude_joints=exclude_joints or [],
|
|
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
|
|
)
|
|
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
|
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
|
|
postprocessor = PolicyProcessorPipeline(
|
|
steps=[absolute_step],
|
|
name="test_postprocessor",
|
|
to_transition=policy_action_to_transition,
|
|
to_output=transition_to_policy_action,
|
|
)
|
|
return relative_step, preprocessor, postprocessor
|
|
|
|
|
|
class _ChunkPolicyStub:
|
|
def __init__(self, action_dim: int, n_action_steps: int):
|
|
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
|
|
self._chunk = torch.zeros(1, n_action_steps, action_dim)
|
|
self.predict_calls = 0
|
|
|
|
def reset(self):
|
|
return None
|
|
|
|
def predict_action_chunk(self, batch):
|
|
self.predict_calls += 1
|
|
return self._chunk.clone()
|
|
|
|
def select_action(self, batch):
|
|
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
|
|
|
|
|
|
class TestActionQueueRelativeActions:
|
|
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
|
|
|
|
def test_left_over_returns_relative_actions(self):
|
|
"""get_left_over() should return the original (relative-space) actions."""
|
|
cfg = _make_rtc_config()
|
|
queue = ActionQueue(cfg)
|
|
|
|
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
|
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
|
queue.merge(relative_actions, absolute_actions, real_delay=0)
|
|
|
|
for _ in range(5):
|
|
queue.get()
|
|
|
|
leftover = queue.get_left_over()
|
|
torch.testing.assert_close(leftover, relative_actions[5:])
|
|
|
|
def test_robot_receives_absolute_actions(self):
|
|
"""The robot (via get()) should receive postprocessed absolute actions."""
|
|
cfg = _make_rtc_config()
|
|
queue = ActionQueue(cfg)
|
|
|
|
relative_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
|
absolute_actions = torch.randn(CHUNK_SIZE, ACTION_DIM)
|
|
queue.merge(relative_actions, absolute_actions, real_delay=0)
|
|
|
|
first_action = queue.get()
|
|
torch.testing.assert_close(first_action, absolute_actions[0])
|
|
|
|
|
|
class TestRolloutInferenceRelativeActions:
|
|
"""Regression tests for rollout inference engines with relative-action policies."""
|
|
|
|
def test_sync_engine_postprocesses_chunk_before_queueing(self):
|
|
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
|
|
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
|
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
|
|
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
|
dataset_features = {ACTION: {"names": ordered_action_keys}}
|
|
|
|
engine = SyncInferenceEngine(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset_features=dataset_features,
|
|
ordered_action_keys=ordered_action_keys,
|
|
task="test",
|
|
device="cpu",
|
|
robot_type="mock",
|
|
)
|
|
|
|
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
|
state_2 = 10 * state_1
|
|
|
|
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
|
|
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
|
|
|
|
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
|
|
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
|
|
assert policy.predict_calls == 1
|
|
|
|
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
|
|
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
|
|
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
|
|
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
|
ACTION_DIM,
|
|
action_names=action_names,
|
|
exclude_joints=["gripper"],
|
|
)
|
|
relative_step.action_names = None
|
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
|
policy.config.action_feature_names = action_names
|
|
dataset_features = {ACTION: {"names": action_names}}
|
|
|
|
assert relative_step.action_names is None
|
|
|
|
engine = SyncInferenceEngine(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset_features=dataset_features,
|
|
ordered_action_keys=action_names,
|
|
task="test",
|
|
device="cpu",
|
|
robot_type="mock",
|
|
)
|
|
|
|
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
|
action = engine.get_action({OBS_STATE: state.copy()})
|
|
expected = torch.from_numpy(state.copy())
|
|
expected[-1] = 0.0
|
|
|
|
torch.testing.assert_close(action, expected)
|
|
assert relative_step.action_names == action_names
|
|
|
|
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
|
|
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
|
|
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
|
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
|
ACTION_DIM,
|
|
action_names=action_names,
|
|
)
|
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
|
policy.config.action_feature_names = action_names
|
|
|
|
engine = SyncInferenceEngine(
|
|
policy=policy,
|
|
preprocessor=preprocessor,
|
|
postprocessor=postprocessor,
|
|
dataset_features={ACTION: {"names": list(reversed(action_names))}},
|
|
ordered_action_keys=action_names,
|
|
task="test",
|
|
device="cpu",
|
|
robot_type="mock",
|
|
)
|
|
|
|
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
|
action = engine.get_action({OBS_STATE: state.copy()})
|
|
|
|
torch.testing.assert_close(action, torch.from_numpy(state))
|
|
|
|
def test_rtc_reanchoring_prefers_raw_cached_state(self):
|
|
"""RTC re-anchoring must use the raw state cached before observation normalization."""
|
|
action_dim = ACTION_DIM
|
|
relative_step = RelativeActionsProcessorStep(
|
|
enabled=True,
|
|
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
|
|
)
|
|
features = {
|
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
|
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
|
}
|
|
stats = {
|
|
OBS_STATE: {
|
|
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
|
|
"std": 2 * np.ones(action_dim, dtype=np.float32),
|
|
},
|
|
ACTION: {
|
|
"mean": np.zeros(action_dim, dtype=np.float32),
|
|
"std": np.ones(action_dim, dtype=np.float32),
|
|
},
|
|
}
|
|
preprocessor = PolicyProcessorPipeline(
|
|
steps=[
|
|
relative_step,
|
|
NormalizerProcessorStep(
|
|
features=features,
|
|
norm_map={
|
|
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
|
},
|
|
stats=stats,
|
|
),
|
|
],
|
|
name="test_preprocessor_with_state_norm",
|
|
)
|
|
|
|
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
|
|
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
|
|
|
|
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
|
|
|
|
torch.testing.assert_close(current_state, raw_state)
|
|
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
|
|
|
|
|
|
class TestRTCDenoiseWithRelativeLeftovers:
|
|
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
|
|
|
|
def test_first_chunk_no_guidance(self):
|
|
"""First chunk (no leftovers) should return v_t without guidance."""
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
def mock_denoise(x):
|
|
return torch.ones_like(x)
|
|
|
|
result = rtc.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=None,
|
|
inference_delay=0,
|
|
time=0.5,
|
|
original_denoise_step_partial=mock_denoise,
|
|
)
|
|
torch.testing.assert_close(result, torch.ones_like(x_t))
|
|
|
|
def test_relative_leftovers_shape_preserved(self):
|
|
"""RTC output should have the same shape as input regardless of leftover shape."""
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
shorter_leftover = torch.randn(1, 20, ACTION_DIM)
|
|
|
|
def mock_denoise(x):
|
|
return torch.zeros_like(x)
|
|
|
|
result = rtc.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=shorter_leftover,
|
|
inference_delay=5,
|
|
time=0.5,
|
|
original_denoise_step_partial=mock_denoise,
|
|
)
|
|
assert result.shape == x_t.shape
|
|
|
|
def test_guidance_steers_toward_previous_relative_actions(self):
|
|
"""RTC guidance should push x1_t toward prev_chunk_left_over in relative space."""
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
prev_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
def mock_denoise(x):
|
|
return torch.zeros_like(x)
|
|
|
|
result_without_guidance = rtc.denoise_step(
|
|
x_t=x_t.clone(),
|
|
prev_chunk_left_over=None,
|
|
inference_delay=5,
|
|
time=0.5,
|
|
original_denoise_step_partial=mock_denoise,
|
|
)
|
|
|
|
result_with_guidance = rtc.denoise_step(
|
|
x_t=x_t.clone(),
|
|
prev_chunk_left_over=prev_relatives,
|
|
inference_delay=5,
|
|
time=0.5,
|
|
original_denoise_step_partial=mock_denoise,
|
|
)
|
|
|
|
assert not torch.allclose(result_with_guidance, result_without_guidance, atol=1e-6)
|
|
|
|
|
|
class TestFullPipelineRelativeRTC:
|
|
"""End-to-end test of the RTC + relative actions pipeline matching lerobot-rollout flow."""
|
|
|
|
def test_preprocessor_caches_state_for_postprocessor(self):
|
|
"""Preprocessor's relative step should cache state so postprocessor can convert back."""
|
|
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
|
|
state = torch.randn(1, ACTION_DIM)
|
|
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
batch = {ACTION: actions, OBS_STATE: state}
|
|
transition = batch_to_transition(batch)
|
|
|
|
relative_step(transition)
|
|
assert relative_step._last_state is not None
|
|
torch.testing.assert_close(relative_step._last_state, state)
|
|
|
|
def test_preprocessor_caches_state_without_actions(self):
|
|
"""During inference, preprocessor receives only observations (no actions).
|
|
Relative step should still cache state for the postprocessor."""
|
|
relative_step, _, _, _ = _make_relative_pipeline()
|
|
|
|
state = torch.randn(1, ACTION_DIM)
|
|
batch = {OBS_STATE: state}
|
|
transition = batch_to_transition(batch)
|
|
|
|
relative_step(transition)
|
|
assert relative_step._last_state is not None
|
|
torch.testing.assert_close(relative_step._last_state, state)
|
|
|
|
def test_roundtrip_with_identity_normalization(self):
|
|
"""Actions → relative → normalize → [model] → unnormalize → absolute should recover originals.
|
|
Using mean=0, std=1 normalization (identity)."""
|
|
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
|
|
state = torch.randn(1, ACTION_DIM)
|
|
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
batch = {ACTION: actions.clone(), OBS_STATE: state}
|
|
transition = batch_to_transition(batch)
|
|
|
|
t1 = relative_step(transition)
|
|
t2 = normalizer(t1)
|
|
|
|
model_output = t2[TransitionKey.ACTION].clone()
|
|
|
|
model_transition = {TransitionKey.ACTION: model_output}
|
|
t3 = unnormalizer(model_transition)
|
|
t4 = absolute_step(t3)
|
|
|
|
recovered = t4[TransitionKey.ACTION]
|
|
torch.testing.assert_close(recovered, actions, atol=1e-5, rtol=1e-5)
|
|
|
|
def test_eval_loop_simulation(self):
|
|
"""Simulate the lerobot-rollout loop with relative actions.
|
|
|
|
Iteration 1: No leftovers → model generates relative actions → store for RTC
|
|
Iteration 2: Use leftovers as RTC guidance → model generates new relative actions
|
|
Both iterations: postprocessor converts relative actions to absolute for robot
|
|
"""
|
|
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
queue = ActionQueue(_make_rtc_config())
|
|
|
|
def mock_model(prev_chunk_left_over, inference_delay, state):
|
|
"""Simulate model generating relative actions with RTC."""
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
def denoise(x):
|
|
return -0.1 * x
|
|
|
|
guided_v = rtc.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=prev_chunk_left_over,
|
|
inference_delay=inference_delay,
|
|
time=0.5,
|
|
original_denoise_step_partial=denoise,
|
|
)
|
|
return x_t - 0.5 * guided_v
|
|
|
|
# --- Iteration 1: first chunk, no leftovers ---
|
|
state_1 = torch.randn(1, ACTION_DIM)
|
|
obs_batch_1 = {OBS_STATE: state_1}
|
|
relative_step(batch_to_transition(obs_batch_1))
|
|
|
|
model_relatives_1 = mock_model(prev_chunk_left_over=None, inference_delay=0, state=state_1)
|
|
original_actions_1 = model_relatives_1.squeeze(0)
|
|
|
|
model_transition_1 = {TransitionKey.ACTION: model_relatives_1}
|
|
postprocessed_1 = absolute_step(unnormalizer(model_transition_1))[TransitionKey.ACTION].squeeze(0)
|
|
|
|
queue.merge(original_actions_1, postprocessed_1, real_delay=0)
|
|
|
|
# Consume some actions (simulate robot executing)
|
|
for _ in range(5):
|
|
action = queue.get()
|
|
assert action is not None
|
|
|
|
# --- Iteration 2: use leftovers for RTC ---
|
|
prev_actions = queue.get_left_over()
|
|
assert prev_actions is not None
|
|
assert prev_actions.shape[0] == CHUNK_SIZE - 5
|
|
|
|
state_2 = state_1 + 0.01 * torch.randn(1, ACTION_DIM)
|
|
obs_batch_2 = {OBS_STATE: state_2}
|
|
relative_step(batch_to_transition(obs_batch_2))
|
|
|
|
model_relatives_2 = mock_model(
|
|
prev_chunk_left_over=prev_actions.unsqueeze(0), inference_delay=3, state=state_2
|
|
)
|
|
original_actions_2 = model_relatives_2.squeeze(0)
|
|
|
|
model_transition_2 = {TransitionKey.ACTION: model_relatives_2}
|
|
postprocessed_2 = absolute_step(unnormalizer(model_transition_2))[TransitionKey.ACTION].squeeze(0)
|
|
|
|
queue.merge(original_actions_2, postprocessed_2, real_delay=3)
|
|
|
|
# Postprocessed actions should be in absolute space
|
|
action = queue.get()
|
|
assert action is not None
|
|
assert action.shape == (ACTION_DIM,)
|
|
|
|
# Verify leftovers are in relative space (original_queue stores relative actions)
|
|
leftover_relatives = queue.get_left_over()
|
|
assert leftover_relatives is not None
|
|
assert leftover_relatives.shape[1] == ACTION_DIM
|
|
|
|
def test_postprocessor_uses_correct_state_per_iteration(self):
|
|
"""Each iteration's postprocessor should use the state from that iteration's preprocessor,
|
|
not a stale state from a previous iteration."""
|
|
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
|
|
state_1 = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
|
|
state_2 = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
|
|
relatives = torch.zeros(1, 5, ACTION_DIM)
|
|
|
|
# Iteration 1: cache state_1
|
|
relative_step(batch_to_transition({OBS_STATE: state_1}))
|
|
result_1 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
|
|
TransitionKey.ACTION
|
|
]
|
|
# relative=0 + state_1 should give state_1
|
|
for t in range(5):
|
|
torch.testing.assert_close(result_1[0, t], state_1[0], atol=1e-5, rtol=1e-5)
|
|
|
|
# Iteration 2: cache state_2
|
|
relative_step(batch_to_transition({OBS_STATE: state_2}))
|
|
result_2 = absolute_step(unnormalizer({TransitionKey.ACTION: relatives.clone()}))[
|
|
TransitionKey.ACTION
|
|
]
|
|
for t in range(5):
|
|
torch.testing.assert_close(result_2[0, t], state_2[0], atol=1e-5, rtol=1e-5)
|
|
|
|
|
|
class TestStateRebasingApproximation:
|
|
"""Verify that the approximation from not rebasing leftover relative actions is small
|
|
when state changes between inference calls are small (real-time control regime)."""
|
|
|
|
def test_small_state_change_produces_small_error(self):
|
|
"""With small state changes (typical in real-time control),
|
|
using stale relative actions for RTC guidance introduces negligible error."""
|
|
state_old = torch.randn(1, ACTION_DIM)
|
|
state_new = state_old + 0.01 * torch.randn(1, ACTION_DIM)
|
|
|
|
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
mask = [True] * ACTION_DIM
|
|
|
|
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
|
|
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
|
|
|
|
error = (relatives_old - relatives_new).abs().mean()
|
|
state_change = (state_old - state_new).abs().mean()
|
|
|
|
# Error should be proportional to state change
|
|
assert error < 0.1, (
|
|
f"Relative-action error {error:.4f} should be small for small state change {state_change:.4f}"
|
|
)
|
|
|
|
def test_large_state_change_produces_proportional_error(self):
|
|
"""With large state changes, stale relative actions diverge more (but RTC guidance decays)."""
|
|
state_old = torch.randn(1, ACTION_DIM)
|
|
state_new = state_old + 10.0 * torch.randn(1, ACTION_DIM)
|
|
|
|
actions_absolute = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
mask = [True] * ACTION_DIM
|
|
|
|
relatives_old = to_relative_actions(actions_absolute, state_old, mask)
|
|
relatives_new = to_relative_actions(actions_absolute, state_new, mask)
|
|
|
|
error = (relatives_old - relatives_new).abs().mean()
|
|
state_change = (state_old - state_new).abs().mean()
|
|
|
|
# Error should be roughly equal to state change
|
|
torch.testing.assert_close(
|
|
error.clone().detach(), state_change.clone().detach(), atol=1e-5, rtol=1e-5
|
|
)
|
|
|
|
def test_excluded_joints_not_affected_by_state_change(self):
|
|
"""Joints excluded from relative conversion should not contribute rebasing error."""
|
|
state_old = torch.randn(1, ACTION_DIM)
|
|
state_new = state_old.clone()
|
|
state_new[0, -1] = state_old[0, -1] + 100.0
|
|
|
|
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
mask = [True] * (ACTION_DIM - 1) + [False]
|
|
|
|
relatives_old = to_relative_actions(actions, state_old, mask)
|
|
relatives_new = to_relative_actions(actions, state_new, mask)
|
|
|
|
# Last dim (excluded) should have zero error
|
|
error_excluded = (relatives_old[..., -1] - relatives_new[..., -1]).abs().max()
|
|
assert error_excluded < 1e-6, f"Excluded joint should have zero error, got {error_excluded}"
|
|
|
|
|
|
def _detect_relative_actions(preprocessor) -> bool:
|
|
"""Mirror of the helper in lerobot-rollout for testing without importing it."""
|
|
return any(isinstance(step, RelativeActionsProcessorStep) and step.enabled for step in preprocessor.steps)
|
|
|
|
|
|
class TestDetectRelativeActions:
|
|
"""Test the _detect_relative_actions helper logic used by lerobot-rollout."""
|
|
|
|
def test_detects_enabled_relative_step(self):
|
|
class FakePipeline:
|
|
steps = [RelativeActionsProcessorStep(enabled=True)]
|
|
|
|
assert _detect_relative_actions(FakePipeline()) is True
|
|
|
|
def test_ignores_disabled_relative_step(self):
|
|
class FakePipeline:
|
|
steps = [RelativeActionsProcessorStep(enabled=False)]
|
|
|
|
assert _detect_relative_actions(FakePipeline()) is False
|
|
|
|
def test_returns_false_when_no_relative_step(self):
|
|
class FakePipeline:
|
|
steps = []
|
|
|
|
assert _detect_relative_actions(FakePipeline()) is False
|
|
|
|
|
|
class TestNonRelativePolicy:
|
|
"""Verify the same pipeline works when relative actions are disabled (standard absolute policy)."""
|
|
|
|
def test_disabled_relative_step_is_noop(self):
|
|
relative_step = RelativeActionsProcessorStep(enabled=False)
|
|
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
|
|
|
|
state = torch.randn(1, ACTION_DIM)
|
|
actions = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
transition = batch_to_transition({ACTION: actions.clone(), OBS_STATE: state})
|
|
t1 = relative_step(transition)
|
|
torch.testing.assert_close(t1[TransitionKey.ACTION], actions)
|
|
|
|
t2 = absolute_step({TransitionKey.ACTION: actions.clone()})
|
|
torch.testing.assert_close(t2[TransitionKey.ACTION], actions)
|
|
|
|
def test_eval_loop_without_relative_actions(self):
|
|
"""Full eval loop simulation with relative actions disabled: original and processed actions are identical."""
|
|
features = {ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(ACTION_DIM,))}
|
|
norm_map = {FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
|
stats = {
|
|
ACTION: {
|
|
"mean": torch.zeros(ACTION_DIM).numpy(),
|
|
"std": torch.ones(ACTION_DIM).numpy(),
|
|
}
|
|
}
|
|
|
|
relative_step = RelativeActionsProcessorStep(enabled=False)
|
|
unnormalizer = UnnormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
|
|
absolute_step = AbsoluteActionsProcessorStep(enabled=False, relative_step=relative_step)
|
|
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
queue = ActionQueue(_make_rtc_config())
|
|
|
|
state = torch.randn(1, ACTION_DIM)
|
|
relative_step(batch_to_transition({OBS_STATE: state}))
|
|
|
|
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))[
|
|
TransitionKey.ACTION
|
|
].squeeze(0)
|
|
original = model_output.squeeze(0)
|
|
|
|
# With identity norm and no relative-action transform, original and postprocessed should match
|
|
torch.testing.assert_close(original, post, atol=1e-5, rtol=1e-5)
|
|
|
|
queue.merge(original, post, real_delay=0)
|
|
|
|
for _ in range(5):
|
|
queue.get()
|
|
|
|
prev_actions = queue.get_left_over()
|
|
assert prev_actions is not None
|
|
|
|
# RTC guidance works the same way (absolute space)
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
result = rtc.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=prev_actions.unsqueeze(0),
|
|
inference_delay=3,
|
|
time=0.5,
|
|
original_denoise_step_partial=lambda x: torch.zeros_like(x),
|
|
)
|
|
assert result.shape == x_t.shape
|
|
|
|
def test_detect_relative_returns_false_when_disabled(self):
|
|
class FakePipeline:
|
|
steps = [RelativeActionsProcessorStep(enabled=False)]
|
|
|
|
assert not _detect_relative_actions(FakePipeline())
|
|
|
|
def test_detect_relative_returns_false_when_absent(self):
|
|
class FakePipeline:
|
|
steps = []
|
|
|
|
assert not _detect_relative_actions(FakePipeline())
|
|
|
|
|
|
class TestMultiChunkConsistency:
|
|
"""Test multiple RTC iterations with relative actions maintain consistency."""
|
|
|
|
def test_three_iteration_pipeline(self):
|
|
"""Simulate three consecutive RTC iterations and verify queue state consistency."""
|
|
relative_step, normalizer, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
queue = ActionQueue(_make_rtc_config())
|
|
|
|
states = [torch.randn(1, ACTION_DIM) + i * 0.01 for i in range(3)]
|
|
|
|
for i in range(3):
|
|
queue.get_left_over()
|
|
|
|
relative_step(batch_to_transition({OBS_STATE: states[i]}))
|
|
|
|
model_output = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
post_transition = absolute_step(unnormalizer({TransitionKey.ACTION: model_output.clone()}))
|
|
postprocessed = post_transition[TransitionKey.ACTION].squeeze(0)
|
|
original = model_output.squeeze(0)
|
|
|
|
delay = min(i * 2, CHUNK_SIZE - 1)
|
|
queue.merge(original, postprocessed, real_delay=delay)
|
|
|
|
for _ in range(5):
|
|
action = queue.get()
|
|
assert action is not None
|
|
assert action.shape == (ACTION_DIM,)
|
|
|
|
# After 3 iterations, queue should still be in valid state
|
|
assert queue.qsize() > 0
|
|
leftover = queue.get_left_over()
|
|
assert leftover is not None
|
|
assert leftover.ndim == 2
|
|
assert leftover.shape[1] == ACTION_DIM
|
|
|
|
def test_leftover_and_processed_differ_when_relative_enabled(self):
|
|
"""With relative actions enabled, original leftovers (relative) must differ from processed (absolute)."""
|
|
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
queue = ActionQueue(_make_rtc_config())
|
|
|
|
state = torch.tensor([[1.0, 2.0, 3.0, 4.0, 5.0, 6.0]])
|
|
relative_step(batch_to_transition({OBS_STATE: state}))
|
|
|
|
model_relatives = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
|
|
TransitionKey.ACTION
|
|
].squeeze(0)
|
|
original = model_relatives.squeeze(0)
|
|
|
|
queue.merge(original, post, real_delay=0)
|
|
|
|
relative_leftover = queue.get_left_over()
|
|
|
|
# Leftovers (relative) must differ from the postprocessed absolute actions
|
|
assert not torch.allclose(relative_leftover, post, atol=1e-3)
|
|
state_expanded = state.squeeze(0).unsqueeze(0).expand_as(relative_leftover)
|
|
torch.testing.assert_close(post, relative_leftover + state_expanded, atol=1e-5, rtol=1e-5)
|
|
|
|
def test_rtc_guidance_uses_relative_space(self):
|
|
"""Verify that RTC denoise_step receives relative-space leftovers, not absolute."""
|
|
relative_step, _, unnormalizer, absolute_step = _make_relative_pipeline()
|
|
rtc = RTCProcessor(_make_rtc_config())
|
|
queue = ActionQueue(_make_rtc_config())
|
|
|
|
state = torch.tensor([[10.0, 20.0, 30.0, 40.0, 50.0, 60.0]])
|
|
relative_step(batch_to_transition({OBS_STATE: state}))
|
|
|
|
model_relatives = 0.1 * torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
post = absolute_step(unnormalizer({TransitionKey.ACTION: model_relatives.clone()}))[
|
|
TransitionKey.ACTION
|
|
].squeeze(0)
|
|
original = model_relatives.squeeze(0)
|
|
|
|
queue.merge(original, post, real_delay=0)
|
|
|
|
for _ in range(5):
|
|
queue.get()
|
|
|
|
prev_left_over = queue.get_left_over()
|
|
|
|
# prev_left_over should be small relative offsets (around 0.1 * randn), not large absolute values
|
|
assert prev_left_over.abs().mean() < 5.0, (
|
|
f"Leftover should be small relative offsets, got mean abs {prev_left_over.abs().mean():.2f}"
|
|
)
|
|
|
|
x_t = torch.randn(1, CHUNK_SIZE, ACTION_DIM)
|
|
|
|
def denoise(x):
|
|
return torch.zeros_like(x)
|
|
|
|
result = rtc.denoise_step(
|
|
x_t=x_t,
|
|
prev_chunk_left_over=prev_left_over.unsqueeze(0),
|
|
inference_delay=3,
|
|
time=0.5,
|
|
original_denoise_step_partial=denoise,
|
|
)
|
|
|
|
assert result.shape == x_t.shape
|