diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py index f4e5596c9..ebbfba928 100644 --- a/src/lerobot/rollout/inference/sync.py +++ b/src/lerobot/rollout/inference/sync.py @@ -24,9 +24,8 @@ from copy import copy import torch from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.policies.utils import prepare_observation_for_inference -from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep -from lerobot.utils.constants import ACTION +from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference +from lerobot.processor import PolicyProcessorPipeline from .base import InferenceEngine @@ -62,17 +61,6 @@ class SyncInferenceEngine(InferenceEngine): self._robot_type = robot_type self._processed_action_queue: deque[torch.Tensor] = deque() - self._relative_step = next( - (s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled), - None, - ) - if self._relative_step is not None and self._relative_step.action_names is None: - cfg_names = getattr(policy.config, "action_feature_names", None) - action_names = cfg_names or dataset_features.get(ACTION, {}).get("names") - if action_names: - self._relative_step.action_names = list(action_names) - logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing") - logger.info( "SyncInferenceEngine initialized (device=%s, action_keys=%d)", self._device, @@ -96,7 +84,7 @@ class SyncInferenceEngine(InferenceEngine): self._processed_action_queue.clear() def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None: - """Queue postprocessed per-step actions in policy output order.""" + """Convert a postprocessed action chunk into ordered per-step CPU tensors.""" if action_chunk.ndim == 2: action_chunk = action_chunk.unsqueeze(0) @@ -104,13 +92,12 @@ class SyncInferenceEngine(InferenceEngine): action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])] for action in action_chunk.squeeze(0): - action_tensor = action.detach().cpu() - if len(action_tensor) != len(self._ordered_action_keys): - raise ValueError( - f"Action tensor length ({len(action_tensor)}) != action keys " - f"({len(self._ordered_action_keys)})" - ) - self._processed_action_queue.append(action_tensor) + action_tensor = action.cpu() + action_dict = make_robot_action(action_tensor, self._dataset_features) + ordered_action = torch.tensor( + [action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype + ) + self._processed_action_queue.append(ordered_action) def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: """Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" diff --git a/tests/policies/rtc/test_rtc_relative_actions.py b/tests/policies/rtc/test_rtc_relative_actions.py index 5fb712a57..ea62ec4ed 100644 --- a/tests/policies/rtc/test_rtc_relative_actions.py +++ b/tests/policies/rtc/test_rtc_relative_actions.py @@ -225,66 +225,6 @@ class TestRolloutInferenceRelativeActions: torch.testing.assert_close(action_2, torch.from_numpy(state_1)) assert policy.predict_calls == 1 - def test_sync_engine_restores_action_names_for_relative_exclusions(self): - """Serialized processors may omit action names; sync rollout must still honor gripper exclusions.""" - action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"] - relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines( - ACTION_DIM, - action_names=action_names, - exclude_joints=["gripper"], - ) - relative_step.action_names = None - policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1) - policy.config.action_feature_names = action_names - dataset_features = {ACTION: {"names": action_names}} - - assert relative_step.action_names is None - - engine = SyncInferenceEngine( - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset_features=dataset_features, - ordered_action_keys=action_names, - task="test", - device="cpu", - robot_type="mock", - ) - - state = np.arange(1, ACTION_DIM + 1, dtype=np.float32) - action = engine.get_action({OBS_STATE: state.copy()}) - expected = torch.from_numpy(state.copy()) - expected[-1] = 0.0 - - torch.testing.assert_close(action, expected) - assert relative_step.action_names == action_names - - def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self): - """Postprocessed chunks are already in policy order; dataset feature order must not scramble them.""" - action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)] - _, preprocessor, postprocessor = _make_relative_sync_pipelines( - ACTION_DIM, - action_names=action_names, - ) - policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1) - policy.config.action_feature_names = action_names - - engine = SyncInferenceEngine( - policy=policy, - preprocessor=preprocessor, - postprocessor=postprocessor, - dataset_features={ACTION: {"names": list(reversed(action_names))}}, - ordered_action_keys=action_names, - task="test", - device="cpu", - robot_type="mock", - ) - - state = np.arange(1, ACTION_DIM + 1, dtype=np.float32) - action = engine.get_action({OBS_STATE: state.copy()}) - - torch.testing.assert_close(action, torch.from_numpy(state)) - def test_rtc_reanchoring_prefers_raw_cached_state(self): """RTC re-anchoring must use the raw state cached before observation normalization.""" action_dim = ACTION_DIM