mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
device + task + warn fix
This commit is contained in:
@@ -225,10 +225,10 @@ class RolloutConfig:
|
|||||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||||
|
|
||||||
if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
|
# if isinstance(self.strategy, BaseStrategyConfig) and self.dataset is not None:
|
||||||
raise ValueError(
|
# raise ValueError(
|
||||||
"Base strategy does not record data. Use sentry, highlight, or dagger for recording."
|
# "Base strategy does not record data. Use sentry, highlight, or dagger for recording."
|
||||||
)
|
# )
|
||||||
|
|
||||||
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
|
# Sentry MUST use streaming encoding to avoid disk I/O blocking the control loop
|
||||||
if (
|
if (
|
||||||
@@ -285,6 +285,26 @@ class RolloutConfig:
|
|||||||
if self.policy is None:
|
if self.policy is None:
|
||||||
raise ValueError("--policy.path is required for rollout")
|
raise ValueError("--policy.path is required for rollout")
|
||||||
|
|
||||||
|
# --- Task resolution ---
|
||||||
|
# When --dataset.rename_map (or any --dataset.* flag) is passed, draccus
|
||||||
|
# creates a DatasetRecordConfig with single_task="". If the user set
|
||||||
|
# the task via the top-level --task flag, propagate it so that all
|
||||||
|
# downstream consumers (inference engine, dataset frame builders) see it.
|
||||||
|
if self.dataset is not None and not self.dataset.single_task and self.task:
|
||||||
|
self.dataset.single_task = self.task
|
||||||
|
elif self.dataset is not None and self.dataset.single_task and not self.task:
|
||||||
|
self.task = self.dataset.single_task
|
||||||
|
|
||||||
|
# --- Device resolution ---
|
||||||
|
# Resolve device from the policy config when not explicitly set so all
|
||||||
|
# components (policy.to, preprocessor, inference engine) use the same
|
||||||
|
# device string instead of inconsistent fallbacks.
|
||||||
|
if self.device is None and self.policy is not None:
|
||||||
|
resolved = getattr(self.policy, "device", None)
|
||||||
|
if resolved:
|
||||||
|
self.device = resolved
|
||||||
|
logger.info("Resolved device from policy config: %s", self.device)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def __get_path_fields__(cls) -> list[str]:
|
def __get_path_fields__(cls) -> list[str]:
|
||||||
return ["policy"]
|
return ["policy"]
|
||||||
|
|||||||
@@ -417,7 +417,7 @@ def build_rollout_context(
|
|||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
dataset_stats=dataset_stats,
|
dataset_stats=dataset_stats,
|
||||||
preprocessor_overrides={
|
preprocessor_overrides={
|
||||||
"device_processor": {"device": cfg.device or getattr(policy_config, "device", "cpu")},
|
"device_processor": {"device": cfg.device},
|
||||||
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
"rename_observations_processor": {"rename_map": cfg.dataset.rename_map if cfg.dataset else {}},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
@@ -428,13 +428,21 @@ def build_rollout_context(
|
|||||||
for step in preprocessor.steps:
|
for step in preprocessor.steps:
|
||||||
if isinstance(step, NormalizerProcessorStep):
|
if isinstance(step, NormalizerProcessorStep):
|
||||||
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
||||||
logger.info("Preprocessor normalizer: %d stat tensors, keys=%s", n_stats, list(step._tensor_stats.keys())[:3])
|
logger.info(
|
||||||
|
"Preprocessor normalizer: %d stat tensors, keys=%s",
|
||||||
|
n_stats,
|
||||||
|
list(step._tensor_stats.keys())[:3],
|
||||||
|
)
|
||||||
if n_stats == 0:
|
if n_stats == 0:
|
||||||
logger.error("PREPROCESSOR NORMALIZER HAS NO STATS — observations will NOT be normalized!")
|
logger.error("PREPROCESSOR NORMALIZER HAS NO STATS — observations will NOT be normalized!")
|
||||||
for step in postprocessor.steps:
|
for step in postprocessor.steps:
|
||||||
if isinstance(step, UnnormalizerProcessorStep):
|
if isinstance(step, UnnormalizerProcessorStep):
|
||||||
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
n_stats = sum(len(v) for v in step._tensor_stats.values()) if step._tensor_stats else 0
|
||||||
logger.info("Postprocessor unnormalizer: %d stat tensors, keys=%s", n_stats, list(step._tensor_stats.keys())[:3])
|
logger.info(
|
||||||
|
"Postprocessor unnormalizer: %d stat tensors, keys=%s",
|
||||||
|
n_stats,
|
||||||
|
list(step._tensor_stats.keys())[:3],
|
||||||
|
)
|
||||||
if n_stats == 0:
|
if n_stats == 0:
|
||||||
logger.error("POSTPROCESSOR UNNORMALIZER HAS NO STATS — actions will NOT be denormalized!")
|
logger.error("POSTPROCESSOR UNNORMALIZER HAS NO STATS — actions will NOT be denormalized!")
|
||||||
|
|
||||||
|
|||||||
@@ -80,6 +80,7 @@ class BaseStrategy(RolloutStrategy):
|
|||||||
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
self._log_telemetry(obs_processed, action_dict, ctx.runtime)
|
||||||
|
|
||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
|
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
|
||||||
|
|||||||
@@ -146,6 +146,20 @@ class RolloutStrategy(abc.ABC):
|
|||||||
compress_images=cfg.display_compressed_images,
|
compress_images=cfg.display_compressed_images,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _warn_if_slow(dt: float, control_interval: float, fps: float) -> None:
|
||||||
|
"""Log a warning when the control loop runs slower than target FPS."""
|
||||||
|
if dt > control_interval:
|
||||||
|
actual_fps = 1.0 / dt if dt > 0 else 0
|
||||||
|
logger.warning(
|
||||||
|
"Control loop is running slower (%.1f Hz) than target FPS (%.0f Hz). "
|
||||||
|
"Dataset frames might be dropped and robot control might be unstable. "
|
||||||
|
"Common causes: 1) Camera FPS not keeping up "
|
||||||
|
"2) Policy inference taking too long 3) CPU starvation",
|
||||||
|
actual_fps,
|
||||||
|
fps,
|
||||||
|
)
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def setup(self, ctx: RolloutContext) -> None:
|
def setup(self, ctx: RolloutContext) -> None:
|
||||||
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
"""Strategy-specific initialisation (keyboard listeners, buffers, etc.)."""
|
||||||
|
|||||||
@@ -506,6 +506,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
episode_start = time.perf_counter()
|
episode_start = time.perf_counter()
|
||||||
|
|
||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
|
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
|
||||||
@@ -646,6 +647,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
last_action = ctx.processors.robot_action_processor((action_dict, obs))
|
||||||
|
|
||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
|
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
|
||||||
|
|||||||
@@ -187,6 +187,7 @@ class HighlightStrategy(RolloutStrategy):
|
|||||||
ring.append(frame)
|
ring.append(frame)
|
||||||
|
|
||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
|
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
|
||||||
|
|||||||
@@ -158,6 +158,7 @@ class SentryStrategy(RolloutStrategy):
|
|||||||
episode_start = time.perf_counter()
|
episode_start = time.perf_counter()
|
||||||
|
|
||||||
dt = time.perf_counter() - loop_start
|
dt = time.perf_counter() - loop_start
|
||||||
|
self._warn_if_slow(dt, control_interval, cfg.fps)
|
||||||
if (sleep_t := control_interval - dt) > 0:
|
if (sleep_t := control_interval - dt) > 0:
|
||||||
precise_sleep(sleep_t)
|
precise_sleep(sleep_t)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user