fix(rollout): preserve relative action chunks

This commit is contained in:
Pepijn
2026-04-23 13:52:58 +02:00
parent eb9519eb91
commit cf2e42f557
3 changed files with 275 additions and 14 deletions
+18 -1
View File
@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded return padded
def _get_current_raw_state(
relative_step: RelativeActionsProcessorStep,
fallback_state: torch.Tensor | None,
) -> torch.Tensor | None:
"""Return the current raw state cached by the relative-action step.
``RelativeActionsProcessorStep`` caches the observation state before any
observation normalization. Re-anchoring RTC leftovers must use that raw
state rather than the normalized observation that the policy consumes.
"""
if relative_step._last_state is not None:
return relative_step._last_state
return fallback_state
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
# RTCInferenceEngine # RTCInferenceEngine
# --------------------------------------------------------------------------- # ---------------------------------------------------------------------------
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
preprocessed = self._preprocessor(obs_batch) preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None: if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE) state_tensor = _get_current_raw_state(
self._relative_step, obs_batch.get(OBS_STATE)
)
if state_tensor is not None: if state_tensor is not None:
prev_abs = queue.get_processed_left_over() prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0: if prev_abs is not None and prev_abs.numel() > 0:
+46 -12
View File
@@ -17,14 +17,16 @@
from __future__ import annotations from __future__ import annotations
import logging import logging
from collections import deque
from contextlib import nullcontext from contextlib import nullcontext
from copy import copy from copy import copy
import torch import torch
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
from lerobot.utils.constants import ACTION
from .base import InferenceEngine from .base import InferenceEngine
@@ -34,9 +36,9 @@ logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine): class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call. """Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor + ``get_action`` runs the full policy pipeline when its local action
``select_action``) on the given observation frame and returns a queue is empty, postprocesses the whole predicted chunk immediately,
CPU action tensor reordered to match the dataset action keys. and then returns one already-postprocessed CPU action at a time.
""" """
def __init__( def __init__(
@@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task self._task = task
self._device = torch.device(device or "cpu") self._device = torch.device(device or "cpu")
self._robot_type = robot_type self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque()
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
if self._relative_step is not None and self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
if action_names:
self._relative_step.action_names = list(action_names)
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
logger.info( logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)", "SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device, self._device,
@@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine):
self._policy.reset() self._policy.reset()
self._preprocessor.reset() self._preprocessor.reset()
self._postprocessor.reset() self._postprocessor.reset()
self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Queue postprocessed per-step actions in policy output order."""
if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0)
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0):
action_tensor = action.detach().cpu()
if len(action_tensor) != len(self._ordered_action_keys):
raise ValueError(
f"Action tensor length ({len(action_tensor)}) != action keys "
f"({len(self._ordered_action_keys)})"
)
self._processed_action_queue.append(action_tensor)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" """Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if self._processed_action_queue:
return self._processed_action_queue.popleft().clone()
if obs_frame is None: if obs_frame is None:
return None return None
# Shallow copy is intentional: the caller (`send_next_action`) builds # Shallow copy is intentional: the caller (`send_next_action`) builds
@@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine):
observation, self._device, self._task, self._robot_type observation, self._device, self._task, self._robot_type
) )
observation = self._preprocessor(observation) observation = self._preprocessor(observation)
action = self._policy.select_action(observation) action_chunk = self._policy.predict_action_chunk(observation)
action = self._postprocessor(action) processed_chunk = self._postprocessor(action_chunk)
action_tensor = action.squeeze(0).cpu()
# Reorder to match dataset action ordering so the caller can treat self._enqueue_processed_chunk(processed_chunk)
# the returned tensor uniformly across backends. if not self._processed_action_queue:
action_dict = make_robot_action(action_tensor, self._dataset_features) return None
return torch.tensor([action_dict[k] for k in self._ordered_action_keys]) return self._processed_action_queue.popleft().clone()
+211 -1
View File
@@ -13,7 +13,9 @@ Flow under test:
import importlib.util import importlib.util
import sys import sys
from pathlib import Path from pathlib import Path
from types import ModuleType, SimpleNamespace
import numpy as np
import torch import torch
from lerobot.configs.types import ( from lerobot.configs.types import (
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
PolicyFeature, PolicyFeature,
RTCAttentionSchedule, RTCAttentionSchedule,
) )
from lerobot.processor import TransitionKey, batch_to_transition from lerobot.processor import (
PolicyProcessorPipeline,
TransitionKey,
batch_to_transition,
policy_action_to_transition,
transition_to_policy_action,
)
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
from lerobot.processor.relative_action_processor import ( from lerobot.processor.relative_action_processor import (
AbsoluteActionsProcessorStep, AbsoluteActionsProcessorStep,
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py") _rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
RTCProcessor = _rtc_mod.RTCProcessor RTCProcessor = _rtc_mod.RTCProcessor
def _ensure_rollout_test_packages() -> Path:
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
rollout_pkg.__path__ = [str(rollout_dir)]
inference_pkg = sys.modules.setdefault(
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
)
inference_pkg.__path__ = [str(rollout_dir / "inference")]
return rollout_dir
def _import_rollout_module(module_name: str, relative_path: str):
rollout_dir = _ensure_rollout_test_packages()
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
mod = importlib.util.module_from_spec(spec)
sys.modules[module_name] = mod
spec.loader.exec_module(mod)
return mod
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
ACTION_DIM = 6 ACTION_DIM = 6
CHUNK_SIZE = 50 CHUNK_SIZE = 50
EXECUTION_HORIZON = 10 EXECUTION_HORIZON = 10
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
return relative_step, normalizer, unnormalizer, absolute_step return relative_step, normalizer, unnormalizer, absolute_step
def _make_relative_sync_pipelines(
action_dim=ACTION_DIM,
action_names: list[str] | None = None,
exclude_joints: list[str] | None = None,
):
relative_step = RelativeActionsProcessorStep(
enabled=True,
exclude_joints=exclude_joints or [],
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
)
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
postprocessor = PolicyProcessorPipeline(
steps=[absolute_step],
name="test_postprocessor",
to_transition=policy_action_to_transition,
to_output=transition_to_policy_action,
)
return relative_step, preprocessor, postprocessor
class _ChunkPolicyStub:
def __init__(self, action_dim: int, n_action_steps: int):
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
self._chunk = torch.zeros(1, n_action_steps, action_dim)
self.predict_calls = 0
def reset(self):
return None
def predict_action_chunk(self, batch):
self.predict_calls += 1
return self._chunk.clone()
def select_action(self, batch):
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
class TestActionQueueRelativeActions: class TestActionQueueRelativeActions:
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot.""" """Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
@@ -120,6 +194,142 @@ class TestActionQueueRelativeActions:
torch.testing.assert_close(first_action, absolute_actions[0]) torch.testing.assert_close(first_action, absolute_actions[0])
class TestRolloutInferenceRelativeActions:
"""Regression tests for rollout inference engines with relative-action policies."""
def test_sync_engine_postprocesses_chunk_before_queueing(self):
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
dataset_features = {ACTION: {"names": ordered_action_keys}}
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=ordered_action_keys,
task="test",
device="cpu",
robot_type="mock",
)
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
state_2 = 10 * state_1
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
exclude_joints=["gripper"],
)
relative_step.action_names = None
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
dataset_features = {ACTION: {"names": action_names}}
assert relative_step.action_names is None
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
expected = torch.from_numpy(state.copy())
expected[-1] = 0.0
torch.testing.assert_close(action, expected)
assert relative_step.action_names == action_names
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features={ACTION: {"names": list(reversed(action_names))}},
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
torch.testing.assert_close(action, torch.from_numpy(state))
def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM
relative_step = RelativeActionsProcessorStep(
enabled=True,
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
)
features = {
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
}
stats = {
OBS_STATE: {
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
"std": 2 * np.ones(action_dim, dtype=np.float32),
},
ACTION: {
"mean": np.zeros(action_dim, dtype=np.float32),
"std": np.ones(action_dim, dtype=np.float32),
},
}
preprocessor = PolicyProcessorPipeline(
steps=[
relative_step,
NormalizerProcessorStep(
features=features,
norm_map={
FeatureType.STATE: NormalizationMode.MEAN_STD,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
},
stats=stats,
),
],
name="test_preprocessor_with_state_norm",
)
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
torch.testing.assert_close(current_state, raw_state)
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
class TestRTCDenoiseWithRelativeLeftovers: class TestRTCDenoiseWithRelativeLeftovers:
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over.""" """Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""