mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
feat(annotate): emit VQA per-camera and propagate camera field
Module 3 now produces one (vqa, user) + (vqa, assistant) pair per emission tick *per camera* rather than only against the dataset's first camera. Each emitted row carries the `camera` field added in PR 1 (language-columns), so the resolver can disambiguate per-camera VQA via `emitted_at(t, style=vqa, role=assistant, camera=...)` without ambiguity. - `frames.py`: `FrameProvider` Protocol gains a `camera_keys` property and a `camera_key=` argument on `frames_at` / `video_for_episode`. `VideoFrameProvider` exposes every `observation.images.*` key the dataset declares (not just the first) and keys its decode cache on `(episode, camera, timestamp)` so per-camera reads don't collide. Module 1 / 2 keep their old single-camera behaviour by leaving `camera_key=None` (falls back to the default camera). - `modules/general_vqa.py`: `run_episode` iterates `frame_provider .camera_keys` for each emission tick, builds one prompt per camera, batches all of them through the VLM, and stamps the resulting rows with `camera=<that key>`. Empty `camera_keys` (null provider) makes the module a no-op rather than silently emitting untagged rows. - `writer.py`: `_normalize_persistent_row` / `_normalize_event_row` carry `camera` through and call `validate_camera_field` so the invariant is enforced at the writer boundary. Event sort key now includes `camera` for deterministic ordering when several cameras share `(timestamp, style, role)`. `speech_atom` sets `camera=None`. - `validator.py`: `StagingValidator` gains a `dataset_camera_keys` field; `_check_camera_field` enforces the invariant and cross-checks every view-dependent row's `camera` against the dataset's known video keys. New `_check_vqa_uniqueness_per_frame_camera` flags duplicate `(vqa, role)` pairs at the same `(t, camera)`. - `lerobot_annotate.py`: passes the live frame provider's `camera_keys` into the validator so the cross-check uses the actual dataset camera set. - Tests: `_StubFrameProvider` exposes `camera_keys` and accepts the new `camera_key=` kwarg. `test_module3_vqa_unique_per_frame_and_camera` configures two cameras and asserts both are represented, that every emitted row has a `camera` tag, and that uniqueness holds per `(timestamp, camera, role)`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,10 +34,29 @@ from .reader import EpisodeRecord
|
||||
class FrameProvider(Protocol):
|
||||
"""Decodes camera frames at episode-relative timestamps."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||
|
||||
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||
to the provider's default camera so existing single-camera callers
|
||||
(Module 1, Module 2) keep working unchanged.
|
||||
"""
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||
|
||||
Sampling is uniform across the episode duration. The returned list is
|
||||
@@ -51,10 +70,24 @@ class FrameProvider(Protocol):
|
||||
class _NullProvider:
|
||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return []
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
return []
|
||||
|
||||
|
||||
@@ -64,12 +97,18 @@ def null_provider() -> FrameProvider:
|
||||
|
||||
@dataclass
|
||||
class VideoFrameProvider:
|
||||
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
||||
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||
|
||||
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
||||
about *what is happening*, not which camera angle the model sees, so a
|
||||
single canonical viewpoint is enough. Override ``camera_key`` if you
|
||||
want a specific stream.
|
||||
By default the *first* camera key is used for Module 1 (subtask
|
||||
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||
grounded against.
|
||||
|
||||
``camera_key`` overrides the default-camera choice but does not restrict
|
||||
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||
``video_for_episode`` to read a non-default stream.
|
||||
|
||||
Caches up to ``cache_size`` decoded frames per process to keep
|
||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||
@@ -81,24 +120,37 @@ class VideoFrameProvider:
|
||||
cache_size: int = 256
|
||||
_meta: Any = field(default=None, init=False, repr=False)
|
||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||
|
||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||
keys = list(self._meta.video_keys or [])
|
||||
self._camera_keys = keys
|
||||
if self.camera_key is None:
|
||||
keys = self._meta.video_keys
|
||||
self.camera_key = keys[0] if keys else None
|
||||
|
||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
||||
if not timestamps or self.camera_key is None:
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
"""All ``observation.images.*`` keys available on this dataset."""
|
||||
return list(self._camera_keys)
|
||||
|
||||
def frames_at(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
timestamps: list[float],
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if not timestamps or target is None:
|
||||
return []
|
||||
|
||||
out: list[Any] = []
|
||||
misses: list[float] = []
|
||||
miss_indices: list[int] = []
|
||||
for i, ts in enumerate(timestamps):
|
||||
key = (record.episode_index, round(float(ts), 6))
|
||||
key = (record.episode_index, target, round(float(ts), 6))
|
||||
cached = self._cache.get(key)
|
||||
if cached is not None:
|
||||
out.append(cached)
|
||||
@@ -108,20 +160,22 @@ class VideoFrameProvider:
|
||||
miss_indices.append(i)
|
||||
|
||||
if misses:
|
||||
decoded = self._decode(record.episode_index, misses)
|
||||
decoded = self._decode(record.episode_index, misses, target)
|
||||
# decoder may return fewer frames than requested when some
|
||||
# timestamps fall outside the video; pair what we have and
|
||||
# leave the rest as None to be filtered below.
|
||||
for i, img in zip(miss_indices, decoded):
|
||||
out[i] = img
|
||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
||||
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||
if len(self._cache) >= self.cache_size:
|
||||
self._cache.pop(next(iter(self._cache)))
|
||||
self._cache[key] = img
|
||||
# filter out any None left over from decode failures
|
||||
return [img for img in out if img is not None]
|
||||
|
||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
||||
def _decode(
|
||||
self, episode_index: int, timestamps: list[float], camera_key: str
|
||||
) -> list[Any]:
|
||||
import os as _os # noqa: PLC0415
|
||||
|
||||
from PIL import Image # noqa: PLC0415
|
||||
@@ -129,9 +183,9 @@ class VideoFrameProvider:
|
||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||
|
||||
ep = self._meta.episodes[episode_index]
|
||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
||||
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||
shifted = [from_timestamp + ts for ts in timestamps]
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
||||
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
|
||||
# some environments; default to ``pyav`` (always available via
|
||||
# the ``av`` package) and let users override with
|
||||
@@ -156,13 +210,19 @@ class VideoFrameProvider:
|
||||
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||
return out
|
||||
|
||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
||||
def video_for_episode(
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
max_frames: int,
|
||||
camera_key: str | None = None,
|
||||
) -> list[Any]:
|
||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||
|
||||
The whole episode duration is covered; the model picks subtask
|
||||
boundaries from the temporal pooling it does internally.
|
||||
"""
|
||||
if max_frames <= 0 or self.camera_key is None or not record.frame_timestamps:
|
||||
target = camera_key if camera_key is not None else self.camera_key
|
||||
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||
return []
|
||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||
if n_frames == len(record.frame_timestamps):
|
||||
@@ -175,7 +235,7 @@ class VideoFrameProvider:
|
||||
else:
|
||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||
return self.frames_at(record, timestamps)
|
||||
return self.frames_at(record, timestamps, camera_key=target)
|
||||
|
||||
|
||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||
|
||||
@@ -16,8 +16,15 @@
|
||||
"""Module 3: general VQA at a timed cadence.
|
||||
|
||||
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
||||
emission so each frame gets at most one ``(vqa, user)`` and one
|
||||
``(vqa, assistant)`` pair — keeps the resolver contract scalar.
|
||||
emission. For datasets with multiple cameras, every emission tick produces
|
||||
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||
generated against that camera's frame and stamped with the matching
|
||||
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||
|
||||
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||
|
||||
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||
@@ -98,23 +105,37 @@ class GeneralVqaModule:
|
||||
anchor_idx = _emission_anchor_indices(
|
||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||
)
|
||||
# Build all messages first, then issue them as a single batched
|
||||
# generate_json call so the client can fan them out concurrently.
|
||||
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
|
||||
cameras = self._target_cameras()
|
||||
if not cameras:
|
||||
# No camera available — keep behaviour parity with previous
|
||||
# text-only stub: emit nothing rather than producing untagged
|
||||
# rows that would fail validation.
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
# Build all messages first (one per (frame, camera)), then issue them
|
||||
# as a single batched generate_json call so the client can fan them
|
||||
# out concurrently.
|
||||
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||
for idx in anchor_idx:
|
||||
ts = float(record.frame_timestamps[idx])
|
||||
qtype = rng.choice(self.config.question_types)
|
||||
messages = self._build_messages(record, qtype, ts)
|
||||
per_call.append((ts, qtype, messages))
|
||||
for camera in cameras:
|
||||
messages = self._build_messages(record, qtype, ts, camera)
|
||||
# Skip cameras that decoded to zero frames at this ts: no point
|
||||
# asking the VLM to ground a bbox without an image.
|
||||
if not _has_image_block(messages):
|
||||
continue
|
||||
per_call.append((ts, camera, qtype, messages))
|
||||
|
||||
if not per_call:
|
||||
staging.write("module_3", [])
|
||||
return
|
||||
|
||||
results = self.vlm.generate_json([m for _, _, m in per_call])
|
||||
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||
|
||||
rows: list[dict[str, Any]] = []
|
||||
for (ts, _qtype, _messages), result in zip(per_call, results):
|
||||
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
|
||||
qa = self._postprocess(result)
|
||||
if qa is None:
|
||||
continue
|
||||
@@ -125,6 +146,7 @@ class GeneralVqaModule:
|
||||
"content": question,
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
@@ -134,19 +156,35 @@ class GeneralVqaModule:
|
||||
"content": json.dumps(answer, sort_keys=True),
|
||||
"style": "vqa",
|
||||
"timestamp": ts,
|
||||
"camera": camera,
|
||||
"tool_calls": None,
|
||||
}
|
||||
)
|
||||
staging.write("module_3", rows)
|
||||
|
||||
def _target_cameras(self) -> list[str]:
|
||||
"""Return the cameras Module 3 should iterate per emission tick.
|
||||
|
||||
Defaults to every camera the provider exposes. Datasets with no
|
||||
cameras (or test/null providers) yield an empty list, which makes
|
||||
``run_episode`` a no-op.
|
||||
"""
|
||||
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||
|
||||
def _build_messages(
|
||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> list[dict[str, Any]]:
|
||||
prompt = load_prompt("module_3_vqa").format(
|
||||
episode_task=record.episode_task,
|
||||
question_type=question_type,
|
||||
)
|
||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
||||
images = self.frame_provider.frames_at(
|
||||
record, [frame_timestamp], camera_key=camera_key
|
||||
)
|
||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||
return [{"role": "user", "content": content}]
|
||||
|
||||
@@ -166,8 +204,24 @@ class GeneralVqaModule:
|
||||
return question.strip(), answer
|
||||
|
||||
def _generate_one(
|
||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
||||
self,
|
||||
record: EpisodeRecord,
|
||||
question_type: str,
|
||||
frame_timestamp: float,
|
||||
camera_key: str,
|
||||
) -> tuple[str, dict[str, Any]] | None:
|
||||
messages = self._build_messages(record, question_type, frame_timestamp)
|
||||
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
|
||||
result = self.vlm.generate_json([messages])[0]
|
||||
return self._postprocess(result)
|
||||
|
||||
|
||||
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||
"""Return True if any user content block is a populated image block."""
|
||||
for msg in messages:
|
||||
content = msg.get("content")
|
||||
if not isinstance(content, list):
|
||||
continue
|
||||
for block in content:
|
||||
if isinstance(block, dict) and block.get("type") == "image":
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -43,6 +43,8 @@ from lerobot.datasets.language import (
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
column_for_style,
|
||||
is_view_dependent_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
@@ -98,6 +100,11 @@ class StagingValidator:
|
||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||
|
||||
timestamp_atol: float = 0.0 # exact-match by default
|
||||
dataset_camera_keys: tuple[str, ...] | None = None
|
||||
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||
validator additionally enforces that every view-dependent row's
|
||||
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||
|
||||
def validate(
|
||||
self,
|
||||
@@ -130,6 +137,9 @@ class StagingValidator:
|
||||
persistent: list[dict[str, Any]] = []
|
||||
for row in all_rows:
|
||||
self._check_column_routing(row, report, record.episode_index)
|
||||
self._check_camera_field(
|
||||
row, report, record.episode_index, self.dataset_camera_keys
|
||||
)
|
||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||
persistent.append(row)
|
||||
else:
|
||||
@@ -141,6 +151,59 @@ class StagingValidator:
|
||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||
self._check_vqa_json(events, report, record.episode_index)
|
||||
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||
|
||||
def _check_camera_field(
|
||||
self,
|
||||
row: dict[str, Any],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
dataset_camera_keys: Sequence[str] | None,
|
||||
) -> None:
|
||||
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||
style = row.get("style")
|
||||
camera = row.get("camera")
|
||||
try:
|
||||
validate_camera_field(style, camera)
|
||||
except ValueError as exc:
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||
)
|
||||
return
|
||||
if (
|
||||
is_view_dependent_style(style)
|
||||
and dataset_camera_keys
|
||||
and camera not in dataset_camera_keys
|
||||
):
|
||||
report.add_error(
|
||||
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||
)
|
||||
|
||||
def _check_vqa_uniqueness_per_frame_camera(
|
||||
self,
|
||||
events: Iterable[dict[str, Any]],
|
||||
report: ValidationReport,
|
||||
episode_index: int,
|
||||
) -> None:
|
||||
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||
counts: dict[tuple[float, str, str], int] = {}
|
||||
for row in events:
|
||||
if row.get("style") != "vqa":
|
||||
continue
|
||||
ts = row.get("timestamp")
|
||||
camera = row.get("camera")
|
||||
role = row.get("role")
|
||||
if ts is None or camera is None or role is None:
|
||||
continue # other validators flag these
|
||||
key = (float(ts), str(camera), str(role))
|
||||
counts[key] = counts.get(key, 0) + 1
|
||||
for (ts, camera, role), n in counts.items():
|
||||
if n > 1:
|
||||
report.add_error(
|
||||
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||
)
|
||||
|
||||
def _check_column_routing(
|
||||
self,
|
||||
|
||||
@@ -55,6 +55,7 @@ from lerobot.datasets.language import (
|
||||
LANGUAGE_PERSISTENT,
|
||||
PERSISTENT_STYLES,
|
||||
column_for_style,
|
||||
validate_camera_field,
|
||||
)
|
||||
|
||||
from .reader import EpisodeRecord
|
||||
@@ -88,7 +89,11 @@ def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
||||
|
||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||
# events are bucketed per-frame, but within a frame we still want determinism
|
||||
return (row.get("style") or "", row.get("role") or "")
|
||||
return (
|
||||
row.get("style") or "",
|
||||
row.get("role") or "",
|
||||
row.get("camera") or "",
|
||||
)
|
||||
|
||||
|
||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
@@ -101,11 +106,14 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
)
|
||||
if "timestamp" not in row:
|
||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"timestamp": float(row["timestamp"]),
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
@@ -119,10 +127,13 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||
)
|
||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||
camera = row.get("camera")
|
||||
validate_camera_field(style, camera)
|
||||
return {
|
||||
"role": str(row["role"]),
|
||||
"content": None if row.get("content") is None else str(row["content"]),
|
||||
"style": style,
|
||||
"camera": None if camera is None else str(camera),
|
||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||
}
|
||||
|
||||
@@ -311,6 +322,7 @@ def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
||||
"content": None,
|
||||
"style": None,
|
||||
"timestamp": float(timestamp),
|
||||
"camera": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"type": "function",
|
||||
|
||||
@@ -72,7 +72,9 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
||||
)
|
||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||
writer = LanguageColumnsWriter()
|
||||
validator = StagingValidator()
|
||||
validator = StagingValidator(
|
||||
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||
)
|
||||
|
||||
executor = Executor(
|
||||
config=cfg,
|
||||
|
||||
@@ -44,15 +44,20 @@ class _StubFrameProvider:
|
||||
"""Returns one sentinel object per requested timestamp."""
|
||||
|
||||
sentinel: Any = field(default_factory=lambda: object())
|
||||
calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int]] = field(default_factory=list)
|
||||
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||
|
||||
def frames_at(self, record, timestamps):
|
||||
self.calls.append((record.episode_index, tuple(timestamps)))
|
||||
@property
|
||||
def camera_keys(self) -> list[str]:
|
||||
return list(self.cameras)
|
||||
|
||||
def frames_at(self, record, timestamps, camera_key=None):
|
||||
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||
return [self.sentinel] * len(timestamps)
|
||||
|
||||
def video_for_episode(self, record, max_frames):
|
||||
self.video_calls.append((record.episode_index, max_frames))
|
||||
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||
n = min(max_frames, len(record.frame_timestamps))
|
||||
return [self.sentinel] * n
|
||||
|
||||
@@ -148,7 +153,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||
|
||||
|
||||
def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
payload = {
|
||||
"question": "How many cups?",
|
||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||
@@ -158,19 +163,34 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path)
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
||||
seed=1,
|
||||
frame_provider=_StubFrameProvider(
|
||||
cameras=("observation.images.top", "observation.images.wrist")
|
||||
),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
module.run_episode(record, staging)
|
||||
rows = staging.read("module_3")
|
||||
user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
||||
assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"]
|
||||
# at most one user (vqa) per frame; same for assistant
|
||||
assert len(user_ts) == len(set(user_ts))
|
||||
assert len(assistant_ts) == len(set(assistant_ts))
|
||||
# every vqa row must carry a camera tag and one of the configured cameras
|
||||
for r in rows:
|
||||
assert r["style"] == "vqa"
|
||||
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||
user_keys = [
|
||||
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"
|
||||
]
|
||||
assistant_keys = [
|
||||
(r["timestamp"], r["camera"])
|
||||
for r in rows
|
||||
if r["role"] == "assistant" and r["style"] == "vqa"
|
||||
]
|
||||
assert len(user_keys) == len(set(user_keys))
|
||||
assert len(assistant_keys) == len(set(assistant_keys))
|
||||
# both cameras must be represented
|
||||
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||
# every emitted timestamp must be an exact source frame timestamp
|
||||
frame_set = set(record.frame_timestamps)
|
||||
for ts in user_ts + assistant_ts:
|
||||
for ts, _ in user_keys + assistant_keys:
|
||||
assert ts in frame_set
|
||||
|
||||
|
||||
@@ -254,11 +274,12 @@ def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path,
|
||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||
assert image_blocks[0]["image"] is provider.sentinel
|
||||
assert len(text_blocks) == 1
|
||||
# provider was called once per emission with the exact emission timestamp
|
||||
for ep_idx, ts_tuple in provider.calls:
|
||||
# provider was called once per emission per camera with the exact emission timestamp
|
||||
for ep_idx, ts_tuple, camera in provider.calls:
|
||||
assert ep_idx == record.episode_index
|
||||
assert len(ts_tuple) == 1
|
||||
assert ts_tuple[0] in record.frame_timestamps
|
||||
assert camera in provider.cameras
|
||||
|
||||
|
||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||
@@ -271,6 +292,7 @@ def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_
|
||||
vlm=vlm,
|
||||
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||
seed=2,
|
||||
frame_provider=_StubFrameProvider(),
|
||||
)
|
||||
record = next(iter_episodes(single_episode_root))
|
||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||
|
||||
Reference in New Issue
Block a user