mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
review: skip-count fix, atomic writes, dedupe span reconstruction, role guards
**#1 Plan-update phase reports correct skip count.** ``_run_plan_update_phase`` only ran ``run_plan_updates`` for episodes with at least one interjection but hardcoded ``episodes_skipped=0``. The summary undercounted skipped episodes. Now returns ``len(records) - processed`` so processed + skipped == total. **#2 ``run_hf_job.py`` installs ``openai``.** The ``CMD`` block does ``pip install --no-deps lerobot[branch]`` then explicitly lists transitive deps. ``openai`` was missing — and since ``VlmConfig.backend`` defaults to ``"openai"``, the job would have ``ImportError``'d when ``vlm_client._make_openai_client`` ran. **#3 Dedupe subtask-span reconstruction.** Module 1's ``_reconstruct_subtasks_from_rows`` (no ``and spans`` guard) and Module 2's ``_read_subtask_spans`` (with the guard) had near- identical logic. Promoted to ``reconstruct_subtask_spans`` in ``reader.py`` using the safer guarded form. Both modules now import the single helper. **#5 Atomic staging.py JSONL writes.** Mirroring the parquet-writer fix from an earlier review round: ``EpisodeStaging.write`` now writes to a sibling ``.tmp`` and ``Path.replace`` atomically. A crash mid-write can no longer leave a half-written JSONL that ``read()`` would then fail to parse. **#6 Atomic ``info.json`` write.** Same pattern in ``executor._ensure_annotation_metadata_in_info`` — ``info.json`` is load-bearing for dataset metadata, so partial writes brick the dataset. **#7 Writer's role-key guard.** ``_normalize_persistent_row`` and ``_normalize_event_row`` accessed ``row["role"]`` directly while every other field used ``.get()``. Pre-validate ``"role" in row`` and raise a friendly ``ValueError`` naming the row, so a future module that accidentally drops ``role`` fails with a triagable message instead of a bare KeyError deep in the writer. **#8 Last subtask span's ``end`` extends to episode end.** ``reconstruct_subtask_spans`` (the new shared helper) takes an optional ``episode_end_t``. When provided, the final span's ``end`` is closed to that timestamp instead of equalling its own ``start`` (zero duration). Both Module 1's plan-update pass and Module 2's interjection anchoring pass ``record.frame_timestamps[-1]``, so downstream "current subtask at refresh_t" lookups no longer miss refreshes that land inside the final span. Sweep: 66 passed, 0 failed. Pre-commit clean. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -21,16 +21,18 @@ from huggingface_hub import get_token, run_job
|
||||
|
||||
token = os.environ.get("HF_TOKEN") or get_token()
|
||||
if not token:
|
||||
raise RuntimeError(
|
||||
"No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`"
|
||||
)
|
||||
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
|
||||
|
||||
CMD = (
|
||||
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
|
||||
"pip install --no-deps "
|
||||
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
|
||||
"pip install --upgrade-strategy only-if-needed "
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect && "
|
||||
# Mirror lerobot's [annotations] runtime deps. ``openai`` is required
|
||||
# because ``VlmConfig.backend`` defaults to ``"openai"`` (which talks
|
||||
# to a vllm/transformers/ktransformers OpenAI-compatible server).
|
||||
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include "
|
||||
"toml typing-inspect openai && "
|
||||
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
|
||||
"export VLLM_VIDEO_BACKEND=pyav && "
|
||||
"lerobot-annotate "
|
||||
|
||||
@@ -170,7 +170,11 @@ class Executor:
|
||||
changed = True
|
||||
|
||||
if changed:
|
||||
info_path.write_text(json.dumps(info, indent=2))
|
||||
# Atomic replace — info.json is load-bearing for dataset
|
||||
# metadata, so a crash mid-write would brick the dataset.
|
||||
tmp_info = info_path.with_suffix(info_path.suffix + ".tmp")
|
||||
tmp_info.write_text(json.dumps(info, indent=2))
|
||||
tmp_info.replace(info_path)
|
||||
print(
|
||||
"[annotate] meta/info.json: "
|
||||
f"language_features={list(language_feature_info())}, "
|
||||
@@ -254,4 +258,10 @@ class Executor:
|
||||
if interjection_times:
|
||||
self.module_1.run_plan_updates(record, staging, interjection_times, interjection_texts)
|
||||
processed += 1
|
||||
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
|
||||
# Episodes without any interjections are skipped (no plan refresh
|
||||
# needed); count them so the summary's processed+skipped == total.
|
||||
return PhaseResult(
|
||||
name="module_1_plan_update",
|
||||
episodes_processed=processed,
|
||||
episodes_skipped=len(records) - processed,
|
||||
)
|
||||
|
||||
@@ -40,7 +40,7 @@ from typing import Any
|
||||
from ..config import Module2Config
|
||||
from ..frames import FrameProvider, null_provider, to_image_blocks
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, snap_to_frame
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
from ..writer import speech_atom
|
||||
@@ -69,24 +69,11 @@ class InterjectionsAndSpeechModule:
|
||||
# Pull Module 1's subtask spans for this episode so the
|
||||
# interjection prompt can ground itself in the actual current
|
||||
# subtask at each chosen timestamp. Module 1 ran first.
|
||||
subtask_spans = self._read_subtask_spans(staging)
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
subtask_spans = reconstruct_subtask_spans(staging.read("module_1"), episode_end_t=episode_end_t)
|
||||
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
||||
staging.write("module_2", rows)
|
||||
|
||||
@staticmethod
|
||||
def _read_subtask_spans(staging: EpisodeStaging) -> list[dict[str, Any]]:
|
||||
rows = [r for r in staging.read("module_1") if r.get("style") == "subtask"]
|
||||
rows.sort(key=lambda r: float(r["timestamp"]))
|
||||
spans: list[dict[str, Any]] = []
|
||||
last_t: float | None = None
|
||||
for r in rows:
|
||||
t = float(r["timestamp"])
|
||||
if last_t is not None and spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
last_t = t
|
||||
return spans
|
||||
|
||||
@staticmethod
|
||||
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
||||
current: str | None = None
|
||||
|
||||
@@ -31,7 +31,7 @@ from ..frames import (
|
||||
to_video_url_block,
|
||||
)
|
||||
from ..prompts import load as load_prompt
|
||||
from ..reader import EpisodeRecord, snap_to_frame
|
||||
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
||||
from ..staging import EpisodeStaging
|
||||
from ..vlm_client import VlmClient
|
||||
|
||||
@@ -259,7 +259,12 @@ class PlanSubtasksMemoryModule:
|
||||
without telling it what the user said).
|
||||
"""
|
||||
existing = staging.read("module_1")
|
||||
spans = self._reconstruct_subtasks_from_rows(existing)
|
||||
# Pass the episode's last frame timestamp so the final subtask
|
||||
# span is closed (otherwise its ``end`` equals its ``start``,
|
||||
# zero duration, and the "current subtask at refresh_t" lookup
|
||||
# in ``_generate_plan`` misses any refresh that lands inside it).
|
||||
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
||||
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
|
||||
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
|
||||
new_rows = list(existing)
|
||||
|
||||
@@ -286,21 +291,6 @@ class PlanSubtasksMemoryModule:
|
||||
)
|
||||
staging.write("module_1", new_rows)
|
||||
|
||||
@staticmethod
|
||||
def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
|
||||
out = []
|
||||
last_t: float | None = None
|
||||
for row in sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
):
|
||||
t = float(row["timestamp"])
|
||||
if last_t is not None:
|
||||
out[-1]["end"] = t
|
||||
out.append({"text": row.get("content") or "", "start": t, "end": t})
|
||||
last_t = t
|
||||
return out
|
||||
|
||||
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
|
||||
if record.row_count == 0 or not record.frame_timestamps:
|
||||
return []
|
||||
|
||||
@@ -70,6 +70,37 @@ class EpisodeRecord:
|
||||
return self._frames_df_cache
|
||||
|
||||
|
||||
def reconstruct_subtask_spans(
|
||||
rows: Sequence[dict[str, Any]],
|
||||
*,
|
||||
episode_end_t: float | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
|
||||
|
||||
Each span's ``end`` is the next span's ``start``. The final span's
|
||||
``end`` defaults to its own ``start`` (zero-duration) — pass
|
||||
``episode_end_t`` to extend it to the episode's last frame instead,
|
||||
which is what downstream consumers (memory, interjection boundary
|
||||
selection) expect.
|
||||
|
||||
Used by Module 1 (plan-update pass) and Module 2 (interjection
|
||||
anchoring), which both need the same span shape.
|
||||
"""
|
||||
sorted_rows = sorted(
|
||||
(r for r in rows if r.get("style") == "subtask"),
|
||||
key=lambda r: float(r["timestamp"]),
|
||||
)
|
||||
spans: list[dict[str, Any]] = []
|
||||
for r in sorted_rows:
|
||||
t = float(r["timestamp"])
|
||||
if spans:
|
||||
spans[-1]["end"] = t
|
||||
spans.append({"text": r.get("content") or "", "start": t, "end": t})
|
||||
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
|
||||
spans[-1]["end"] = float(episode_end_t)
|
||||
return spans
|
||||
|
||||
|
||||
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
|
||||
"""Snap an arbitrary float to the nearest exact source frame timestamp.
|
||||
|
||||
|
||||
@@ -61,10 +61,16 @@ class EpisodeStaging:
|
||||
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
|
||||
path = self.path_for(module)
|
||||
path.parent.mkdir(parents=True, exist_ok=True)
|
||||
with path.open("w", encoding="utf-8") as f:
|
||||
# Atomic replace: a crash mid-write would otherwise leave a
|
||||
# half-written JSONL file that ``read()`` would then fail to
|
||||
# parse. Write to a sibling .tmp and rename so the target path
|
||||
# only ever points at a complete file.
|
||||
tmp_path = path.with_suffix(path.suffix + ".tmp")
|
||||
with tmp_path.open("w", encoding="utf-8") as f:
|
||||
for row in rows:
|
||||
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
|
||||
f.write("\n")
|
||||
tmp_path.replace(path)
|
||||
return path
|
||||
|
||||
def read(self, module: ModuleName) -> list[dict[str, Any]]:
|
||||
|
||||
@@ -99,6 +99,13 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
if "role" not in row:
|
||||
# Surface a friendly error from the writer rather than letting
|
||||
# the raw KeyError bubble out of the dict access below — modules
|
||||
# are expected to always emit ``role``, but the validator
|
||||
# currently doesn't check this so a future bug would otherwise
|
||||
# be hard to triage.
|
||||
raise ValueError(f"persistent row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
@@ -120,6 +127,8 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
if "role" not in row:
|
||||
raise ValueError(f"event row missing role: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
|
||||
Reference in New Issue
Block a user