fix(rollout) require dataset in dagger + use duration too

This commit is contained in:
Steven Palma
2026-04-19 17:29:08 +02:00
parent bc06cb44ca
commit 14c7a25ce4
4 changed files with 20 additions and 11 deletions
-1
View File
@@ -23,7 +23,6 @@ from .configs import (
DAggerKeyboardConfig, DAggerKeyboardConfig,
DAggerPedalConfig, DAggerPedalConfig,
DAggerStrategyConfig, DAggerStrategyConfig,
DatasetRecordConfig,
HighlightStrategyConfig, HighlightStrategyConfig,
RolloutConfig, RolloutConfig,
RolloutStrategyConfig, RolloutStrategyConfig,
+14 -9
View File
@@ -216,7 +216,10 @@ class RolloutConfig:
if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None: if isinstance(self.strategy, DAggerStrategyConfig) and self.teleop is None:
raise ValueError("DAgger strategy requires --teleop.type to be set") raise ValueError("DAgger strategy requires --teleop.type to be set")
needs_dataset = isinstance(self.strategy, (SentryStrategyConfig, HighlightStrategyConfig)) # TODO(Steven): DAgger shouldn't require a dataset (user may want to just rollout+intervene without recording), but for now we require it to simplify the implementation.
needs_dataset = isinstance(
self.strategy, (SentryStrategyConfig, HighlightStrategyConfig, DAggerStrategyConfig)
)
if needs_dataset and (self.dataset is None or not self.dataset.repo_id): if needs_dataset and (self.dataset is None or not self.dataset.repo_id):
raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set") raise ValueError(f"{self.strategy.type} strategy requires --dataset.repo_id to be set")
@@ -244,14 +247,16 @@ class RolloutConfig:
self.dataset.streaming_encoding = True self.dataset.streaming_encoding = True
# DAgger: streaming is mandatory only when the autonomous phase is also recorded. # DAgger: streaming is mandatory only when the autonomous phase is also recorded.
if ( if isinstance(self.strategy, DAggerStrategyConfig) and self.dataset is not None:
isinstance(self.strategy, DAggerStrategyConfig) if self.strategy.record_autonomous and not self.dataset.streaming_encoding:
and self.strategy.record_autonomous logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True")
and self.dataset is not None self.dataset.streaming_encoding = True
and not self.dataset.streaming_encoding elif not self.strategy.record_autonomous and not self.dataset.streaming_encoding:
): logger.info(
logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True") "Streaming encoding is disabled for DAgger corrections-only mode. "
self.dataset.streaming_encoding = True "Consider enabling it for faster episode saving: "
"--dataset.streaming_encoding=true --dataset.encoder_threads=2"
)
# --- Policy loading --- # --- Policy loading ---
if self.robot is None: if self.robot is None:
+5
View File
@@ -556,6 +556,7 @@ class DAggerStrategy(RolloutStrategy):
engine.resume() engine.resume()
last_action: dict[str, Any] | None = None last_action: dict[str, Any] | None = None
start_time = time.perf_counter()
record_tick = 0 record_tick = 0
recorded = 0 recorded = 0
logger.info( logger.info(
@@ -571,6 +572,10 @@ class DAggerStrategy(RolloutStrategy):
): ):
loop_start = time.perf_counter() loop_start = time.perf_counter()
if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration:
logger.info("Duration limit reached (%.0fs)", cfg.duration)
break
# Process transitions # Process transitions
transition = events.consume_transition() transition = events.consume_transition()
if transition is not None: if transition is not None:
+1 -1
View File
@@ -25,7 +25,7 @@ from .highlight import HighlightStrategy
from .sentry import SentryStrategy from .sentry import SentryStrategy
if TYPE_CHECKING: if TYPE_CHECKING:
from lerobot.rollout import RolloutStrategyConfig from ..configs import RolloutStrategyConfig
def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: