chore(rollout): remove speculative action order fix

This commit is contained in:
Pepijn
2026-04-24 16:24:44 +02:00
parent ee737b72d0
commit 727ca1a92c
2 changed files with 9 additions and 82 deletions
+9 -22
View File
@@ -24,9 +24,8 @@ from copy import copy
import torch import torch
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import prepare_observation_for_inference from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.constants import ACTION
from .base import InferenceEngine from .base import InferenceEngine
@@ -62,17 +61,6 @@ class SyncInferenceEngine(InferenceEngine):
self._robot_type = robot_type self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque() self._processed_action_queue: deque[torch.Tensor] = deque()
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
if self._relative_step is not None and self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
if action_names:
self._relative_step.action_names = list(action_names)
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
logger.info( logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)", "SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device, self._device,
@@ -96,7 +84,7 @@ class SyncInferenceEngine(InferenceEngine):
self._processed_action_queue.clear() self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None: def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Queue postprocessed per-step actions in policy output order.""" """Convert a postprocessed action chunk into ordered per-step CPU tensors."""
if action_chunk.ndim == 2: if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0) action_chunk = action_chunk.unsqueeze(0)
@@ -104,13 +92,12 @@ class SyncInferenceEngine(InferenceEngine):
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])] action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0): for action in action_chunk.squeeze(0):
action_tensor = action.detach().cpu() action_tensor = action.cpu()
if len(action_tensor) != len(self._ordered_action_keys): action_dict = make_robot_action(action_tensor, self._dataset_features)
raise ValueError( ordered_action = torch.tensor(
f"Action tensor length ({len(action_tensor)}) != action keys " [action_dict[k] for k in self._ordered_action_keys], dtype=action_tensor.dtype
f"({len(self._ordered_action_keys)})" )
) self._processed_action_queue.append(ordered_action)
self._processed_action_queue.append(action_tensor)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor.""" """Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
@@ -225,66 +225,6 @@ class TestRolloutInferenceRelativeActions:
torch.testing.assert_close(action_2, torch.from_numpy(state_1)) torch.testing.assert_close(action_2, torch.from_numpy(state_1))
assert policy.predict_calls == 1 assert policy.predict_calls == 1
def test_sync_engine_restores_action_names_for_relative_exclusions(self):
"""Serialized processors may omit action names; sync rollout must still honor gripper exclusions."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM - 1)] + ["gripper.pos"]
relative_step, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
exclude_joints=["gripper"],
)
relative_step.action_names = None
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
dataset_features = {ACTION: {"names": action_names}}
assert relative_step.action_names is None
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features=dataset_features,
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
expected = torch.from_numpy(state.copy())
expected[-1] = 0.0
torch.testing.assert_close(action, expected)
assert relative_step.action_names == action_names
def test_sync_engine_does_not_remap_chunk_through_dataset_action_names(self):
"""Postprocessed chunks are already in policy order; dataset feature order must not scramble them."""
action_names = [f"joint_{i}.pos" for i in range(ACTION_DIM)]
_, preprocessor, postprocessor = _make_relative_sync_pipelines(
ACTION_DIM,
action_names=action_names,
)
policy = _ChunkPolicyStub(action_dim=ACTION_DIM, n_action_steps=1)
policy.config.action_feature_names = action_names
engine = SyncInferenceEngine(
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset_features={ACTION: {"names": list(reversed(action_names))}},
ordered_action_keys=action_names,
task="test",
device="cpu",
robot_type="mock",
)
state = np.arange(1, ACTION_DIM + 1, dtype=np.float32)
action = engine.get_action({OBS_STATE: state.copy()})
torch.testing.assert_close(action, torch.from_numpy(state))
def test_rtc_reanchoring_prefers_raw_cached_state(self): def test_rtc_reanchoring_prefers_raw_cached_state(self):
"""RTC re-anchoring must use the raw state cached before observation normalization.""" """RTC re-anchoring must use the raw state cached before observation normalization."""
action_dim = ACTION_DIM action_dim = ACTION_DIM