From 8d1401abe31655b3801bad5eae9a0752490541ae Mon Sep 17 00:00:00 2001 From: Steven Palma Date: Wed, 22 Apr 2026 12:55:22 +0200 Subject: [PATCH] fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config --- src/lerobot/configs/dataset.py | 4 +-- src/lerobot/rollout/configs.py | 28 +++++++++++++++++++++ src/lerobot/rollout/context.py | 19 ++++++++------ src/lerobot/rollout/strategies/base.py | 4 +++ src/lerobot/rollout/strategies/core.py | 1 + src/lerobot/rollout/strategies/dagger.py | 8 ++++++ src/lerobot/rollout/strategies/highlight.py | 4 +++ src/lerobot/rollout/strategies/sentry.py | 4 +++ 8 files changed, 62 insertions(+), 10 deletions(-) diff --git a/src/lerobot/configs/dataset.py b/src/lerobot/configs/dataset.py index d025a1682..e359aadc7 100644 --- a/src/lerobot/configs/dataset.py +++ b/src/lerobot/configs/dataset.py @@ -14,7 +14,7 @@ """Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``.""" -from dataclasses import dataclass, field +from dataclasses import dataclass from datetime import datetime from pathlib import Path @@ -68,8 +68,6 @@ class DatasetRecordConfig: # Number of threads per encoder instance. None = auto (codec default). # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. encoder_threads: int | None = None - # Rename map for the observation to override the image and state keys - rename_map: dict[str, str] = field(default_factory=dict) def __post_init__(self) -> None: if self.repo_id: diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index d1527dc08..f74baa781 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -26,6 +26,7 @@ from lerobot.configs import PreTrainedConfig, parser from lerobot.configs.dataset import DatasetRecordConfig from lerobot.robots.config import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig +from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available from .inference import InferenceEngineConfig, SyncInferenceConfig @@ -205,6 +206,8 @@ class RolloutConfig: # Use vocal synthesis to read events play_sounds: bool = True resume: bool = False + # Rename map for mapping robot/dataset observation keys to policy keys + rename_map: dict[str, str] = field(default_factory=dict) # Torch compile use_torch_compile: bool = False @@ -285,6 +288,31 @@ class RolloutConfig: if self.policy is None: raise ValueError("--policy.path is required for rollout") + # --- Task resolution --- + # When --dataset.rename_map (or any --dataset.* flag) is passed, draccus + # creates a DatasetRecordConfig with single_task="". If the user set + # the task via the top-level --task flag, propagate it so that all + # downstream consumers (inference engine, dataset frame builders) see it. + if self.dataset is not None and not self.dataset.single_task and self.task: + logger.info("Propagating top-level task '%s' to dataset config", self.task) + self.dataset.single_task = self.task + elif self.dataset is not None and self.dataset.single_task and not self.task: + logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task) + self.task = self.dataset.single_task + + # --- Device resolution --- + # Resolve device from the policy config when not explicitly set so all + # components (policy.to, preprocessor, inference engine) use the same + # device string instead of inconsistent fallbacks. + if self.device is None or not is_torch_device_available(self.device): + resolved = self.policy.device + if resolved: + self.device = resolved + logger.info("Resolved device from policy config: %s", self.device) + else: + self.device = auto_select_torch_device().type + logger.info("No policy config to resolve device from; auto-selected device: %s", self.device) + @classmethod def __get_path_fields__(cls) -> list[str]: return ["policy"] diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index 273198d56..ee21be56e 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -272,11 +272,16 @@ def build_rollout_context( # ) # --- 4. Features + action-key reconciliation --------------------- + # TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and + # torque channels are observation-only and must be excluded from the state + # and action tensors that the policy sees. all_obs_features = robot.observation_features observation_features_hw = { - k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple) + k: v + for k, v in all_obs_features.items() + if isinstance(v, tuple) or (v is float and k.endswith(".pos")) } - action_features_hw = robot.action_features + action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")} # The action side is always needed: sync inference reads action names from # ``dataset_features[ACTION]`` to map policy tensors back to robot actions. @@ -293,7 +298,7 @@ def build_rollout_context( ) dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features) hw_features = hw_to_dataset_features(observation_features_hw, "observation") - raw_action_keys = list(robot.action_features.keys()) + raw_action_keys = list(action_features_hw.keys()) policy_action_names = getattr(policy_config, "action_feature_names", None) ordered_action_keys = _resolve_action_key_order( list(policy_action_names) if policy_action_names else None, @@ -301,7 +306,7 @@ def build_rollout_context( ) # Validate visual features if no rename_map is active - rename_map = cfg.dataset.rename_map if cfg.dataset else {} + rename_map = cfg.rename_map if not rename_map: expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL} provided_visuals = { @@ -366,7 +371,7 @@ def build_rollout_context( if dataset is not None: dataset_stats = rename_stats( dataset.meta.stats, - cfg.dataset.rename_map if cfg.dataset else {}, + cfg.rename_map, ) preprocessor, postprocessor = make_pre_post_processors( @@ -374,8 +379,8 @@ def build_rollout_context( pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset_stats, preprocessor_overrides={ - "device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")}, - "rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}}, + "device_processor": {"device": cfg.device}, + "rename_observations_processor": {"rename_map": cfg.rename_map}, }, ) diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index 6dca99b00..10871c424 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -72,6 +72,10 @@ class BaseStrategy(RolloutStrategy): dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) + else: + logging.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) def teardown(self, ctx: RolloutContext) -> None: """Disconnect hardware and stop inference.""" diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index f0f146109..5c336ea3a 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -62,6 +62,7 @@ class RolloutStrategy(abc.ABC): self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier) self._engine = ctx.policy.inference logger.info("Starting inference engine...") + self._engine.reset() self._engine.start() self._warmup_flushed = False logger.info("Inference engine started") diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index ba842844f..66fadee27 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -508,6 +508,10 @@ class DAggerStrategy(RolloutStrategy): dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) + else: + logging.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) finally: logger.info("DAgger continuous control loop ended — pausing engine") @@ -648,6 +652,10 @@ class DAggerStrategy(RolloutStrategy): dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) + else: + logging.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) finally: logger.info("DAgger corrections-only loop ended — pausing engine") diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 05b32e225..1d3a7e55a 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -189,6 +189,10 @@ class HighlightStrategy(RolloutStrategy): dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) + else: + logging.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) finally: logger.info("Highlight control loop ended") diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 047aced25..3a56e8163 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -160,6 +160,10 @@ class SentryStrategy(RolloutStrategy): dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) + else: + logging.warning( + f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation" + ) finally: logger.info("Sentry control loop ended — saving final episode")