From cf2e42f55756d898583cd2901dac15876f3e0103 Mon Sep 17 00:00:00 2001 From: Pepijn Date: Thu, 23 Apr 2026 13:52:58 +0200 Subject: [PATCH] fix(rollout): preserve relative action chunks --- src/lerobot/rollout/inference/rtc.py | 19 +- src/lerobot/rollout/inference/sync.py | 58 ++++- .../policies/rtc/test_rtc_relative_actions.py | 212 +++++++++++++++++- 3 files changed, 275 insertions(+), 14 deletions(-) diff --git a/src/lerobot/rollout/inference/rtc.py b/src/lerobot/rollout/inference/rtc.py index ae8719b77..db9c3794e 100644 --- a/src/lerobot/rollout/inference/rtc.py +++ b/src/lerobot/rollout/inference/rtc.py @@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int return padded +def _get_current_raw_state( + relative_step: RelativeActionsProcessorStep, + fallback_state: torch.Tensor | None, +) -> torch.Tensor | None: + """Return the current raw state cached by the relative-action step. + + ``RelativeActionsProcessorStep`` caches the observation state before any + observation normalization. Re-anchoring RTC leftovers must use that raw + state rather than the normalized observation that the policy consumes. + """ + if relative_step._last_state is not None: + return relative_step._last_state + return fallback_state + + # --------------------------------------------------------------------------- # RTCInferenceEngine # --------------------------------------------------------------------------- @@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine): preprocessed = self._preprocessor(obs_batch) if prev_actions is not None and self._relative_step is not None: - state_tensor = preprocessed.get(OBS_STATE) + state_tensor = _get_current_raw_state( + self._relative_step, obs_batch.get(OBS_STATE) + ) if state_tensor is not None: prev_abs = queue.get_processed_left_over() if prev_abs is not None and prev_abs.numel() > 0: diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py index aaed0b356..f4e5596c9 100644 --- a/src/lerobot/rollout/inference/sync.py +++ b/src/lerobot/rollout/inference/sync.py @@ -17,14 +17,16 @@ from __future__ import annotations import logging +from collections import deque from contextlib import nullcontext from copy import copy import torch from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference -from lerobot.processor import PolicyProcessorPipeline +from lerobot.policies.utils import prepare_observation_for_inference +from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep +from lerobot.utils.constants import ACTION from .base import InferenceEngine @@ -34,9 +36,9 @@ logger = logging.getLogger(__name__) class SyncInferenceEngine(InferenceEngine): """Inline synchronous inference: compute one action per call. - ``get_action`` runs the full policy pipeline (pre/post-processor + - ``select_action``) on the given observation frame and returns a - CPU action tensor reordered to match the dataset action keys. + ``get_action`` runs the full policy pipeline when its local action + queue is empty, postprocesses the whole predicted chunk immediately, + and then returns one already-postprocessed CPU action at a time. """ def __init__( @@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine): self._task = task self._device = torch.device(device or "cpu") self._robot_type = robot_type + self._processed_action_queue: deque[torch.Tensor] = deque() + + self._relative_step = next( + (s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled), + None, + ) + if self._relative_step is not None and self._relative_step.action_names is None: + cfg_names = getattr(policy.config, "action_feature_names", None) + action_names = cfg_names or dataset_features.get(ACTION, {}).get("names") + if action_names: + self._relative_step.action_names = list(action_names) + logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing") + logger.info( "SyncInferenceEngine initialized (device=%s, action_keys=%d)", self._device, @@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine): self._policy.reset() self._preprocessor.reset() self._postprocessor.reset() + self._processed_action_queue.clear() + + def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None: + """Queue postprocessed per-step actions in policy output order.""" + if action_chunk.ndim == 2: + action_chunk = action_chunk.unsqueeze(0) + + n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1]) + action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])] + + for action in action_chunk.squeeze(0): + action_tensor = action.detach().cpu() + if len(action_tensor) != len(self._ordered_action_keys): + raise ValueError( + f"Action tensor length ({len(action_tensor)}) != action keys " + f"({len(self._ordered_action_keys)})" + ) + self._processed_action_queue.append(action_tensor) def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: """Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" + if self._processed_action_queue: + return self._processed_action_queue.popleft().clone() if obs_frame is None: return None # Shallow copy is intentional: the caller (`send_next_action`) builds @@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine): observation, self._device, self._task, self._robot_type ) observation = self._preprocessor(observation) - action = self._policy.select_action(observation) - action = self._postprocessor(action) - action_tensor = action.squeeze(0).cpu() + action_chunk = self._policy.predict_action_chunk(observation) + processed_chunk = self._postprocessor(action_chunk) - # Reorder to match dataset action ordering so the caller can treat - # the returned tensor uniformly across backends. - action_dict = make_robot_action(action_tensor, self._dataset_features) - return torch.tensor([action_dict[k] for k in self._ordered_action_keys]) + self._enqueue_processed_chunk(processed_chunk) + if not self._processed_action_queue: + return None + return self._processed_action_queue.popleft().clone() diff --git a/tests/policies/rtc/test_rtc_relative_actions.py b/tests/policies/rtc/test_rtc_relative_actions.py index 14c115764..5fb712a57 100644 --- a/tests/policies/rtc/test_rtc_relative_actions.py +++ b/tests/policies/rtc/test_rtc_relative_actions.py @@ -13,7 +13,9 @@ Flow under test: import importlib.util import sys from pathlib import Path +from types import ModuleType, SimpleNamespace +import numpy as np import torch from lerobot.configs.types import ( @@ -22,7 +24,13 @@ from lerobot.configs.types import ( PolicyFeature, RTCAttentionSchedule, ) -from lerobot.processor import TransitionKey, batch_to_transition +from lerobot.processor import ( + PolicyProcessorPipeline, + TransitionKey, + batch_to_transition, + policy_action_to_transition, + transition_to_policy_action, +) from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep from lerobot.processor.relative_action_processor import ( AbsoluteActionsProcessorStep, @@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug _rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py") RTCProcessor = _rtc_mod.RTCProcessor + +def _ensure_rollout_test_packages() -> Path: + rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout" + rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout")) + rollout_pkg.__path__ = [str(rollout_dir)] + inference_pkg = sys.modules.setdefault( + "lerobot.rollout.inference", ModuleType("lerobot.rollout.inference") + ) + inference_pkg.__path__ = [str(rollout_dir / "inference")] + return rollout_dir + + +def _import_rollout_module(module_name: str, relative_path: str): + rollout_dir = _ensure_rollout_test_packages() + spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path) + mod = importlib.util.module_from_spec(spec) + sys.modules[module_name] = mod + spec.loader.exec_module(mod) + return mod + + +_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py") +_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py") +_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py") +_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py") +SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine +get_current_raw_state = _rollout_rtc_mod._get_current_raw_state + ACTION_DIM = 6 CHUNK_SIZE = 50 EXECUTION_HORIZON = 10 @@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M return relative_step, normalizer, unnormalizer, absolute_step +def _make_relative_sync_pipelines( + action_dim=ACTION_DIM, + action_names: list[str] | None = None, + exclude_joints: list[str] | None = None, +): + relative_step = RelativeActionsProcessorStep( + enabled=True, + exclude_joints=exclude_joints or [], + action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)], + ) + absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step) + preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor") + postprocessor = PolicyProcessorPipeline( + steps=[absolute_step], + name="test_postprocessor", + to_transition=policy_action_to_transition, + to_output=transition_to_policy_action, + ) + return relative_step, preprocessor, postprocessor + + +class _ChunkPolicyStub: + def __init__(self, action_dim: int, n_action_steps: int): + self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps) + self._chunk = torch.zeros(1, n_action_steps, action_dim) + self.predict_calls = 0 + + def reset(self): + return None + + def predict_action_chunk(self, batch): + self.predict_calls += 1 + return self._chunk.clone() + + def select_action(self, batch): + raise AssertionError("SyncInferenceEngine should consume chunk outputs directly") + + class TestActionQueueRelativeActions: """Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot.""" @@ -120,6 +194,142 @@ class TestActionQueueRelativeActions: torch.testing.assert_close(first_action, absolute_actions[0]) +class TestRolloutInferenceRelativeActions: + """Regression tests for rollout inference engines with relative-action policies.""" + + def test_sync_engine_postprocesses_chunk_before_queueing(self): + """Queued sync actions must stay anchored to the state from the chunk-producing step.""" + _, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM) + policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3) + ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)] + dataset_features = {ACTION: {"names": ordered_action_keys}} + + engine = SyncInferenceEngine( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset_features=dataset_features, + ordered_action_keys=ordered_action_keys, + task="test", + device="cpu", + robot_type="mock", + ) + + state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32) + state_2 = 10 * state_1 + + action_1 = engine.get_action({OBS_STATE: state_1.copy()}) + action_2 = engine.get_action({OBS_STATE: state_2.copy()}) + + torch.testing.assert_close(action_1, torch.from_numpy(state_1)) + torch.testing.assert_close(action_2, torch.from_numpy(state_1)) + assert policy.predict_calls == 1 + + def test_sync_engine_restores_action_names_for_relative_exclusions(self): + """Serialized processors may omit action names; sync rollout must still honor gripper exclusions.""" + action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"] + relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines( + ACTION_DIM, + action_names=action_names, + exclude_joints=["gripper"], + ) + relative_step.action_names = None + policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1) + policy.config.action_feature_names = action_names + dataset_features = {ACTION: {"names": action_names}} + + assert relative_step.action_names is None + + engine = SyncInferenceEngine( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset_features=dataset_features, + ordered_action_keys=action_names, + task="test", + device="cpu", + robot_type="mock", + ) + + state = np.arange(1, ACTION_DIM + 1, dtype=np.float32) + action = engine.get_action({OBS_STATE: state.copy()}) + expected = torch.from_numpy(state.copy()) + expected[-1] = 0.0 + + torch.testing.assert_close(action, expected) + assert relative_step.action_names == action_names + + def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self): + """Postprocessed chunks are already in policy order; dataset feature order must not scramble them.""" + action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)] + _, preprocessor, postprocessor = _make_relative_sync_pipelines( + ACTION_DIM, + action_names=action_names, + ) + policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1) + policy.config.action_feature_names = action_names + + engine = SyncInferenceEngine( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset_features={ACTION: {"names": list(reversed(action_names))}}, + ordered_action_keys=action_names, + task="test", + device="cpu", + robot_type="mock", + ) + + state = np.arange(1, ACTION_DIM + 1, dtype=np.float32) + action = engine.get_action({OBS_STATE: state.copy()}) + + torch.testing.assert_close(action, torch.from_numpy(state)) + + def test_rtc_reanchoring_prefers_raw_cached_state(self): + """RTC re-anchoring must use the raw state cached before observation normalization.""" + action_dim = ACTION_DIM + relative_step = RelativeActionsProcessorStep( + enabled=True, + action_names=[f"joint_{i}.pos" for i in range(action_dim)], + ) + features = { + OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)), + ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)), + } + stats = { + OBS_STATE: { + "mean": np.arange(1, action_dim + 1, dtype=np.float32), + "std": 2 * np.ones(action_dim, dtype=np.float32), + }, + ACTION: { + "mean": np.zeros(action_dim, dtype=np.float32), + "std": np.ones(action_dim, dtype=np.float32), + }, + } + preprocessor = PolicyProcessorPipeline( + steps=[ + relative_step, + NormalizerProcessorStep( + features=features, + norm_map={ + FeatureType.STATE: NormalizationMode.MEAN_STD, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + }, + stats=stats, + ), + ], + name="test_preprocessor_with_state_norm", + ) + + raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0) + preprocessed = preprocessor({OBS_STATE: raw_state.clone()}) + + current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE)) + + torch.testing.assert_close(current_state, raw_state) + assert not torch.allclose(preprocessed[OBS_STATE], raw_state) + + class TestRTCDenoiseWithRelativeLeftovers: """Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""