mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
feat(annotate): per-episode progress logs in executor
This commit is contained in:
@@ -101,7 +101,7 @@ class Executor:
|
||||
raise ValueError(f"No episodes found under {root}/data/")
|
||||
|
||||
executor_kind = select_executor_class(n, self.config.executor)
|
||||
logger.info("annotate: %d episodes; executor=%s", n, executor_kind)
|
||||
print(f"[annotate] {n} episodes total; executor={executor_kind}", flush=True)
|
||||
|
||||
staging_dir = self.config.resolved_staging_dir(root)
|
||||
staging_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -117,11 +117,15 @@ class Executor:
|
||||
# Phase 4: Module 3 (VQA)
|
||||
phases.append(self._run_module_phase("module_3", records, staging_dir, self.module_3))
|
||||
|
||||
print("[annotate] running validator...", flush=True)
|
||||
report = self.validator.validate(records, staging_dir)
|
||||
if not report.ok and not self.config.skip_validation:
|
||||
raise RuntimeError(f"Staging validation failed: {report.summary()}")
|
||||
print(f"[annotate] validator: {report.summary()}", flush=True)
|
||||
|
||||
print(f"[annotate] writing parquet shards into {root}/data/...", flush=True)
|
||||
written = self.writer.write_all(records, staging_dir, root)
|
||||
print(f"[annotate] wrote {len(written)} shard(s); pipeline complete", flush=True)
|
||||
return PipelineRunSummary(phases=phases, written_paths=written, validation_report=report)
|
||||
|
||||
def _run_module_phase(
|
||||
@@ -131,16 +135,33 @@ class Executor:
|
||||
staging_dir: Path,
|
||||
module: Any,
|
||||
) -> PhaseResult:
|
||||
import time as _time # noqa: PLC0415
|
||||
|
||||
if not module.enabled:
|
||||
print(f"[annotate] phase={name} skipped (module disabled)", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=0, episodes_skipped=len(records))
|
||||
n = len(records)
|
||||
print(f"[annotate] phase={name} starting on {n} episode(s)", flush=True)
|
||||
t0 = _time.time()
|
||||
processed = 0
|
||||
for record in records:
|
||||
for i, record in enumerate(records, 1):
|
||||
ep_start = _time.time()
|
||||
staging = EpisodeStaging(staging_dir, record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
processed += 1
|
||||
elapsed = _time.time() - ep_start
|
||||
print(
|
||||
f"[annotate] {name} episode {i}/{n} "
|
||||
f"(idx={record.episode_index}) done in {elapsed:.1f}s",
|
||||
flush=True,
|
||||
)
|
||||
total = _time.time() - t0
|
||||
print(f"[annotate] phase={name} complete: {processed}/{n} in {total:.1f}s", flush=True)
|
||||
return PhaseResult(name=name, episodes_processed=processed, episodes_skipped=0)
|
||||
|
||||
def _run_plan_update_phase(self, records: list[EpisodeRecord], staging_dir: Path) -> PhaseResult:
|
||||
def _run_plan_update_phase( # noqa: PLR0915
|
||||
self, records: list[EpisodeRecord], staging_dir: Path
|
||||
) -> PhaseResult:
|
||||
"""Re-emit ``plan`` rows at each interjection timestamp from Module 2.
|
||||
|
||||
Module 1 owns the prompt; Module 2 produced the timestamps. This phase
|
||||
|
||||
Reference in New Issue
Block a user