mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config
This commit is contained in:
@@ -14,7 +14,7 @@
|
||||
|
||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
|
||||
@@ -68,8 +68,6 @@ class DatasetRecordConfig:
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.repo_id:
|
||||
|
||||
@@ -26,6 +26,7 @@ from lerobot.configs import PreTrainedConfig, parser
|
||||
from lerobot.configs.dataset import DatasetRecordConfig
|
||||
from lerobot.robots.config import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||
|
||||
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
||||
|
||||
@@ -205,6 +206,8 @@ class RolloutConfig:
|
||||
# Use vocal synthesis to read events
|
||||
play_sounds: bool = True
|
||||
resume: bool = False
|
||||
# Rename map for mapping robot/dataset observation keys to policy keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
# Torch compile
|
||||
use_torch_compile: bool = False
|
||||
@@ -285,6 +288,31 @@ class RolloutConfig:
|
||||
if self.policy is None:
|
||||
raise ValueError("--policy.path is required for rollout")
|
||||
|
||||
# --- Task resolution ---
|
||||
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
|
||||
# creates a DatasetRecordConfig with single_task="". If the user set
|
||||
# the task via the top-level --task flag, propagate it so that all
|
||||
# downstream consumers (inference engine, dataset frame builders) see it.
|
||||
if self.dataset is not None and not self.dataset.single_task and self.task:
|
||||
logger.info("Propagating top-level task '%s' to dataset config", self.task)
|
||||
self.dataset.single_task = self.task
|
||||
elif self.dataset is not None and self.dataset.single_task and not self.task:
|
||||
logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task)
|
||||
self.task = self.dataset.single_task
|
||||
|
||||
# --- Device resolution ---
|
||||
# Resolve device from the policy config when not explicitly set so all
|
||||
# components (policy.to, preprocessor, inference engine) use the same
|
||||
# device string instead of inconsistent fallbacks.
|
||||
if self.device is None or not is_torch_device_available(self.device):
|
||||
resolved = self.policy.device
|
||||
if resolved:
|
||||
self.device = resolved
|
||||
logger.info("Resolved device from policy config: %s", self.device)
|
||||
else:
|
||||
self.device = auto_select_torch_device().type
|
||||
logger.info("No policy config to resolve device from; auto-selected device: %s", self.device)
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
@@ -272,11 +272,16 @@ def build_rollout_context(
|
||||
# )
|
||||
|
||||
# --- 4. Features + action-key reconciliation ---------------------
|
||||
# TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and
|
||||
# torque channels are observation-only and must be excluded from the state
|
||||
# and action tensors that the policy sees.
|
||||
all_obs_features = robot.observation_features
|
||||
observation_features_hw = {
|
||||
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
|
||||
k: v
|
||||
for k, v in all_obs_features.items()
|
||||
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
|
||||
}
|
||||
action_features_hw = robot.action_features
|
||||
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
|
||||
|
||||
# The action side is always needed: sync inference reads action names from
|
||||
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||
@@ -293,7 +298,7 @@ def build_rollout_context(
|
||||
)
|
||||
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||
raw_action_keys = list(robot.action_features.keys())
|
||||
raw_action_keys = list(action_features_hw.keys())
|
||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||
ordered_action_keys = _resolve_action_key_order(
|
||||
list(policy_action_names) if policy_action_names else None,
|
||||
@@ -301,7 +306,7 @@ def build_rollout_context(
|
||||
)
|
||||
|
||||
# Validate visual features if no rename_map is active
|
||||
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
|
||||
rename_map = cfg.rename_map
|
||||
if not rename_map:
|
||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {
|
||||
@@ -366,7 +371,7 @@ def build_rollout_context(
|
||||
if dataset is not None:
|
||||
dataset_stats = rename_stats(
|
||||
dataset.meta.stats,
|
||||
cfg.dataset.rename_map if cfg.dataset else {},
|
||||
cfg.rename_map,
|
||||
)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
@@ -374,8 +379,8 @@ def build_rollout_context(
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset_stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
|
||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
||||
"device_processor": {"device": cfg.device},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@@ -72,6 +72,10 @@ class BaseStrategy(RolloutStrategy):
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
def teardown(self, ctx: RolloutContext) -> None:
|
||||
"""Disconnect hardware and stop inference."""
|
||||
|
||||
@@ -62,6 +62,7 @@ class RolloutStrategy(abc.ABC):
|
||||
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
|
||||
self._engine = ctx.policy.inference
|
||||
logger.info("Starting inference engine...")
|
||||
self._engine.reset()
|
||||
self._engine.start()
|
||||
self._warmup_flushed = False
|
||||
logger.info("Inference engine started")
|
||||
|
||||
@@ -508,6 +508,10 @@ class DAggerStrategy(RolloutStrategy):
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||
@@ -648,6 +652,10 @@ class DAggerStrategy(RolloutStrategy):
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
finally:
|
||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||
|
||||
@@ -189,6 +189,10 @@ class HighlightStrategy(RolloutStrategy):
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
finally:
|
||||
logger.info("Highlight control loop ended")
|
||||
|
||||
@@ -160,6 +160,10 @@ class SentryStrategy(RolloutStrategy):
|
||||
dt = time.perf_counter() - loop_start
|
||||
if (sleep_t := control_interval - dt) > 0:
|
||||
precise_sleep(sleep_t)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||
)
|
||||
|
||||
finally:
|
||||
logger.info("Sentry control loop ended — saving final episode")
|
||||
|
||||
Reference in New Issue
Block a user