fix(rollout): preserve relative action chunks

This commit is contained in:
Pepijn
2026-04-23 13:52:58 +02:00
parent eb9519eb91
commit e9f3f88377
3 changed files with 275 additions and 14 deletions
+211 -1
View File
@@ -13,7 +13,9 @@ Flow under test:
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import numpy as np
import torch
from lerobot.configs.types import (
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
PolicyFeature,
RTCAttentionSchedule,
)
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor import (
PolicyProcessorPipeline,
TransitionKey,
batch_to_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor
def _ensure_rollout_test_packages() -> Path:
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
rollout_pkg.__path__ = [str(rollout_dir)]
inference_pkg = sys.modules.setdefault(
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
)
inference_pkg.__path__ = [str(rollout_dir / "inference")]
return rollout_dir
def _import_rollout_module(module_name: str, relative_path: str):
rollout_dir = _ensure_rollout_test_packages()
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
ACTION_DIM = 6
CHUNK_SIZE = 50
EXECUTION_HORIZON = 10
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
return relative_step, normalizer, unnormalizer, absolute_step
def _make_relative_sync_pipelines(
action_dim=ACTION_DIM,
action_names: list[str] | None = None,
exclude_joints: list[str] | None = None,
):
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints or [],
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
postprocessor = PolicyProcessorPipeline(
steps=[absolute_step],
name="test_postprocessor",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return relative_step, preprocessor, postprocessor
class _ChunkPolicyStub:
def __init__(self, action_dim: int, n_action_steps: int):
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
self._chunk = torch.zeros(1, n_action_steps, action_dim)
self.predict_calls = 0
def reset(self):
return None
def predict_action_chunk(self, batch):
self.predict_calls += 1
return self._chunk.clone()
def select_action(self, batch):
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
@@ -120,6 +194,142 @@ class TestActionQueueRelativeActions:
torch.testing.assert_close(first_action, absolute_actions[0])
class TestRolloutInferenceRelativeActions:
"""Regression tests for rollout inference engines with relative-action policies."""
def test_sync_engine_postprocesses_chunk_before_queueing(self):
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
dataset_features = {ACTION: {"names": ordered_action_keys}}
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task="test",
device="cpu",
robot_type="mock",
)
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
state_2 = 10 * state_1
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
exclude_joints=["gripper"],
)
relative_step.action_names = None
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
dataset_features = {ACTION: {"names": action_names}}
assert relative_step.action_names is None
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
expected = torch.from_numpy(state.copy())
expected[-1] = 0.0
torch.testing.assert_close(action, expected)
assert relative_step.action_names == action_names
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features={ACTION: {"names": list(reversed(action_names))}},
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
torch.testing.assert_close(action, torch.from_numpy(state))
def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM
relative_step = RelativeActionsProcessorStep(
enabled=True,
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
)
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
}
stats = {
OBS_STATE: {
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
"std": 2 * np.ones(action_dim, dtype=np.float32),
},
ACTION: {
"mean": np.zeros(action_dim, dtype=np.float32),
"std": np.ones(action_dim, dtype=np.float32),
},
}
preprocessor = PolicyProcessorPipeline(
steps=[
relative_step,
NormalizerProcessorStep(
features=features,
norm_map={
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
},
stats=stats,
),
],
name="test_preprocessor_with_state_norm",
)
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
torch.testing.assert_close(current_state, raw_state)
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""