fix(rollout): preserve relative action chunks

This commit is contained in:
Pepijn
2026-04-23 13:52:58 +02:00
parent eb9519eb91
commit e9f3f88377
3 changed files with 275 additions and 14 deletions
+18 -1
View File
@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded
def _get_current_raw_state(
relative_step: RelativeActionsProcessorStep,
fallback_state: torch.Tensor | None,
) -> torch.Tensor | None:
"""Return the current raw state cached by the relative-action step.
``RelativeActionsProcessorStep`` caches the observation state before any
observation normalization. Re-anchoring RTC leftovers must use that raw
state rather than the normalized observation that the policy consumes.
"""
if relative_step._last_state is not None:
return relative_step._last_state
return fallback_state
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
state_tensor = _get_current_raw_state(
self._relative_step, obs_batch.get(OBS_STATE)
)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
+46 -12
View File
@@ -17,14 +17,16 @@
from __future__ import annotations
import logging
from collections import deque
from contextlib import nullcontext
from copy import copy
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
from lerobot.utils.constants import ACTION
from .base import InferenceEngine
@@ -34,9 +36,9 @@ logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
``select_action``) on the given observation frame and returns a
CPU action tensor reordered to match the dataset action keys.
``get_action`` runs the full policy pipeline when its local action
queue is empty, postprocesses the whole predicted chunk immediately,
and then returns one already-postprocessed CPU action at a time.
"""
def __init__(
@@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque()
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
if self._relative_step is not None and self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
if action_names:
self._relative_step.action_names = list(action_names)
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
@@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine):
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Queue postprocessed per-step actions in policy output order."""
if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0)
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0):
action_tensor = action.detach().cpu()
if len(action_tensor) != len(self._ordered_action_keys):
raise ValueError(
f"Action tensor length ({len(action_tensor)}) != action keys "
f"({len(self._ordered_action_keys)})"
)
self._processed_action_queue.append(action_tensor)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if self._processed_action_queue:
return self._processed_action_queue.popleft().clone()
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
@@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine):
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
action = self._postprocessor(action)
action_tensor = action.squeeze(0).cpu()
action_chunk = self._policy.predict_action_chunk(observation)
processed_chunk = self._postprocessor(action_chunk)
# Reorder to match dataset action ordering so the caller can treat
# the returned tensor uniformly across backends.
action_dict = make_robot_action(action_tensor, self._dataset_features)
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
self._enqueue_processed_chunk(processed_chunk)
if not self._processed_action_queue:
return None
return self._processed_action_queue.popleft().clone()
+211 -1
View File
@@ -13,7 +13,9 @@ Flow under test:
import importlib.util
import sys
from pathlib import Path
from types import ModuleType, SimpleNamespace
import numpy as np
import torch
from lerobot.configs.types import (
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
PolicyFeature,
RTCAttentionSchedule,
)
from lerobot.processor import TransitionKey, batch_to_transition
from lerobot.processor import (
PolicyProcessorPipeline,
TransitionKey,
batch_to_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep,
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor
def _ensure_rollout_test_packages() -> Path:
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
rollout_pkg.__path__ = [str(rollout_dir)]
inference_pkg = sys.modules.setdefault(
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
)
inference_pkg.__path__ = [str(rollout_dir / "inference")]
return rollout_dir
def _import_rollout_module(module_name: str, relative_path: str):
rollout_dir = _ensure_rollout_test_packages()
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
ACTION_DIM = 6
CHUNK_SIZE = 50
EXECUTION_HORIZON = 10
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
return relative_step, normalizer, unnormalizer, absolute_step
def _make_relative_sync_pipelines(
action_dim=ACTION_DIM,
action_names: list[str] | None = None,
exclude_joints: list[str] | None = None,
):
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints or [],
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
postprocessor = PolicyProcessorPipeline(
steps=[absolute_step],
name="test_postprocessor",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return relative_step, preprocessor, postprocessor
class _ChunkPolicyStub:
def __init__(self, action_dim: int, n_action_steps: int):
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
self._chunk = torch.zeros(1, n_action_steps, action_dim)
self.predict_calls = 0
def reset(self):
return None
def predict_action_chunk(self, batch):
self.predict_calls += 1
return self._chunk.clone()
def select_action(self, batch):
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
@@ -120,6 +194,142 @@ class TestActionQueueRelativeActions:
torch.testing.assert_close(first_action, absolute_actions[0])
class TestRolloutInferenceRelativeActions:
"""Regression tests for rollout inference engines with relative-action policies."""
def test_sync_engine_postprocesses_chunk_before_queueing(self):
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
dataset_features = {ACTION: {"names": ordered_action_keys}}
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task="test",
device="cpu",
robot_type="mock",
)
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
state_2 = 10 * state_1
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
exclude_joints=["gripper"],
)
relative_step.action_names = None
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
dataset_features = {ACTION: {"names": action_names}}
assert relative_step.action_names is None
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
expected = torch.from_numpy(state.copy())
expected[-1] = 0.0
torch.testing.assert_close(action, expected)
assert relative_step.action_names == action_names
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features={ACTION: {"names": list(reversed(action_names))}},
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
torch.testing.assert_close(action, torch.from_numpy(state))
def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM
relative_step = RelativeActionsProcessorStep(
enabled=True,
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
)
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
}
stats = {
OBS_STATE: {
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
"std": 2 * np.ones(action_dim, dtype=np.float32),
},
ACTION: {
"mean": np.zeros(action_dim, dtype=np.float32),
"std": np.ones(action_dim, dtype=np.float32),
},
}
preprocessor = PolicyProcessorPipeline(
steps=[
relative_step,
NormalizerProcessorStep(
features=features,
norm_map={
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
},
stats=stats,
),
],
name="test_preprocessor_with_state_norm",
)
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
torch.testing.assert_close(current_state, raw_state)
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""