mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 18:49:52 +00:00
fix(rollout): preserve relative action chunks
This commit is contained in:
@@ -13,7 +13,9 @@ Flow under test:
|
||||
import importlib.util
|
||||
import sys
|
||||
from pathlib import Path
|
||||
from types import ModuleType, SimpleNamespace
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import (
|
||||
@@ -22,7 +24,13 @@ from lerobot.configs.types import (
|
||||
PolicyFeature,
|
||||
RTCAttentionSchedule,
|
||||
)
|
||||
from lerobot.processor import TransitionKey, batch_to_transition
|
||||
from lerobot.processor import (
|
||||
PolicyProcessorPipeline,
|
||||
TransitionKey,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
)
|
||||
from lerobot.processor.normalize_processor import NormalizerProcessorStep, UnnormalizerProcessorStep
|
||||
from lerobot.processor.relative_action_processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
@@ -52,6 +60,34 @@ _rtc_debug_mod = _import_rtc_module("lerobot.policies.rtc.debug_tracker", "debug
|
||||
_rtc_mod = _import_rtc_module("lerobot.policies.rtc.modeling_rtc", "modeling_rtc.py")
|
||||
RTCProcessor = _rtc_mod.RTCProcessor
|
||||
|
||||
|
||||
def _ensure_rollout_test_packages() -> Path:
|
||||
rollout_dir = Path(__file__).resolve().parents[3] / "src" / "lerobot" / "rollout"
|
||||
rollout_pkg = sys.modules.setdefault("lerobot.rollout", ModuleType("lerobot.rollout"))
|
||||
rollout_pkg.__path__ = [str(rollout_dir)]
|
||||
inference_pkg = sys.modules.setdefault(
|
||||
"lerobot.rollout.inference", ModuleType("lerobot.rollout.inference")
|
||||
)
|
||||
inference_pkg.__path__ = [str(rollout_dir / "inference")]
|
||||
return rollout_dir
|
||||
|
||||
|
||||
def _import_rollout_module(module_name: str, relative_path: str):
|
||||
rollout_dir = _ensure_rollout_test_packages()
|
||||
spec = importlib.util.spec_from_file_location(module_name, rollout_dir / relative_path)
|
||||
mod = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = mod
|
||||
spec.loader.exec_module(mod)
|
||||
return mod
|
||||
|
||||
|
||||
_rollout_robot_wrapper_mod = _import_rollout_module("lerobot.rollout.robot_wrapper", "robot_wrapper.py")
|
||||
_rollout_base_mod = _import_rollout_module("lerobot.rollout.inference.base", "inference/base.py")
|
||||
_rollout_sync_mod = _import_rollout_module("lerobot.rollout.inference.sync", "inference/sync.py")
|
||||
_rollout_rtc_mod = _import_rollout_module("lerobot.rollout.inference.rtc", "inference/rtc.py")
|
||||
SyncInferenceEngine = _rollout_sync_mod.SyncInferenceEngine
|
||||
get_current_raw_state = _rollout_rtc_mod._get_current_raw_state
|
||||
|
||||
ACTION_DIM = 6
|
||||
CHUNK_SIZE = 50
|
||||
EXECUTION_HORIZON = 10
|
||||
@@ -89,6 +125,44 @@ def _make_relative_pipeline(action_dim=ACTION_DIM, norm_mode=NormalizationMode.M
|
||||
return relative_step, normalizer, unnormalizer, absolute_step
|
||||
|
||||
|
||||
def _make_relative_sync_pipelines(
|
||||
action_dim=ACTION_DIM,
|
||||
action_names: list[str] | None = None,
|
||||
exclude_joints: list[str] | None = None,
|
||||
):
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
exclude_joints=exclude_joints or [],
|
||||
action_names=action_names or [f"joint_{i}.pos" for i in range(action_dim)],
|
||||
)
|
||||
absolute_step = AbsoluteActionsProcessorStep(enabled=True, relative_step=relative_step)
|
||||
preprocessor = PolicyProcessorPipeline(steps=[relative_step], name="test_preprocessor")
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
steps=[absolute_step],
|
||||
name="test_postprocessor",
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
)
|
||||
return relative_step, preprocessor, postprocessor
|
||||
|
||||
|
||||
class _ChunkPolicyStub:
|
||||
def __init__(self, action_dim: int, n_action_steps: int):
|
||||
self.config = SimpleNamespace(use_amp=False, n_action_steps=n_action_steps)
|
||||
self._chunk = torch.zeros(1, n_action_steps, action_dim)
|
||||
self.predict_calls = 0
|
||||
|
||||
def reset(self):
|
||||
return None
|
||||
|
||||
def predict_action_chunk(self, batch):
|
||||
self.predict_calls += 1
|
||||
return self._chunk.clone()
|
||||
|
||||
def select_action(self, batch):
|
||||
raise AssertionError("SyncInferenceEngine should consume chunk outputs directly")
|
||||
|
||||
|
||||
class TestActionQueueRelativeActions:
|
||||
"""Verify ActionQueue stores model-space (relative) actions for RTC and absolute for robot."""
|
||||
|
||||
@@ -120,6 +194,142 @@ class TestActionQueueRelativeActions:
|
||||
torch.testing.assert_close(first_action, absolute_actions[0])
|
||||
|
||||
|
||||
class TestRolloutInferenceRelativeActions:
|
||||
"""Regression tests for rollout inference engines with relative-action policies."""
|
||||
|
||||
def test_sync_engine_postprocesses_chunk_before_queueing(self):
|
||||
"""Queued sync actions must stay anchored to the state from the chunk-producing step."""
|
||||
_, preprocessor, postprocessor = _make_relative_sync_pipelines(ACTION_DIM)
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=3)
|
||||
ordered_action_keys = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||
dataset_features = {ACTION: {"names": ordered_action_keys}}
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=ordered_action_keys,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state_1 = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
state_2 = 10 * state_1
|
||||
|
||||
action_1 = engine.get_action({OBS_STATE: state_1.copy()})
|
||||
action_2 = engine.get_action({OBS_STATE: state_2.copy()})
|
||||
|
||||
torch.testing.assert_close(action_1, torch.from_numpy(state_1))
|
||||
torch.testing.assert_close(action_2, torch.from_numpy(state_1))
|
||||
assert policy.predict_calls == 1
|
||||
|
||||
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
|
||||
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
|
||||
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
exclude_joints=["gripper"],
|
||||
)
|
||||
relative_step.action_names = None
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
dataset_features = {ACTION: {"names": action_names}}
|
||||
|
||||
assert relative_step.action_names is None
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features=dataset_features,
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
expected = torch.from_numpy(state.copy())
|
||||
expected[-1] = 0.0
|
||||
|
||||
torch.testing.assert_close(action, expected)
|
||||
assert relative_step.action_names == action_names
|
||||
|
||||
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
|
||||
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
|
||||
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
|
||||
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
|
||||
ACTION_DIM,
|
||||
action_names=action_names,
|
||||
)
|
||||
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
|
||||
policy.config.action_feature_names = action_names
|
||||
|
||||
engine = SyncInferenceEngine(
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset_features={ACTION: {"names": list(reversed(action_names))}},
|
||||
ordered_action_keys=action_names,
|
||||
task="test",
|
||||
device="cpu",
|
||||
robot_type="mock",
|
||||
)
|
||||
|
||||
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
|
||||
action = engine.get_action({OBS_STATE: state.copy()})
|
||||
|
||||
torch.testing.assert_close(action, torch.from_numpy(state))
|
||||
|
||||
def test_rtc_reanchoring_prefers_raw_cached_state(self):
|
||||
"""RTC re-anchoring must use the raw state cached before observation normalization."""
|
||||
action_dim = ACTION_DIM
|
||||
relative_step = RelativeActionsProcessorStep(
|
||||
enabled=True,
|
||||
action_names=[f"joint_{i}.pos" for i in range(action_dim)],
|
||||
)
|
||||
features = {
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(action_dim,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,)),
|
||||
}
|
||||
stats = {
|
||||
OBS_STATE: {
|
||||
"mean": np.arange(1, action_dim + 1, dtype=np.float32),
|
||||
"std": 2 * np.ones(action_dim, dtype=np.float32),
|
||||
},
|
||||
ACTION: {
|
||||
"mean": np.zeros(action_dim, dtype=np.float32),
|
||||
"std": np.ones(action_dim, dtype=np.float32),
|
||||
},
|
||||
}
|
||||
preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
relative_step,
|
||||
NormalizerProcessorStep(
|
||||
features=features,
|
||||
norm_map={
|
||||
FeatureType.STATE: NormalizationMode.MEAN_STD,
|
||||
FeatureType.ACTION: NormalizationMode.MEAN_STD,
|
||||
},
|
||||
stats=stats,
|
||||
),
|
||||
],
|
||||
name="test_preprocessor_with_state_norm",
|
||||
)
|
||||
|
||||
raw_state = torch.from_numpy(np.arange(5, 5 + action_dim, dtype=np.float32)).unsqueeze(0)
|
||||
preprocessed = preprocessor({OBS_STATE: raw_state.clone()})
|
||||
|
||||
current_state = get_current_raw_state(relative_step, preprocessed.get(OBS_STATE))
|
||||
|
||||
torch.testing.assert_close(current_state, raw_state)
|
||||
assert not torch.allclose(preprocessed[OBS_STATE], raw_state)
|
||||
|
||||
|
||||
class TestRTCDenoiseWithRelativeLeftovers:
|
||||
"""Verify RTC denoise_step correctly handles relative-space prev_chunk_left_over."""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user