mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 05:29:55 +00:00
fd18beb3a1
- name the three modules everywhere (plan / interjections / vqa) instead of module_1/2/3 — config classes, config fields, executor params, staging keys and phase names now carry the module name - rename examples/annotation -> examples/annotations; add the Apache header to run_hf_job.py - drop the unused GeneralVqaModule._generate_one - remove "PR 1" references from comments/docstrings - frames.py: rely on the always-defined LeRobotDatasetMetadata.camera_keys - executor.py: read/write meta/info.json via load_info / write_info - reader.py: load meta/tasks.parquet via io_utils.load_tasks - make --push_to_hub a bool; push the annotated dataset back to --repo_id - move the on-disk test dataset builder into tests/fixtures (build_annotation_dataset); run_e2e_smoke reuses it - clarify in the docs that the vqa module grounds each pair on a single frame (K = per-tick anchor count) - hoist stdlib dynamic imports to module scope Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
211 lines
9.0 KiB
Python
211 lines
9.0 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
"""``interjections`` module: interjections + paired speech (EVENT styles + speech atoms).
|
|
|
|
Two sub-passes:
|
|
|
|
1. At ``t=0``, emit ONLY a speech tool-call atom (acknowledgement of the
|
|
canonical task). No interjection row — the canonical task is already the
|
|
user utterance from ``meta/tasks.parquet``.
|
|
|
|
2. For mid-episode interruptions, emit a co-timestamped pair:
|
|
{role:user, style:interjection, content:<text>}
|
|
speech atom (role:assistant, style:None, tool_calls=[say(...)])
|
|
Both rows go in ``language_events`` at the same timestamp.
|
|
|
|
The ``plan`` module's :meth:`run_plan_updates` reuses this module's
|
|
interjection timestamps to refresh the ``plan`` row at the same instant.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import random
|
|
from collections.abc import Sequence
|
|
from dataclasses import dataclass, field
|
|
from typing import Any
|
|
|
|
from ..config import InterjectionsConfig
|
|
from ..frames import FrameProvider, null_provider, to_image_blocks
|
|
from ..prompts import load as load_prompt
|
|
from ..reader import EpisodeRecord, reconstruct_subtask_spans, snap_to_frame
|
|
from ..staging import EpisodeStaging
|
|
from ..vlm_client import VlmClient
|
|
from ..writer import speech_atom
|
|
|
|
|
|
@dataclass
|
|
class InterjectionsAndSpeechModule:
|
|
"""Generate task-start speech and mid-episode interjection/speech pairs."""
|
|
|
|
vlm: VlmClient
|
|
config: InterjectionsConfig
|
|
seed: int = 1729
|
|
frame_provider: FrameProvider = field(default_factory=null_provider)
|
|
|
|
@property
|
|
def enabled(self) -> bool:
|
|
return self.config.enabled
|
|
|
|
def run_episode(self, record: EpisodeRecord, staging: EpisodeStaging) -> None:
|
|
rows: list[dict[str, Any]] = []
|
|
if record.frame_timestamps:
|
|
t0 = float(record.frame_timestamps[0])
|
|
initial = self._initial_speech(record)
|
|
if initial:
|
|
rows.append(speech_atom(t0, initial))
|
|
# Pull the ``plan`` module's subtask spans for this episode so the
|
|
# interjection prompt can ground itself in the actual current
|
|
# subtask at each chosen timestamp. The ``plan`` module ran first.
|
|
episode_end_t = float(record.frame_timestamps[-1]) if record.frame_timestamps else None
|
|
subtask_spans = reconstruct_subtask_spans(staging.read("plan"), episode_end_t=episode_end_t)
|
|
rows.extend(self._mid_episode_interjections(record, subtask_spans))
|
|
staging.write("interjections", rows)
|
|
|
|
@staticmethod
|
|
def _subtask_at(spans: Sequence[dict[str, Any]], t: float) -> str | None:
|
|
current: str | None = None
|
|
for span in spans:
|
|
if float(span["start"]) <= t:
|
|
current = span.get("text")
|
|
else:
|
|
break
|
|
return current
|
|
|
|
def _initial_speech(self, record: EpisodeRecord) -> str | None:
|
|
prompt = load_prompt("module_2_initial_speech").format(
|
|
episode_task=record.episode_task,
|
|
)
|
|
messages = [{"role": "user", "content": [{"type": "text", "text": prompt}]}]
|
|
result = self.vlm.generate_json([messages])[0]
|
|
if isinstance(result, dict) and isinstance(result.get("text"), str):
|
|
text = result["text"].strip()
|
|
if text:
|
|
return text
|
|
return None
|
|
|
|
def _mid_episode_interjections(
|
|
self,
|
|
record: EpisodeRecord,
|
|
subtask_spans: Sequence[dict[str, Any]],
|
|
) -> list[dict[str, Any]]:
|
|
"""Generate interjections aligned with the actual demo trajectory.
|
|
|
|
Teleop data is frozen — the robot already executed every step in
|
|
the video. A *counterfactual* interjection like "actually skip
|
|
the wipe" contradicts what then happens in the video, which is
|
|
what qwen36moe-10/11 surfaced as low-quality interjections.
|
|
|
|
Instead, anchor every interjection at a subtask boundary and
|
|
write it as a natural user request for the *upcoming* subtask.
|
|
The robot's visible next behavior IS the interjection's effect,
|
|
so the training signal stays consistent: interjection text →
|
|
plan refresh → action stream all line up.
|
|
"""
|
|
if self.config.max_interjections_per_episode <= 0:
|
|
return []
|
|
if len(subtask_spans) < 2:
|
|
# Need at least one transition (subtask 0 → subtask 1).
|
|
return []
|
|
# Deterministic per-episode RNG so reruns are stable across SLURM jobs.
|
|
rng = random.Random(f"{self.seed}:{record.episode_index}:interjection")
|
|
|
|
# Boundaries: the start time of every subtask except the first
|
|
# (which is just t0 and is covered by the initial-task speech atom).
|
|
boundaries: list[tuple[float, str, str]] = []
|
|
for i in range(1, len(subtask_spans)):
|
|
ts = float(subtask_spans[i]["start"])
|
|
if ts < self.config.interjection_min_t:
|
|
continue
|
|
prev_text = (subtask_spans[i - 1].get("text") or "").strip()
|
|
next_text = (subtask_spans[i].get("text") or "").strip()
|
|
if not next_text:
|
|
continue
|
|
boundaries.append((ts, prev_text, next_text))
|
|
if not boundaries:
|
|
return []
|
|
|
|
n = min(self.config.max_interjections_per_episode, len(boundaries))
|
|
chosen = sorted(rng.sample(boundaries, n), key=lambda b: b[0])
|
|
|
|
out: list[dict[str, Any]] = []
|
|
for t, prev_subtask, next_subtask in chosen:
|
|
t_snap = snap_to_frame(t, record.frame_timestamps)
|
|
# Window straddles the boundary so the VLM sees the end of the
|
|
# previous subtask and the start of the next one — same
|
|
# conditioning the policy will see at training time.
|
|
window_ts = self._window_timestamps(t_snap, record.frame_timestamps)
|
|
prompt = load_prompt("module_2_interjection").format(
|
|
episode_task=record.episode_task,
|
|
prev_subtask=prev_subtask or "(starting from initial state)",
|
|
next_subtask=next_subtask,
|
|
timestamp=t_snap,
|
|
window_seconds=self.config.interjection_window_seconds,
|
|
)
|
|
images = self.frame_provider.frames_at(record, window_ts)
|
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
|
messages = [{"role": "user", "content": content}]
|
|
result = self.vlm.generate_json([messages])[0]
|
|
if not isinstance(result, dict):
|
|
continue
|
|
interjection_text = result.get("interjection")
|
|
speech_text = result.get("speech")
|
|
if not isinstance(interjection_text, str) or not interjection_text.strip():
|
|
continue
|
|
if not isinstance(speech_text, str) or not speech_text.strip():
|
|
continue
|
|
out.append(
|
|
{
|
|
"role": "user",
|
|
"content": interjection_text.strip(),
|
|
"style": "interjection",
|
|
"timestamp": t_snap,
|
|
"tool_calls": None,
|
|
}
|
|
)
|
|
out.append(speech_atom(t_snap, speech_text.strip()))
|
|
return out
|
|
|
|
def _window_timestamps(self, t_anchor: float, frame_timestamps: Sequence[float]) -> list[float]:
|
|
"""Return a small set of frame timestamps centered on ``t_anchor``.
|
|
|
|
The window straddles the subtask boundary the interjection sits
|
|
on: roughly half the frames cover the end of the previous
|
|
subtask, half cover the start of the next one. The VLM therefore
|
|
sees BOTH what just finished AND what's about to start, which is
|
|
the conditioning we need to write a natural "now please do X"
|
|
request that matches the visible upcoming behavior.
|
|
"""
|
|
if not frame_timestamps:
|
|
return [t_anchor]
|
|
n = max(1, int(self.config.interjection_window_frames))
|
|
if n == 1:
|
|
return [t_anchor]
|
|
window = float(self.config.interjection_window_seconds)
|
|
step = window / max(1, n - 1)
|
|
# Center the window on the anchor so half lands before, half after.
|
|
start_offset = -window / 2.0
|
|
targets = [t_anchor + start_offset + step * i for i in range(n)]
|
|
last_ts = float(frame_timestamps[-1])
|
|
snapped: list[float] = []
|
|
seen: set[float] = set()
|
|
for tgt in targets:
|
|
clamped = min(last_ts, max(0.0, tgt))
|
|
t = snap_to_frame(clamped, frame_timestamps)
|
|
if t not in seen:
|
|
seen.add(t)
|
|
snapped.append(t)
|
|
return snapped or [t_anchor]
|