fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config

This commit is contained in:
Steven Palma
2026-04-22 12:55:22 +02:00
parent 195b777367
commit 8d1401abe3
8 changed files with 62 additions and 10 deletions
+1 -3
View File
@@ -14,7 +14,7 @@
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``.""" """Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
from dataclasses import dataclass, field from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from pathlib import Path from pathlib import Path
@@ -68,8 +68,6 @@ class DatasetRecordConfig:
# Number of threads per encoder instance. None = auto (codec default). # Number of threads per encoder instance. None = auto (codec default).
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc.. # Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
encoder_threads: int | None = None encoder_threads: int | None = None
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self) -> None: def __post_init__(self) -> None:
if self.repo_id: if self.repo_id:
+28
View File
@@ -26,6 +26,7 @@ from lerobot.configs import PreTrainedConfig, parser
from lerobot.configs.dataset import DatasetRecordConfig from lerobot.configs.dataset import DatasetRecordConfig
from lerobot.robots.config import RobotConfig from lerobot.robots.config import RobotConfig
from lerobot.teleoperators.config import TeleoperatorConfig from lerobot.teleoperators.config import TeleoperatorConfig
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
from .inference import InferenceEngineConfig, SyncInferenceConfig from .inference import InferenceEngineConfig, SyncInferenceConfig
@@ -205,6 +206,8 @@ class RolloutConfig:
# Use vocal synthesis to read events # Use vocal synthesis to read events
play_sounds: bool = True play_sounds: bool = True
resume: bool = False resume: bool = False
# Rename map for mapping robot/dataset observation keys to policy keys
rename_map: dict[str, str] = field(default_factory=dict)
# Torch compile # Torch compile
use_torch_compile: bool = False use_torch_compile: bool = False
@@ -285,6 +288,31 @@ class RolloutConfig:
if self.policy is None: if self.policy is None:
raise ValueError("--policy.path is required for rollout") raise ValueError("--policy.path is required for rollout")
# --- Task resolution ---
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
# creates a DatasetRecordConfig with single_task="". If the user set
# the task via the top-level --task flag, propagate it so that all
# downstream consumers (inference engine, dataset frame builders) see it.
if self.dataset is not None and not self.dataset.single_task and self.task:
logger.info("Propagating top-level task '%s' to dataset config", self.task)
self.dataset.single_task = self.task
elif self.dataset is not None and self.dataset.single_task and not self.task:
logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task)
self.task = self.dataset.single_task
# --- Device resolution ---
# Resolve device from the policy config when not explicitly set so all
# components (policy.to, preprocessor, inference engine) use the same
# device string instead of inconsistent fallbacks.
if self.device is None or not is_torch_device_available(self.device):
resolved = self.policy.device
if resolved:
self.device = resolved
logger.info("Resolved device from policy config: %s", self.device)
else:
self.device = auto_select_torch_device().type
logger.info("No policy config to resolve device from; auto-selected device: %s", self.device)
@classmethod @classmethod
def __get_path_fields__(cls) -> list[str]: def __get_path_fields__(cls) -> list[str]:
return ["policy"] return ["policy"]
+12 -7
View File
@@ -272,11 +272,16 @@ def build_rollout_context(
# ) # )
# --- 4. Features + action-key reconciliation --------------------- # --- 4. Features + action-key reconciliation ---------------------
# TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and
# torque channels are observation-only and must be excluded from the state
# and action tensors that the policy sees.
all_obs_features = robot.observation_features all_obs_features = robot.observation_features
observation_features_hw = { observation_features_hw = {
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple) k: v
for k, v in all_obs_features.items()
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
} }
action_features_hw = robot.action_features action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
# The action side is always needed: sync inference reads action names from # The action side is always needed: sync inference reads action names from
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions. # ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
@@ -293,7 +298,7 @@ def build_rollout_context(
) )
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features) dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
hw_features = hw_to_dataset_features(observation_features_hw, "observation") hw_features = hw_to_dataset_features(observation_features_hw, "observation")
raw_action_keys = list(robot.action_features.keys()) raw_action_keys = list(action_features_hw.keys())
policy_action_names = getattr(policy_config, "action_feature_names", None) policy_action_names = getattr(policy_config, "action_feature_names", None)
ordered_action_keys = _resolve_action_key_order( ordered_action_keys = _resolve_action_key_order(
list(policy_action_names) if policy_action_names else None, list(policy_action_names) if policy_action_names else None,
@@ -301,7 +306,7 @@ def build_rollout_context(
) )
# Validate visual features if no rename_map is active # Validate visual features if no rename_map is active
rename_map = cfg.dataset.rename_map if cfg.dataset else {} rename_map = cfg.rename_map
if not rename_map: if not rename_map:
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL} expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
provided_visuals = { provided_visuals = {
@@ -366,7 +371,7 @@ def build_rollout_context(
if dataset is not None: if dataset is not None:
dataset_stats = rename_stats( dataset_stats = rename_stats(
dataset.meta.stats, dataset.meta.stats,
cfg.dataset.rename_map if cfg.dataset else {}, cfg.rename_map,
) )
preprocessor, postprocessor = make_pre_post_processors( preprocessor, postprocessor = make_pre_post_processors(
@@ -374,8 +379,8 @@ def build_rollout_context(
pretrained_path=cfg.policy.pretrained_path, pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset_stats, dataset_stats=dataset_stats,
preprocessor_overrides={ preprocessor_overrides={
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")}, "device_processor": {"device": cfg.device},
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}}, "rename_observations_processor": {"rename_map": cfg.rename_map},
}, },
) )
+4
View File
@@ -72,6 +72,10 @@ class BaseStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0: if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t) precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
def teardown(self, ctx: RolloutContext) -> None: def teardown(self, ctx: RolloutContext) -> None:
"""Disconnect hardware and stop inference.""" """Disconnect hardware and stop inference."""
+1
View File
@@ -62,6 +62,7 @@ class RolloutStrategy(abc.ABC):
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier) self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
self._engine = ctx.policy.inference self._engine = ctx.policy.inference
logger.info("Starting inference engine...") logger.info("Starting inference engine...")
self._engine.reset()
self._engine.start() self._engine.start()
self._warmup_flushed = False self._warmup_flushed = False
logger.info("Inference engine started") logger.info("Inference engine started")
+8
View File
@@ -508,6 +508,10 @@ class DAggerStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0: if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t) precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally: finally:
logger.info("DAgger continuous control loop ended — pausing engine") logger.info("DAgger continuous control loop ended — pausing engine")
@@ -648,6 +652,10 @@ class DAggerStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0: if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t) precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally: finally:
logger.info("DAgger corrections-only loop ended — pausing engine") logger.info("DAgger corrections-only loop ended — pausing engine")
@@ -189,6 +189,10 @@ class HighlightStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0: if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t) precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally: finally:
logger.info("Highlight control loop ended") logger.info("Highlight control loop ended")
+4
View File
@@ -160,6 +160,10 @@ class SentryStrategy(RolloutStrategy):
dt = time.perf_counter() - loop_start dt = time.perf_counter() - loop_start
if (sleep_t := control_interval - dt) > 0: if (sleep_t := control_interval - dt) > 0:
precise_sleep(sleep_t) precise_sleep(sleep_t)
else:
logging.warning(
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
)
finally: finally:
logger.info("Sentry control loop ended — saving final episode") logger.info("Sentry control loop ended — saving final episode")