mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 04:59:47 +00:00
fix(rollout): device and task propagation + feature pos + warn fps + move rename_map config
This commit is contained in:
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
"""Shared dataset recording configuration used by both ``lerobot-record`` and ``lerobot-rollout``."""
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
@@ -68,8 +68,6 @@ class DatasetRecordConfig:
|
|||||||
# Number of threads per encoder instance. None = auto (codec default).
|
# Number of threads per encoder instance. None = auto (codec default).
|
||||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||||
encoder_threads: int | None = None
|
encoder_threads: int | None = None
|
||||||
# Rename map for the observation to override the image and state keys
|
|
||||||
rename_map: dict[str, str] = field(default_factory=dict)
|
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
if self.repo_id:
|
if self.repo_id:
|
||||||
|
|||||||
@@ -26,6 +26,7 @@ from lerobot.configs import PreTrainedConfig, parser
|
|||||||
from lerobot.configs.dataset import DatasetRecordConfig
|
from lerobot.configs.dataset import DatasetRecordConfig
|
||||||
from lerobot.robots.config import RobotConfig
|
from lerobot.robots.config import RobotConfig
|
||||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||||
|
from lerobot.utils.device_utils import auto_select_torch_device, is_torch_device_available
|
||||||
|
|
||||||
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
from .inference import InferenceEngineConfig, SyncInferenceConfig
|
||||||
|
|
||||||
@@ -205,6 +206,8 @@ class RolloutConfig:
|
|||||||
# Use vocal synthesis to read events
|
# Use vocal synthesis to read events
|
||||||
play_sounds: bool = True
|
play_sounds: bool = True
|
||||||
resume: bool = False
|
resume: bool = False
|
||||||
|
# Rename map for mapping robot/dataset observation keys to policy keys
|
||||||
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
# Torch compile
|
# Torch compile
|
||||||
use_torch_compile: bool = False
|
use_torch_compile: bool = False
|
||||||
@@ -285,6 +288,31 @@ class RolloutConfig:
|
|||||||
if self.policy is None:
|
if self.policy is None:
|
||||||
raise ValueError("--policy.path is required for rollout")
|
raise ValueError("--policy.path is required for rollout")
|
||||||
|
|
||||||
|
# --- Task resolution ---
|
||||||
|
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
|
||||||
|
# creates a DatasetRecordConfig with single_task="". If the user set
|
||||||
|
# the task via the top-level --task flag, propagate it so that all
|
||||||
|
# downstream consumers (inference engine, dataset frame builders) see it.
|
||||||
|
if self.dataset is not None and not self.dataset.single_task and self.task:
|
||||||
|
logger.info("Propagating top-level task '%s' to dataset config", self.task)
|
||||||
|
self.dataset.single_task = self.task
|
||||||
|
elif self.dataset is not None and self.dataset.single_task and not self.task:
|
||||||
|
logger.info("Propagating dataset single_task '%s' to top-level task", self.dataset.single_task)
|
||||||
|
self.task = self.dataset.single_task
|
||||||
|
|
||||||
|
# --- Device resolution ---
|
||||||
|
# Resolve device from the policy config when not explicitly set so all
|
||||||
|
# components (policy.to, preprocessor, inference engine) use the same
|
||||||
|
# device string instead of inconsistent fallbacks.
|
||||||
|
if self.device is None or not is_torch_device_available(self.device):
|
||||||
|
resolved = self.policy.device
|
||||||
|
if resolved:
|
||||||
|
self.device = resolved
|
||||||
|
logger.info("Resolved device from policy config: %s", self.device)
|
||||||
|
else:
|
||||||
|
self.device = auto_select_torch_device().type
|
||||||
|
logger.info("No policy config to resolve device from; auto-selected device: %s", self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
return ["policy"]
|
return ["policy"]
|
||||||
|
|||||||
@@ -272,11 +272,16 @@ def build_rollout_context(
|
|||||||
# )
|
# )
|
||||||
|
|
||||||
# --- 4. Features + action-key reconciliation ---------------------
|
# --- 4. Features + action-key reconciliation ---------------------
|
||||||
|
# TODO(Steven): Only `.pos` joint features are used for policy inference — velocity and
|
||||||
|
# torque channels are observation-only and must be excluded from the state
|
||||||
|
# and action tensors that the policy sees.
|
||||||
all_obs_features = robot.observation_features
|
all_obs_features = robot.observation_features
|
||||||
observation_features_hw = {
|
observation_features_hw = {
|
||||||
k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple)
|
k: v
|
||||||
|
for k, v in all_obs_features.items()
|
||||||
|
if isinstance(v, tuple) or (v is float and k.endswith(".pos"))
|
||||||
}
|
}
|
||||||
action_features_hw = robot.action_features
|
action_features_hw = {k: v for k, v in robot.action_features.items() if k.endswith(".pos")}
|
||||||
|
|
||||||
# The action side is always needed: sync inference reads action names from
|
# The action side is always needed: sync inference reads action names from
|
||||||
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
# ``dataset_features[ACTION]`` to map policy tensors back to robot actions.
|
||||||
@@ -293,7 +298,7 @@ def build_rollout_context(
|
|||||||
)
|
)
|
||||||
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
dataset_features = combine_feature_dicts(action_dataset_features, observation_dataset_features)
|
||||||
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
hw_features = hw_to_dataset_features(observation_features_hw, "observation")
|
||||||
raw_action_keys = list(robot.action_features.keys())
|
raw_action_keys = list(action_features_hw.keys())
|
||||||
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
policy_action_names = getattr(policy_config, "action_feature_names", None)
|
||||||
ordered_action_keys = _resolve_action_key_order(
|
ordered_action_keys = _resolve_action_key_order(
|
||||||
list(policy_action_names) if policy_action_names else None,
|
list(policy_action_names) if policy_action_names else None,
|
||||||
@@ -301,7 +306,7 @@ def build_rollout_context(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Validate visual features if no rename_map is active
|
# Validate visual features if no rename_map is active
|
||||||
rename_map = cfg.dataset.rename_map if cfg.dataset else {}
|
rename_map = cfg.rename_map
|
||||||
if not rename_map:
|
if not rename_map:
|
||||||
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
expected_visuals = {k for k, v in full_config.input_features.items() if v.type == FeatureType.VISUAL}
|
||||||
provided_visuals = {
|
provided_visuals = {
|
||||||
@@ -366,7 +371,7 @@ def build_rollout_context(
|
|||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
dataset_stats = rename_stats(
|
dataset_stats = rename_stats(
|
||||||
dataset.meta.stats,
|
dataset.meta.stats,
|
||||||
cfg.dataset.rename_map if cfg.dataset else {},
|
cfg.rename_map,
|
||||||
)
|
)
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
@@ -374,8 +379,8 @@ def build_rollout_context(
|
|||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
dataset_stats=dataset_stats,
|
dataset_stats=dataset_stats,
|
||||||
preprocessor_overrides={
|
preprocessor_overrides={
|
||||||
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
|
"device_processor": {"device": cfg.device},
|
||||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -72,6 +72,10 @@ class BaseStrategy(RolloutStrategy):
|
|||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||||
|
)
|
||||||
|
|
||||||
def teardown(self, ctx: RolloutContext) -> None:
|
def teardown(self, ctx: RolloutContext) -> None:
|
||||||
"""Disconnect hardware and stop inference."""
|
"""Disconnect hardware and stop inference."""
|
||||||
|
|||||||
@@ -62,6 +62,7 @@ class RolloutStrategy(abc.ABC):
|
|||||||
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
|
self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier)
|
||||||
self._engine = ctx.policy.inference
|
self._engine = ctx.policy.inference
|
||||||
logger.info("Starting inference engine...")
|
logger.info("Starting inference engine...")
|
||||||
|
self._engine.reset()
|
||||||
self._engine.start()
|
self._engine.start()
|
||||||
self._warmup_flushed = False
|
self._warmup_flushed = False
|
||||||
logger.info("Inference engine started")
|
logger.info("Inference engine started")
|
||||||
|
|||||||
@@ -508,6 +508,10 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("DAgger continuous control loop ended — pausing engine")
|
logger.info("DAgger continuous control loop ended — pausing engine")
|
||||||
@@ -648,6 +652,10 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("DAgger corrections-only loop ended — pausing engine")
|
logger.info("DAgger corrections-only loop ended — pausing engine")
|
||||||
|
|||||||
@@ -189,6 +189,10 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("Highlight control loop ended")
|
logger.info("Highlight control loop ended")
|
||||||
|
|||||||
@@ -160,6 +160,10 @@ class SentryStrategy(RolloutStrategy):
|
|||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
else:
|
||||||
|
logging.warning(
|
||||||
|
f"Record loop is running slower ({1 / dt:.1f} Hz) than the target FPS ({cfg.fps} Hz). Dataset frames might be dropped and robot control might be unstable. Common causes are: 1) Camera FPS not keeping up 2) Policy inference taking too long 3) CPU starvation"
|
||||||
|
)
|
||||||
|
|
||||||
finally:
|
finally:
|
||||||
logger.info("Sentry control loop ended — saving final episode")
|
logger.info("Sentry control loop ended — saving final episode")
|
||||||
|
|||||||
Reference in New Issue
Block a user