mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
fix(rollout) require dataset in dagger + use duration too
This commit is contained in:
@@ -23,7 +23,6 @@ from .configs import (
|
|||||||
DAggerKeyboardConfig,
|
DAggerKeyboardConfig,
|
||||||
DAggerPedalConfig,
|
DAggerPedalConfig,
|
||||||
DAggerStrategyConfig,
|
DAggerStrategyConfig,
|
||||||
DatasetRecordConfig,
|
|
||||||
HighlightStrategyConfig,
|
HighlightStrategyConfig,
|
||||||
RolloutConfig,
|
RolloutConfig,
|
||||||
RolloutStrategyConfig,
|
RolloutStrategyConfig,
|
||||||
|
|||||||
@@ -216,7 +216,10 @@ class RolloutConfig:
|
|||||||
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
|
||||||
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
raise ValueError("DAgger strategy requires --teleop.type to be set")
|
||||||
|
|
||||||
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig))
|
# TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
|
||||||
|
needs_dataset = isinstance(
|
||||||
|
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
|
||||||
|
)
|
||||||
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
|
||||||
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
|
||||||
|
|
||||||
@@ -244,14 +247,16 @@ class RolloutConfig:
|
|||||||
self.dataset.streaming_encoding = True
|
self.dataset.streaming_encoding = True
|
||||||
|
|
||||||
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
|
# DAgger: streaming is mandatory only when the autonomous phase is also recorded.
|
||||||
if (
|
if isinstance(self.strategy, DAggerStrategyConfig) and self.dataset is not None:
|
||||||
isinstance(self.strategy, DAggerStrategyConfig)
|
if self.strategy.record_autonomous and not self.dataset.streaming_encoding:
|
||||||
and self.strategy.record_autonomous
|
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
|
||||||
and self.dataset is not None
|
self.dataset.streaming_encoding = True
|
||||||
and not self.dataset.streaming_encoding
|
elif not self.strategy.record_autonomous and not self.dataset.streaming_encoding:
|
||||||
):
|
logger.info(
|
||||||
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
|
"Streaming encoding is disabled for DAgger corrections-only mode. "
|
||||||
self.dataset.streaming_encoding = True
|
"Consider enabling it for faster episode saving: "
|
||||||
|
"--dataset.streaming_encoding=true --dataset.encoder_threads=2"
|
||||||
|
)
|
||||||
|
|
||||||
# --- Policy loading ---
|
# --- Policy loading ---
|
||||||
if self.robot is None:
|
if self.robot is None:
|
||||||
|
|||||||
@@ -556,6 +556,7 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
engine.resume()
|
engine.resume()
|
||||||
|
|
||||||
last_action: dict[str, Any] | None = None
|
last_action: dict[str, Any] | None = None
|
||||||
|
start_time = time.perf_counter()
|
||||||
record_tick = 0
|
record_tick = 0
|
||||||
recorded = 0
|
recorded = 0
|
||||||
logger.info(
|
logger.info(
|
||||||
@@ -571,6 +572,10 @@ class DAggerStrategy(RolloutStrategy):
|
|||||||
):
|
):
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
|
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
|
||||||
|
logger.info("Duration limit reached (%.0fs)", cfg.duration)
|
||||||
|
break
|
||||||
|
|
||||||
# Process transitions
|
# Process transitions
|
||||||
transition = events.consume_transition()
|
transition = events.consume_transition()
|
||||||
if transition is not None:
|
if transition is not None:
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
|
|||||||
from .sentry import SentryStrategy
|
from .sentry import SentryStrategy
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from lerobot.rollout import RolloutStrategyConfig
|
from ..configs import RolloutStrategyConfig
|
||||||
|
|
||||||
|
|
||||||
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy:
|
||||||
|
|||||||
Reference in New Issue
Block a user