mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
fix(rollout): preserve relative action chunks
This commit is contained in:
@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
|
|||||||
return padded
|
return padded
|
||||||
|
|
||||||
|
|
||||||
|
def _get_current_raw_state(
|
||||||
|
relative_step: RelativeActionsProcessorStep,
|
||||||
|
fallback_state: torch.Tensor | None,
|
||||||
|
) -> torch.Tensor | None:
|
||||||
|
"""Return the current raw state cached by the relative-action step.
|
||||||
|
|
||||||
|
``RelativeActionsProcessorStep`` caches the observation state before any
|
||||||
|
observation normalization. Re-anchoring RTC leftovers must use that raw
|
||||||
|
state rather than the normalized observation that the policy consumes.
|
||||||
|
"""
|
||||||
|
if relative_step._last_state is not None:
|
||||||
|
return relative_step._last_state
|
||||||
|
return fallback_state
|
||||||
|
|
||||||
|
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
# RTCInferenceEngine
|
# RTCInferenceEngine
|
||||||
# ---------------------------------------------------------------------------
|
# ---------------------------------------------------------------------------
|
||||||
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
|
|||||||
preprocessed = self._preprocessor(obs_batch)
|
preprocessed = self._preprocessor(obs_batch)
|
||||||
|
|
||||||
if prev_actions is not None and self._relative_step is not None:
|
if prev_actions is not None and self._relative_step is not None:
|
||||||
state_tensor = preprocessed.get(OBS_STATE)
|
state_tensor = _get_current_raw_state(
|
||||||
|
self._relative_step, obs_batch.get(OBS_STATE)
|
||||||
|
)
|
||||||
if state_tensor is not None:
|
if state_tensor is not None:
|
||||||
prev_abs = queue.get_processed_left_over()
|
prev_abs = queue.get_processed_left_over()
|
||||||
if prev_abs is not None and prev_abs.numel() > 0:
|
if prev_abs is not None and prev_abs.numel() > 0:
|
||||||
|
|||||||
@@ -17,14 +17,16 @@
|
|||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from collections import deque
|
||||||
from contextlib import nullcontext
|
from contextlib import nullcontext
|
||||||
from copy import copy
|
from copy import copy
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
from lerobot.policies.utils import prepare_observation_for_inference
|
||||||
from lerobot.processor import PolicyProcessorPipeline
|
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
|
||||||
|
from lerobot.utils.constants import ACTION
|
||||||
|
|
||||||
from .base import InferenceEngine
|
from .base import InferenceEngine
|
||||||
|
|
||||||
@@ -34,9 +36,9 @@ logger = logging.getLogger(__name__)
|
|||||||
class SyncInferenceEngine(InferenceEngine):
|
class SyncInferenceEngine(InferenceEngine):
|
||||||
"""Inline synchronous inference: compute one action per call.
|
"""Inline synchronous inference: compute one action per call.
|
||||||
|
|
||||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
``get_action`` runs the full policy pipeline when its local action
|
||||||
``select_action``) on the given observation frame and returns a
|
queue is empty, postprocesses the whole predicted chunk immediately,
|
||||||
CPU action tensor reordered to match the dataset action keys.
|
and then returns one already-postprocessed CPU action at a time.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
@@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine):
|
|||||||
self._task = task
|
self._task = task
|
||||||
self._device = torch.device(device or "cpu")
|
self._device = torch.device(device or "cpu")
|
||||||
self._robot_type = robot_type
|
self._robot_type = robot_type
|
||||||
|
self._processed_action_queue: deque[torch.Tensor] = deque()
|
||||||
|
|
||||||
|
self._relative_step = next(
|
||||||
|
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if self._relative_step is not None and self._relative_step.action_names is None:
|
||||||
|
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||||
|
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
|
||||||
|
if action_names:
|
||||||
|
self._relative_step.action_names = list(action_names)
|
||||||
|
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
|
||||||
|
|
||||||
logger.info(
|
logger.info(
|
||||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||||
self._device,
|
self._device,
|
||||||
@@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine):
|
|||||||
self._policy.reset()
|
self._policy.reset()
|
||||||
self._preprocessor.reset()
|
self._preprocessor.reset()
|
||||||
self._postprocessor.reset()
|
self._postprocessor.reset()
|
||||||
|
self._processed_action_queue.clear()
|
||||||
|
|
||||||
|
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
|
||||||
|
"""Queue postprocessed per-step actions in policy output order."""
|
||||||
|
if action_chunk.ndim == 2:
|
||||||
|
action_chunk = action_chunk.unsqueeze(0)
|
||||||
|
|
||||||
|
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
|
||||||
|
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
|
||||||
|
|
||||||
|
for action in action_chunk.squeeze(0):
|
||||||
|
action_tensor = action.detach().cpu()
|
||||||
|
if len(action_tensor) != len(self._ordered_action_keys):
|
||||||
|
raise ValueError(
|
||||||
|
f"Action tensor length ({len(action_tensor)}) != action keys "
|
||||||
|
f"({len(self._ordered_action_keys)})"
|
||||||
|
)
|
||||||
|
self._processed_action_queue.append(action_tensor)
|
||||||
|
|
||||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||||
|
if self._processed_action_queue:
|
||||||
|
return self._processed_action_queue.popleft().clone()
|
||||||
if obs_frame is None:
|
if obs_frame is None:
|
||||||
return None
|
return None
|
||||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||||
@@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine):
|
|||||||
observation, self._device, self._task, self._robot_type
|
observation, self._device, self._task, self._robot_type
|
||||||
)
|
)
|
||||||
observation = self._preprocessor(observation)
|
observation = self._preprocessor(observation)
|
||||||
action = self._policy.select_action(observation)
|
action_chunk = self._policy.predict_action_chunk(observation)
|
||||||
action = self._postprocessor(action)
|
processed_chunk = self._postprocessor(action_chunk)
|
||||||
action_tensor = action.squeeze(0).cpu()
|
|
||||||
|
|
||||||
# Reorder to match dataset action ordering so the caller can treat
|
self._enqueue_processed_chunk(processed_chunk)
|
||||||
# the returned tensor uniformly across backends.
|
if not self._processed_action_queue:
|
||||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
return None
|
||||||
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
|
return self._processed_action_queue.popleft().clone()
|
||||||
|
|||||||
@@ -13,7 +13,9 @@ Flow under test:
|
|||||||
import importlib.util
|
import importlib.util
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
from types import ModuleType, SimpleNamespace
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import (
|
from lerobot.configs.types import (
|
||||||
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
|
|||||||
PolicyFeature,
|
PolicyFeature,
|
||||||
RTCAttentionSchedule,
|
RTCAttentionSchedule,
|
||||||
)
|
)
|
||||||
from lerobot.processor import TransitionKey, batch_to_transition
|
from lerobot.processor import (
|
||||||
|
PolicyProcessorPipeline,
|
||||||
|
TransitionKey,
|
||||||
|
batch_to_transition,
|
||||||
|
policy_action_to_transition,
|
||||||
|
transition_to_policy_action,
|
||||||
|
)
|
||||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||||
from lerobot.processor.relative_action_processor import (
|
from lerobot.processor.relative_action_processor import (
|
||||||
AbsoluteActionsProcessorStep,
|
AbsoluteActionsProcessorStep,
|
||||||
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
|
|||||||
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
||||||
RTCProcessor = _rtc_mod.RTCProcessor
|
RTCProcessor = _rtc_mod.RTCProcessor
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_rollout_test_packages() -> Path:
|
||||||
|
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
|
||||||
|
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
|
||||||
|
rollout_pkg.__path__ = [str(rollout_dir)]
|
||||||
|
inference_pkg = sys.modules.setdefault(
|
||||||
|
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
|
||||||
|
)
|
||||||
|
inference_pkg.__path__ = [str(rollout_dir / "inference")]
|
||||||
|
return rollout_dir
|
||||||
|
|
||||||
|
|
||||||
|
def _import_rollout_module(module_name: str, relative_path: str):
|
||||||
|
rollout_dir = _ensure_rollout_test_packages()
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
|
||||||
|
mod = importlib.util.module_from_spec(spec)
|
||||||
|
sys.modules[module_name] = mod
|
||||||
|
spec.loader.exec_module(mod)
|
||||||
|
return mod
|
||||||
|
|
||||||
|
|
||||||
|
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
|
||||||
|
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
|
||||||
|
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
|
||||||
|
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
|
||||||
|
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
|
||||||
|
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
|
||||||
|
|
||||||
ACTION_DIM = 6
|
ACTION_DIM = 6
|
||||||
CHUNK_SIZE = 50
|
CHUNK_SIZE = 50
|
||||||
EXECUTION_HORIZON = 10
|
EXECUTION_HORIZON = 10
|
||||||
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
|
|||||||
return relative_step, normalizer, unnormalizer, absolute_step
|
return relative_step, normalizer, unnormalizer, absolute_step
|
||||||
|
|
||||||
|
|
||||||
|
def _make_relative_sync_pipelines(
|
||||||
|
action_dim=ACTION_DIM,
|
||||||
|
action_names: list[str] | None = None,
|
||||||
|
exclude_joints: list[str] | None = None,
|
||||||
|
):
|
||||||
|
relative_step = RelativeActionsProcessorStep(
|
||||||
|
enabled=True,
|
||||||
|
exclude_joints=exclude_joints or [],
|
||||||
|
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
|
||||||
|
)
|
||||||
|
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
||||||
|
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
|
||||||
|
postprocessor = PolicyProcessorPipeline(
|
||||||
|
steps=[absolute_step],
|
||||||
|
name="test_postprocessor",
|
||||||
|
to_transition=policy_action_to_transition,
|
||||||
|
to_output=transition_to_policy_action,
|
||||||
|
)
|
||||||
|
return relative_step, preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
|
class _ChunkPolicyStub:
|
||||||
|
def __init__(self, action_dim: int, n_action_steps: int):
|
||||||
|
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
|
||||||
|
self._chunk = torch.zeros(1, n_action_steps, action_dim)
|
||||||
|
self.predict_calls = 0
|
||||||
|
|
||||||
|
def reset(self):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def predict_action_chunk(self, batch):
|
||||||
|
self.predict_calls += 1
|
||||||
|
return self._chunk.clone()
|
||||||
|
|
||||||
|
def select_action(self, batch):
|
||||||
|
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
|
||||||
|
|
||||||
|
|
||||||
class TestActionQueueRelativeActions:
|
class TestActionQueueRelativeActions:
|
||||||
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
|
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
|
||||||
|
|
||||||
@@ -120,6 +194,142 @@ class TestActionQueueRelativeActions:
|
|||||||
torch.testing.assert_close(first_action, absolute_actions[0])
|
torch.testing.assert_close(first_action, absolute_actions[0])
|
||||||
|
|
||||||
|
|
||||||
|
class TestRolloutInferenceRelativeActions:
|
||||||
|
"""Regression tests for rollout inference engines with relative-action policies."""
|
||||||
|
|
||||||
|
def test_sync_engine_postprocesses_chunk_before_queueing(self):
|
||||||
|
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
|
||||||
|
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
|
||||||
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
|
||||||
|
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||||
|
dataset_features = {ACTION: {"names": ordered_action_keys}}
|
||||||
|
|
||||||
|
engine = SyncInferenceEngine(
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
dataset_features=dataset_features,
|
||||||
|
ordered_action_keys=ordered_action_keys,
|
||||||
|
task="test",
|
||||||
|
device="cpu",
|
||||||
|
robot_type="mock",
|
||||||
|
)
|
||||||
|
|
||||||
|
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||||
|
state_2 = 10 * state_1
|
||||||
|
|
||||||
|
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
|
||||||
|
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
|
||||||
|
|
||||||
|
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
|
||||||
|
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
|
||||||
|
assert policy.predict_calls == 1
|
||||||
|
|
||||||
|
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
|
||||||
|
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
|
||||||
|
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
|
||||||
|
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||||
|
ACTION_DIM,
|
||||||
|
action_names=action_names,
|
||||||
|
exclude_joints=["gripper"],
|
||||||
|
)
|
||||||
|
relative_step.action_names = None
|
||||||
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||||
|
policy.config.action_feature_names = action_names
|
||||||
|
dataset_features = {ACTION: {"names": action_names}}
|
||||||
|
|
||||||
|
assert relative_step.action_names is None
|
||||||
|
|
||||||
|
engine = SyncInferenceEngine(
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
dataset_features=dataset_features,
|
||||||
|
ordered_action_keys=action_names,
|
||||||
|
task="test",
|
||||||
|
device="cpu",
|
||||||
|
robot_type="mock",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||||
|
action = engine.get_action({OBS_STATE: state.copy()})
|
||||||
|
expected = torch.from_numpy(state.copy())
|
||||||
|
expected[-1] = 0.0
|
||||||
|
|
||||||
|
torch.testing.assert_close(action, expected)
|
||||||
|
assert relative_step.action_names == action_names
|
||||||
|
|
||||||
|
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
|
||||||
|
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
|
||||||
|
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||||
|
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||||
|
ACTION_DIM,
|
||||||
|
action_names=action_names,
|
||||||
|
)
|
||||||
|
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||||
|
policy.config.action_feature_names = action_names
|
||||||
|
|
||||||
|
engine = SyncInferenceEngine(
|
||||||
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
|
dataset_features={ACTION: {"names": list(reversed(action_names))}},
|
||||||
|
ordered_action_keys=action_names,
|
||||||
|
task="test",
|
||||||
|
device="cpu",
|
||||||
|
robot_type="mock",
|
||||||
|
)
|
||||||
|
|
||||||
|
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||||
|
action = engine.get_action({OBS_STATE: state.copy()})
|
||||||
|
|
||||||
|
torch.testing.assert_close(action, torch.from_numpy(state))
|
||||||
|
|
||||||
|
def test_rtc_reanchoring_prefers_raw_cached_state(self):
|
||||||
|
"""RTC re-anchoring must use the raw state cached before observation normalization."""
|
||||||
|
action_dim = ACTION_DIM
|
||||||
|
relative_step = RelativeActionsProcessorStep(
|
||||||
|
enabled=True,
|
||||||
|
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
|
||||||
|
)
|
||||||
|
features = {
|
||||||
|
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
|
||||||
|
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||||
|
}
|
||||||
|
stats = {
|
||||||
|
OBS_STATE: {
|
||||||
|
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
|
||||||
|
"std": 2 * np.ones(action_dim, dtype=np.float32),
|
||||||
|
},
|
||||||
|
ACTION: {
|
||||||
|
"mean": np.zeros(action_dim, dtype=np.float32),
|
||||||
|
"std": np.ones(action_dim, dtype=np.float32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
preprocessor = PolicyProcessorPipeline(
|
||||||
|
steps=[
|
||||||
|
relative_step,
|
||||||
|
NormalizerProcessorStep(
|
||||||
|
features=features,
|
||||||
|
norm_map={
|
||||||
|
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||||
|
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||||
|
},
|
||||||
|
stats=stats,
|
||||||
|
),
|
||||||
|
],
|
||||||
|
name="test_preprocessor_with_state_norm",
|
||||||
|
)
|
||||||
|
|
||||||
|
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
|
||||||
|
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
|
||||||
|
|
||||||
|
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
|
||||||
|
|
||||||
|
torch.testing.assert_close(current_state, raw_state)
|
||||||
|
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
|
||||||
|
|
||||||
|
|
||||||
class TestRTCDenoiseWithRelativeLeftovers:
|
class TestRTCDenoiseWithRelativeLeftovers:
|
||||||
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
|
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user