mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix(rollout): preserve relative action chunks
This commit is contained in:
@@ -109,6 +109,21 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int
|
||||
return padded
|
||||
|
||||
|
||||
def _get_current_raw_state(
|
||||
relative_step: RelativeActionsProcessorStep,
|
||||
fallback_state: torch.Tensor | None,
|
||||
) -> torch.Tensor | None:
|
||||
"""Return the current raw state cached by the relative-action step.
|
||||
|
||||
``RelativeActionsProcessorStep`` caches the observation state before any
|
||||
observation normalization. Re-anchoring RTC leftovers must use that raw
|
||||
state rather than the normalized observation that the policy consumes.
|
||||
"""
|
||||
if relative_step._last_state is not None:
|
||||
return relative_step._last_state
|
||||
return fallback_state
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# RTCInferenceEngine
|
||||
# ---------------------------------------------------------------------------
|
||||
@@ -318,7 +333,9 @@ class RTCInferenceEngine(InferenceEngine):
|
||||
preprocessed = self._preprocessor(obs_batch)
|
||||
|
||||
if prev_actions is not None and self._relative_step is not None:
|
||||
state_tensor = preprocessed.get(OBS_STATE)
|
||||
state_tensor = _get_current_raw_state(
|
||||
self._relative_step, obs_batch.get(OBS_STATE)
|
||||
)
|
||||
if state_tensor is not None:
|
||||
prev_abs = queue.get_processed_left_over()
|
||||
if prev_abs is not None and prev_abs.numel() > 0:
|
||||
|
||||
@@ -17,14 +17,16 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections import deque
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline, RelativeActionsProcessorStep
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
from .base import InferenceEngine
|
||||
|
||||
@@ -34,9 +36,9 @@ logger = logging.getLogger(__name__)
|
||||
class SyncInferenceEngine(InferenceEngine):
|
||||
"""Inline synchronous inference: compute one action per call.
|
||||
|
||||
``get_action`` runs the full policy pipeline (pre/post-processor +
|
||||
``select_action``) on the given observation frame and returns a
|
||||
CPU action tensor reordered to match the dataset action keys.
|
||||
``get_action`` runs the full policy pipeline when its local action
|
||||
queue is empty, postprocesses the whole predicted chunk immediately,
|
||||
and then returns one already-postprocessed CPU action at a time.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
@@ -58,6 +60,19 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._task = task
|
||||
self._device = torch.device(device or "cpu")
|
||||
self._robot_type = robot_type
|
||||
self._processed_action_queue: deque[torch.Tensor] = deque()
|
||||
|
||||
self._relative_step = next(
|
||||
(s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled),
|
||||
None,
|
||||
)
|
||||
if self._relative_step is not None and self._relative_step.action_names is None:
|
||||
cfg_names = getattr(policy.config, "action_feature_names", None)
|
||||
action_names = cfg_names or dataset_features.get(ACTION, {}).get("names")
|
||||
if action_names:
|
||||
self._relative_step.action_names = list(action_names)
|
||||
logger.info("Relative actions enabled: sync chunks will be postprocessed before queueing")
|
||||
|
||||
logger.info(
|
||||
"SyncInferenceEngine initialized (device=%s, action_keys=%d)",
|
||||
self._device,
|
||||
@@ -78,9 +93,29 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
self._policy.reset()
|
||||
self._preprocessor.reset()
|
||||
self._postprocessor.reset()
|
||||
self._processed_action_queue.clear()
|
||||
|
||||
def _enqueue_processed_chunk(self, action_chunk: torch.Tensor) -> None:
|
||||
"""Queue postprocessed per-step actions in policy output order."""
|
||||
if action_chunk.ndim == 2:
|
||||
action_chunk = action_chunk.unsqueeze(0)
|
||||
|
||||
n_action_steps = getattr(self._policy.config, "n_action_steps", action_chunk.shape[1])
|
||||
action_chunk = action_chunk[:, : min(n_action_steps, action_chunk.shape[1])]
|
||||
|
||||
for action in action_chunk.squeeze(0):
|
||||
action_tensor = action.detach().cpu()
|
||||
if len(action_tensor) != len(self._ordered_action_keys):
|
||||
raise ValueError(
|
||||
f"Action tensor length ({len(action_tensor)}) != action keys "
|
||||
f"({len(self._ordered_action_keys)})"
|
||||
)
|
||||
self._processed_action_queue.append(action_tensor)
|
||||
|
||||
def get_action(self, obs_frame: dict | None) -> torch.Tensor | None:
|
||||
"""Run the full inference pipeline on ``obs_frame`` and return an action tensor."""
|
||||
if self._processed_action_queue:
|
||||
return self._processed_action_queue.popleft().clone()
|
||||
if obs_frame is None:
|
||||
return None
|
||||
# Shallow copy is intentional: the caller (`send_next_action`) builds
|
||||
@@ -97,11 +132,10 @@ class SyncInferenceEngine(InferenceEngine):
|
||||
observation, self._device, self._task, self._robot_type
|
||||
)
|
||||
observation = self._preprocessor(observation)
|
||||
action = self._policy.select_action(observation)
|
||||
action = self._postprocessor(action)
|
||||
action_tensor = action.squeeze(0).cpu()
|
||||
action_chunk = self._policy.predict_action_chunk(observation)
|
||||
processed_chunk = self._postprocessor(action_chunk)
|
||||
|
||||
# Reorder to match dataset action ordering so the caller can treat
|
||||
# the returned tensor uniformly across backends.
|
||||
action_dict = make_robot_action(action_tensor, self._dataset_features)
|
||||
return torch.tensor([action_dict[k] for k in self._ordered_action_keys])
|
||||
self._enqueue_processed_chunk(processed_chunk)
|
||||
if not self._processed_action_queue:
|
||||
return None
|
||||
return self._processed_action_queue.popleft().clone()
|
||||
|
||||
Reference in New Issue
Block a user