review: skip-count fix, atomic writes, dedupe span reconstruction, role guards

**#1 Plan-update phase reports correct skip count.**
``_run_plan_update_phase`` only ran ``run_plan_updates`` for episodes
with at least one interjection but hardcoded ``episodes_skipped=0``.
The summary undercounted skipped episodes. Now returns
``len(records) - processed`` so processed + skipped == total.

**#2 ``run_hf_job.py`` installs ``openai``.**
The ``CMD`` block does ``pip install --no-deps lerobot[branch]`` then
explicitly lists transitive deps. ``openai`` was missing — and since
``VlmConfig.backend`` defaults to ``"openai"``, the job would have
``ImportError``'d when ``vlm_client._make_openai_client`` ran.

**#3 Dedupe subtask-span reconstruction.**
Module 1's ``_reconstruct_subtasks_from_rows`` (no ``and spans`` guard)
and Module 2's ``_read_subtask_spans`` (with the guard) had near-
identical logic. Promoted to ``reconstruct_subtask_spans`` in
``reader.py`` using the safer guarded form. Both modules now import
the single helper.

**#5 Atomic staging.py JSONL writes.**
Mirroring the parquet-writer fix from an earlier review round:
``EpisodeStaging.write`` now writes to a sibling ``.tmp`` and
``Path.replace`` atomically. A crash mid-write can no longer leave a
half-written JSONL that ``read()`` would then fail to parse.

**#6 Atomic ``info.json`` write.**
Same pattern in ``executor._ensure_annotation_metadata_in_info`` —
``info.json`` is load-bearing for dataset metadata, so partial writes
brick the dataset.

**#7 Writer's role-key guard.**
``_normalize_persistent_row`` and ``_normalize_event_row`` accessed
``row["role"]`` directly while every other field used ``.get()``.
Pre-validate ``"role" in row`` and raise a friendly ``ValueError``
naming the row, so a future module that accidentally drops ``role``
fails with a triagable message instead of a bare KeyError deep in the
writer.

**#8 Last subtask span's ``end`` extends to episode end.**
``reconstruct_subtask_spans`` (the new shared helper) takes an optional
``episode_end_t``. When provided, the final span's ``end`` is closed
to that timestamp instead of equalling its own ``start`` (zero
duration). Both Module 1's plan-update pass and Module 2's interjection
anchoring pass ``record.frame_timestamps[-1]``, so downstream "current
subtask at refresh_t" lookups no longer miss refreshes that land
inside the final span.

Sweep: 66 passed, 0 failed. Pre-commit clean.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-08 12:18:09 +02:00
parent 1238a0cd47
commit 965d42825f
7 changed files with 75 additions and 40 deletions
+6 -4
View File
@@ -21,16 +21,18 @@ from huggingface_hub import get_token, run_job
token = os.environ.get("HF_TOKEN") or get_token()
if not token:
raise RuntimeError(
"No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`"
)
raise RuntimeError("No HF token. Run `huggingface-cli login` or `export HF_TOKEN=hf_...`")
CMD = (
"apt-get update -qq && apt-get install -y -qq git ffmpeg && "
"pip install --no-deps "
"'lerobot @ git+https://github.com/huggingface/lerobot.git@feat/language-annotation-pipeline' && "
"pip install --upgrade-strategy only-if-needed "
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include toml typing-inspect && "
# Mirror lerobot's [annotations] runtime deps. ``openai`` is required
# because ``VlmConfig.backend`` defaults to ``"openai"`` (which talks
# to a vllm/transformers/ktransformers OpenAI-compatible server).
"datasets pyarrow av jsonlines draccus gymnasium torchcodec mergedeep pyyaml-include "
"toml typing-inspect openai && "
"export VLLM_MEMORY_PROFILER_ESTIMATE_CUDAGRAPHS=0 && "
"export VLLM_VIDEO_BACKEND=pyav && "
"lerobot-annotate "
@@ -170,7 +170,11 @@ class Executor:
changed = True
if changed:
info_path.write_text(json.dumps(info, indent=2))
# Atomic replace — info.json is load-bearing for dataset
# metadata, so a crash mid-write would brick the dataset.
tmp_info = info_path.with_suffix(info_path.suffix + ".tmp")
tmp_info.write_text(json.dumps(info, indent=2))
tmp_info.replace(info_path)
print(
"[annotate] meta/info.json: "
f"language_features={list(language_feature_info())}, "
@@ -254,4 +258,10 @@ class Executor:
if interjection_times:
self.module_1.run_plan_updates(record, staging, interjection_times, interjection_texts)
processed += 1
return PhaseResult(name="module_1_plan_update", episodes_processed=processed, episodes_skipped=0)
# Episodes without any interjections are skipped (no plan refresh
# needed); count them so the summary's processed+skipped == total.
return PhaseResult(
name="module_1_plan_update",
episodes_processed=processed,
episodes_skipped=len(records) - processed,
)
@@ -40,7 +40,7 @@ from typing import Any
from ..config import Module2Config
from ..frames import FrameProvider, null_provider, to_image_blocks
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, snap_to_frame
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
from ..writer import speech_atom
@@ -69,24 +69,11 @@ class InterjectionsAndSpeechModule:
# Pull Module 1's subtask spans for this episode so the
# interjection prompt can ground itself in the actual current
# subtask at each chosen timestamp. Module 1 ran first.
subtask_spans = self._read_subtask_spans(staging)
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
subtask_spans = reconstruct_subtask_spans(staging.read("module_1"), episode_end_t=episode_end_t)
rows.extend(self._mid_episode_interjections(record, subtask_spans))
staging.write("module_2", rows)
@staticmethod
def _read_subtask_spans(staging: EpisodeStaging) -> list[dict[str, Any]]:
rows = [r for r in staging.read("module_1") if r.get("style") == "subtask"]
rows.sort(key=lambda r: float(r["timestamp"]))
spans: list[dict[str, Any]] = []
last_t: float | None = None
for r in rows:
t = float(r["timestamp"])
if last_t is not None and spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
last_t = t
return spans
@staticmethod
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
current: str | None = None
@@ -31,7 +31,7 @@ from ..frames import (
to_video_url_block,
)
from ..prompts import load as load_prompt
from ..reader import EpisodeRecord, snap_to_frame
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
from ..staging import EpisodeStaging
from ..vlm_client import VlmClient
@@ -259,7 +259,12 @@ class PlanSubtasksMemoryModule:
without telling it what the user said).
"""
existing = staging.read("module_1")
spans = self._reconstruct_subtasks_from_rows(existing)
# Pass the episode's last frame timestamp so the final subtask
# span is closed (otherwise its ``end`` equals its ``start``,
# zero duration, and the "current subtask at refresh_t" lookup
# in ``_generate_plan`` misses any refresh that lands inside it).
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
spans = reconstruct_subtask_spans(existing, episode_end_t=episode_end_t)
already_planned: set[float] = {float(r["timestamp"]) for r in existing if r.get("style") == "plan"}
new_rows = list(existing)
@@ -286,21 +291,6 @@ class PlanSubtasksMemoryModule:
)
staging.write("module_1", new_rows)
@staticmethod
def _reconstruct_subtasks_from_rows(rows: Sequence[dict[str, Any]]) -> list[dict[str, Any]]:
out = []
last_t: float | None = None
for row in sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
):
t = float(row["timestamp"])
if last_t is not None:
out[-1]["end"] = t
out.append({"text": row.get("content") or "", "start": t, "end": t})
last_t = t
return out
def _generate_subtasks(self, record: EpisodeRecord, *, task: str | None = None) -> list[dict[str, Any]]:
if record.row_count == 0 or not record.frame_timestamps:
return []
@@ -70,6 +70,37 @@ class EpisodeRecord:
return self._frames_df_cache
def reconstruct_subtask_spans(
rows: Sequence[dict[str, Any]],
*,
episode_end_t: float | None = None,
) -> list[dict[str, Any]]:
"""Turn ``style="subtask"`` rows into ``{text, start, end}`` spans.
Each span's ``end`` is the next span's ``start``. The final span's
``end`` defaults to its own ``start`` (zero-duration) pass
``episode_end_t`` to extend it to the episode's last frame instead,
which is what downstream consumers (memory, interjection boundary
selection) expect.
Used by Module 1 (plan-update pass) and Module 2 (interjection
anchoring), which both need the same span shape.
"""
sorted_rows = sorted(
(r for r in rows if r.get("style") == "subtask"),
key=lambda r: float(r["timestamp"]),
)
spans: list[dict[str, Any]] = []
for r in sorted_rows:
t = float(r["timestamp"])
if spans:
spans[-1]["end"] = t
spans.append({"text": r.get("content") or "", "start": t, "end": t})
if spans and episode_end_t is not None and float(episode_end_t) > spans[-1]["start"]:
spans[-1]["end"] = float(episode_end_t)
return spans
def snap_to_frame(t: float, frame_timestamps: Sequence[float]) -> float:
"""Snap an arbitrary float to the nearest exact source frame timestamp.
@@ -61,10 +61,16 @@ class EpisodeStaging:
def write(self, module: ModuleName, rows: Iterable[dict[str, Any]]) -> Path:
path = self.path_for(module)
path.parent.mkdir(parents=True, exist_ok=True)
with path.open("w", encoding="utf-8") as f:
# Atomic replace: a crash mid-write would otherwise leave a
# half-written JSONL file that ``read()`` would then fail to
# parse. Write to a sibling .tmp and rename so the target path
# only ever points at a complete file.
tmp_path = path.with_suffix(path.suffix + ".tmp")
with tmp_path.open("w", encoding="utf-8") as f:
for row in rows:
f.write(json.dumps(row, ensure_ascii=False, sort_keys=True))
f.write("\n")
tmp_path.replace(path)
return path
def read(self, module: ModuleName) -> list[dict[str, Any]]:
@@ -99,6 +99,13 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
)
if "timestamp" not in row:
raise ValueError(f"persistent row missing timestamp: {row!r}")
if "role" not in row:
# Surface a friendly error from the writer rather than letting
# the raw KeyError bubble out of the dict access below — modules
# are expected to always emit ``role``, but the validator
# currently doesn't check this so a future bug would otherwise
# be hard to triage.
raise ValueError(f"persistent row missing role: {row!r}")
camera = row.get("camera")
validate_camera_field(style, camera)
return {
@@ -120,6 +127,8 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
)
if column_for_style(style) != LANGUAGE_EVENTS:
raise ValueError(f"event row with style {style!r} would not route to language_events")
if "role" not in row:
raise ValueError(f"event row missing role: {row!r}")
camera = row.get("camera")
validate_camera_field(style, camera)
return {