fix(rollout): preserve relative action chunks

This commit is contained in:
Pepijn
2026-04-23 13:52:58 +02:00
parent eb9519eb91
commit e9f3f88377
3 changed files with 275 additions and 14 deletions
+18 -1
View File
@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
return padded
def _get_current_raw_state(
relative_step: RelativeActionsProcessorStep,
fallback_state: torch.Tensor | None,
) -> torch.Tensor | None:
"""Return the current raw state cached by the relative-action step.
``RelativeActionsProcessorStep`` caches the observation state before any
observation normalization. Re-anchoring RTC leftovers must use that raw
state rather than the normalized observation that the policy consumes.
"""
if relative_step._last_state is not None:
return relative_step._last_state
return fallback_state
# ---------------------------------------------------------------------------
# RTCInferenceEngine
# ---------------------------------------------------------------------------
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
preprocessed = self._preprocessor(obs_batch)
if prev_actions is not None and self._relative_step is not None:
state_tensor = preprocessed.get(OBS_STATE)
state_tensor = _get_current_raw_state(
self._relative_step, obs_batch.get(OBS_STATE)
)
if state_tensor is not None:
prev_abs = queue.get_processed_left_over()
if prev_abs is not None and prev_abs.numel() > 0:
+46 -12
View File
@@ -17,14 +17,16 @@
from __future__ import annotations
import logging
from collections import deque
from contextlib import nullcontext
from copy import copy
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
from lerobot.utils.constants import ACTION
from .base import InferenceEngine
@@ -34,9 +36,9 @@ logger = logging.getLogger(__name__)
class SyncInferenceEngine(InferenceEngine):
"""Inline synchronous inference: compute one action per call.
``get_action`` runs the full policy pipeline (pre/post-processor +
``select_action``) on the given observation frame and returns a
CPU action tensor reordered to match the dataset action keys.
``get_action`` runs the full policy pipeline when its local action
queue is empty, postprocesses the whole predicted chunk immediately,
and then returns one already-postprocessed CPU action at a time.
"""
def __init__(
@@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine):
self._task = task
self._device = torch.device(device or "cpu")
self._robot_type = robot_type
self._processed_action_queue: deque[torch.Tensor] = deque()
self._relative_step = next(
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
None,
)
if self._relative_step is not None and self._relative_step.action_names is None:
cfg_names = getattr(policy.config, "action_feature_names", None)
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
if action_names:
self._relative_step.action_names = list(action_names)
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
logger.info(
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
self._device,
@@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine):
self._policy.reset()
self._preprocessor.reset()
self._postprocessor.reset()
self._processed_action_queue.clear()
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
"""Queue postprocessed per-step actions in policy output order."""
if action_chunk.ndim == 2:
action_chunk = action_chunk.unsqueeze(0)
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
for action in action_chunk.squeeze(0):
action_tensor = action.detach().cpu()
if len(action_tensor) != len(self._ordered_action_keys):
raise ValueError(
f"Action tensor length ({len(action_tensor)}) != action keys "
f"({len(self._ordered_action_keys)})"
)
self._processed_action_queue.append(action_tensor)
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
if self._processed_action_queue:
return self._processed_action_queue.popleft().clone()
if obs_frame is None:
return None
# Shallow copy is intentional: the caller (`send_next_action`) builds
@@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine):
observation, self._device, self._task, self._robot_type
)
observation = self._preprocessor(observation)
action = self._policy.select_action(observation)
action = self._postprocessor(action)
action_tensor = action.squeeze(0).cpu()
action_chunk = self._policy.predict_action_chunk(observation)
processed_chunk = self._postprocessor(action_chunk)
# Reorder to match dataset action ordering so the caller can treat
# the returned tensor uniformly across backends.
action_dict = make_robot_action(action_tensor, self._dataset_features)
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
self._enqueue_processed_chunk(processed_chunk)
if not self._processed_action_queue:
return None
return self._processed_action_queue.popleft().clone()