fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config

This commit is contained in:
Steven Palma
2026-04-22 12:55:22 +02:00
parent 195b777367
commit 8d1401abe3
8 changed files with 62 additions and 10 deletions
+1 -3
View File
@@ -14,7 +14,7 @@
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field
from dataclasses import dataclass
from datetime import datetime
from pathlib import Path
@@ -68,8 +68,6 @@ class DatasetRecordConfig:
# Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None:
if self.repo_id:
+28
View File
@@ -26,6 +26,7 @@ from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from .inference import InferenceEngineConfig, SyncInferenceConfig
@@ -205,6 +206,8 @@ class RolloutConfig:
# Use vocal synthesis to read events
play_sounds: bool = True
resume: bool = False
# Rename map for mapping robot/dataset observation keys to policy keys
rename_map: dict[str, str] = field(default_factory=dict)
# Torch compile
use_torch_compile: bool = False
@@ -285,6 +288,31 @@ class RolloutConfig:
if self.policy is None:
raise ValueError("--policy.path is required for rollout")
# --- Task resolution ---
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
# creates a DatasetRecordConfig with single_task="". If the user set
# the task via the top-level --task flag, propagate it so that all
# downstream consumers (inference engine, dataset frame builders) see it.
if self.dataset is not None and not self.dataset.single_task and self.task:
logger.info("Propagating top-level task '%s' to dataset config", self.task)
self.dataset.single_task = self.task
elif self.dataset is not None and self.dataset.single_task and not self.task:
logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task)
self.task = self.dataset.single_task
# --- Device resolution ---
# Resolve device from the policy config when not explicitly set so all
# components (policy.to, preprocessor, inference engine) use the same
# device string instead of inconsistent fallbacks.
if self.device is None or not is_torch_device_available(self.device):
resolved = self.policy.device
if resolved:
self.device = resolved
logger.info("Resolved device from policy config: %s", self.device)
else:
self.device = auto_select_torch_device().type
logger.info("No policy config to resolve device from; auto-selected device: %s", self.device)
@classmethod
def __get_path_fields__(cls) -> list[str]:
return ["policy"]
+12 -7
View File
@@ -272,11 +272,16 @@ def build_rollout_context(
# )
# --- 4. Features + action-key reconciliation ---------------------
# TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and
# torque channels are observation-only and must be excluded from the state
# and action tensors that the policy sees.
all_obs_features = robot.observation_features
observation_features_hw = {
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
k: v
for k, v in all_obs_features.items()
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
}
action_features_hw = robot.action_features
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
# The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
@@ -293,7 +298,7 @@ def build_rollout_context(
)
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys())
raw_action_keys = list(action_features_hw.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None,
@@ -301,7 +306,7 @@ def build_rollout_context(
)
# Validate visual features if no rename_map is active
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
rename_map = cfg.rename_map
if not rename_map:
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = {
@@ -366,7 +371,7 @@ def build_rollout_context(
if dataset is not None:
dataset_stats = rename_stats(
dataset.meta.stats,
cfg.dataset.rename_map if cfg.dataset else {},
cfg.rename_map,
)
preprocessor, postprocessor = make_pre_post_processors(
@@ -374,8 +379,8 @@ def build_rollout_context(
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats,
preprocessor_overrides={
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
"device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.rename_map},
},
)
+4
View File
@@ -72,6 +72,10 @@ class BaseStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference."""
+1
View File
@@ -62,6 +62,7 @@ class RolloutStrategy(abc.ABC):
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference
logger.info("Starting inference engine...")
self._engine.reset()
self._engine.start()
self._warmup_flushed = False
logger.info("Inference engine started")
+8
View File
@@ -508,6 +508,10 @@ class DAggerStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("DAgger continuous control loop ended — pausing engine")
@@ -648,6 +652,10 @@ class DAggerStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("DAgger corrections-only loop ended — pausing engine")
@@ -189,6 +189,10 @@ class HighlightStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("Highlight control loop ended")
+4
View File
@@ -160,6 +160,10 @@ class SentryStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally:
logger.info("Sentry control loop ended — saving final episode")