mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
feat(annotate): emit VQA per-camera and propagate camera field
Module 3 now produces one (vqa, user) + (vqa, assistant) pair per emission tick *per camera* rather than only against the dataset's first camera. Each emitted row carries the `camera` field added in PR 1 (language-columns), so the resolver can disambiguate per-camera VQA via `emitted_at(t, style=vqa, role=assistant, camera=...)` without ambiguity. - `frames.py`: `FrameProvider` Protocol gains a `camera_keys` property and a `camera_key=` argument on `frames_at` / `video_for_episode`. `VideoFrameProvider` exposes every `observation.images.*` key the dataset declares (not just the first) and keys its decode cache on `(episode, camera, timestamp)` so per-camera reads don't collide. Module 1 / 2 keep their old single-camera behaviour by leaving `camera_key=None` (falls back to the default camera). - `modules/general_vqa.py`: `run_episode` iterates `frame_provider .camera_keys` for each emission tick, builds one prompt per camera, batches all of them through the VLM, and stamps the resulting rows with `camera=<that key>`. Empty `camera_keys` (null provider) makes the module a no-op rather than silently emitting untagged rows. - `writer.py`: `_normalize_persistent_row` / `_normalize_event_row` carry `camera` through and call `validate_camera_field` so the invariant is enforced at the writer boundary. Event sort key now includes `camera` for deterministic ordering when several cameras share `(timestamp, style, role)`. `speech_atom` sets `camera=None`. - `validator.py`: `StagingValidator` gains a `dataset_camera_keys` field; `_check_camera_field` enforces the invariant and cross-checks every view-dependent row's `camera` against the dataset's known video keys. New `_check_vqa_uniqueness_per_frame_camera` flags duplicate `(vqa, role)` pairs at the same `(t, camera)`. - `lerobot_annotate.py`: passes the live frame provider's `camera_keys` into the validator so the cross-check uses the actual dataset camera set. - Tests: `_StubFrameProvider` exposes `camera_keys` and accepts the new `camera_key=` kwarg. `test_module3_vqa_unique_per_frame_and_camera` configures two cameras and asserts both are represented, that every emitted row has a `camera` tag, and that uniqueness holds per `(timestamp, camera, role)`. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -34,10 +34,29 @@ from .reader import EpisodeRecord
|
|||||||
class FrameProvider(Protocol):
|
class FrameProvider(Protocol):
|
||||||
"""Decodes camera frames at episode-relative timestamps."""
|
"""Decodes camera frames at episode-relative timestamps."""
|
||||||
|
|
||||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
@property
|
||||||
"""Return one PIL.Image per timestamp; empty list if no camera available."""
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` feature keys this provider can decode."""
|
||||||
|
|
||||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
"""Return one PIL.Image per timestamp from ``camera_key`` (or default).
|
||||||
|
|
||||||
|
Empty list if the camera is unavailable. ``camera_key=None`` falls back
|
||||||
|
to the provider's default camera so existing single-camera callers
|
||||||
|
(Module 1, Module 2) keep working unchanged.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
"""Return up to ``max_frames`` PIL images covering the whole episode.
|
||||||
|
|
||||||
Sampling is uniform across the episode duration. The returned list is
|
Sampling is uniform across the episode duration. The returned list is
|
||||||
@@ -51,10 +70,24 @@ class FrameProvider(Protocol):
|
|||||||
class _NullProvider:
|
class _NullProvider:
|
||||||
"""No-op provider used when the dataset has no video keys or in tests."""
|
"""No-op provider used when the dataset has no video keys or in tests."""
|
||||||
|
|
||||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
@property
|
||||||
|
def camera_keys(self) -> list[str]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
|
|
||||||
@@ -64,12 +97,18 @@ def null_provider() -> FrameProvider:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class VideoFrameProvider:
|
class VideoFrameProvider:
|
||||||
"""Decodes frames from the dataset's first ``observation.images.*`` stream.
|
"""Decodes frames from the dataset's ``observation.images.*`` streams.
|
||||||
|
|
||||||
The first camera key is used unconditionally — Module 1/2/3 prompts care
|
By default the *first* camera key is used for Module 1 (subtask
|
||||||
about *what is happening*, not which camera angle the model sees, so a
|
decomposition) and Module 2 (interjection scenarios) — those prompts care
|
||||||
single canonical viewpoint is enough. Override ``camera_key`` if you
|
about *what is happening*, not which angle. Module 3 (VQA) instead
|
||||||
want a specific stream.
|
iterates over every camera in :attr:`camera_keys` so each frame's
|
||||||
|
grounded answer (bbox/keypoint/...) is tagged with the camera it was
|
||||||
|
grounded against.
|
||||||
|
|
||||||
|
``camera_key`` overrides the default-camera choice but does not restrict
|
||||||
|
:attr:`camera_keys`. Pass ``camera_key`` explicitly to ``frames_at`` /
|
||||||
|
``video_for_episode`` to read a non-default stream.
|
||||||
|
|
||||||
Caches up to ``cache_size`` decoded frames per process to keep
|
Caches up to ``cache_size`` decoded frames per process to keep
|
||||||
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
co-timestamped Module 2 + Module 1 plan-update calls cheap.
|
||||||
@@ -81,24 +120,37 @@ class VideoFrameProvider:
|
|||||||
cache_size: int = 256
|
cache_size: int = 256
|
||||||
_meta: Any = field(default=None, init=False, repr=False)
|
_meta: Any = field(default=None, init=False, repr=False)
|
||||||
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
_cache: dict = field(default_factory=dict, init=False, repr=False)
|
||||||
|
_camera_keys: list[str] = field(default_factory=list, init=False, repr=False)
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata # noqa: PLC0415
|
||||||
|
|
||||||
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
self._meta = LeRobotDatasetMetadata(repo_id="local", root=self.root)
|
||||||
|
keys = list(self._meta.video_keys or [])
|
||||||
|
self._camera_keys = keys
|
||||||
if self.camera_key is None:
|
if self.camera_key is None:
|
||||||
keys = self._meta.video_keys
|
|
||||||
self.camera_key = keys[0] if keys else None
|
self.camera_key = keys[0] if keys else None
|
||||||
|
|
||||||
def frames_at(self, record: EpisodeRecord, timestamps: list[float]) -> list[Any]:
|
@property
|
||||||
if not timestamps or self.camera_key is None:
|
def camera_keys(self) -> list[str]:
|
||||||
|
"""All ``observation.images.*`` keys available on this dataset."""
|
||||||
|
return list(self._camera_keys)
|
||||||
|
|
||||||
|
def frames_at(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
timestamps: list[float],
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if not timestamps or target is None:
|
||||||
return []
|
return []
|
||||||
|
|
||||||
out: list[Any] = []
|
out: list[Any] = []
|
||||||
misses: list[float] = []
|
misses: list[float] = []
|
||||||
miss_indices: list[int] = []
|
miss_indices: list[int] = []
|
||||||
for i, ts in enumerate(timestamps):
|
for i, ts in enumerate(timestamps):
|
||||||
key = (record.episode_index, round(float(ts), 6))
|
key = (record.episode_index, target, round(float(ts), 6))
|
||||||
cached = self._cache.get(key)
|
cached = self._cache.get(key)
|
||||||
if cached is not None:
|
if cached is not None:
|
||||||
out.append(cached)
|
out.append(cached)
|
||||||
@@ -108,20 +160,22 @@ class VideoFrameProvider:
|
|||||||
miss_indices.append(i)
|
miss_indices.append(i)
|
||||||
|
|
||||||
if misses:
|
if misses:
|
||||||
decoded = self._decode(record.episode_index, misses)
|
decoded = self._decode(record.episode_index, misses, target)
|
||||||
# decoder may return fewer frames than requested when some
|
# decoder may return fewer frames than requested when some
|
||||||
# timestamps fall outside the video; pair what we have and
|
# timestamps fall outside the video; pair what we have and
|
||||||
# leave the rest as None to be filtered below.
|
# leave the rest as None to be filtered below.
|
||||||
for i, img in zip(miss_indices, decoded):
|
for i, img in zip(miss_indices, decoded):
|
||||||
out[i] = img
|
out[i] = img
|
||||||
key = (record.episode_index, round(float(timestamps[i]), 6))
|
key = (record.episode_index, target, round(float(timestamps[i]), 6))
|
||||||
if len(self._cache) >= self.cache_size:
|
if len(self._cache) >= self.cache_size:
|
||||||
self._cache.pop(next(iter(self._cache)))
|
self._cache.pop(next(iter(self._cache)))
|
||||||
self._cache[key] = img
|
self._cache[key] = img
|
||||||
# filter out any None left over from decode failures
|
# filter out any None left over from decode failures
|
||||||
return [img for img in out if img is not None]
|
return [img for img in out if img is not None]
|
||||||
|
|
||||||
def _decode(self, episode_index: int, timestamps: list[float]) -> list[Any]:
|
def _decode(
|
||||||
|
self, episode_index: int, timestamps: list[float], camera_key: str
|
||||||
|
) -> list[Any]:
|
||||||
import os as _os # noqa: PLC0415
|
import os as _os # noqa: PLC0415
|
||||||
|
|
||||||
from PIL import Image # noqa: PLC0415
|
from PIL import Image # noqa: PLC0415
|
||||||
@@ -129,9 +183,9 @@ class VideoFrameProvider:
|
|||||||
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
from lerobot.datasets.video_utils import decode_video_frames # noqa: PLC0415
|
||||||
|
|
||||||
ep = self._meta.episodes[episode_index]
|
ep = self._meta.episodes[episode_index]
|
||||||
from_timestamp = ep[f"videos/{self.camera_key}/from_timestamp"]
|
from_timestamp = ep[f"videos/{camera_key}/from_timestamp"]
|
||||||
shifted = [from_timestamp + ts for ts in timestamps]
|
shifted = [from_timestamp + ts for ts in timestamps]
|
||||||
video_path = self.root / self._meta.get_video_file_path(episode_index, self.camera_key)
|
video_path = self.root / self._meta.get_video_file_path(episode_index, camera_key)
|
||||||
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
|
# ``torchcodec`` import currently bad-allocs on cu128/torch-2.8 in
|
||||||
# some environments; default to ``pyav`` (always available via
|
# some environments; default to ``pyav`` (always available via
|
||||||
# the ``av`` package) and let users override with
|
# the ``av`` package) and let users override with
|
||||||
@@ -156,13 +210,19 @@ class VideoFrameProvider:
|
|||||||
out.append(Image.fromarray(hwc, mode="RGB"))
|
out.append(Image.fromarray(hwc, mode="RGB"))
|
||||||
return out
|
return out
|
||||||
|
|
||||||
def video_for_episode(self, record: EpisodeRecord, max_frames: int) -> list[Any]:
|
def video_for_episode(
|
||||||
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
max_frames: int,
|
||||||
|
camera_key: str | None = None,
|
||||||
|
) -> list[Any]:
|
||||||
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
"""Return up to ``max_frames`` images uniformly sampled across the episode.
|
||||||
|
|
||||||
The whole episode duration is covered; the model picks subtask
|
The whole episode duration is covered; the model picks subtask
|
||||||
boundaries from the temporal pooling it does internally.
|
boundaries from the temporal pooling it does internally.
|
||||||
"""
|
"""
|
||||||
if max_frames <= 0 or self.camera_key is None or not record.frame_timestamps:
|
target = camera_key if camera_key is not None else self.camera_key
|
||||||
|
if max_frames <= 0 or target is None or not record.frame_timestamps:
|
||||||
return []
|
return []
|
||||||
n_frames = min(max_frames, len(record.frame_timestamps))
|
n_frames = min(max_frames, len(record.frame_timestamps))
|
||||||
if n_frames == len(record.frame_timestamps):
|
if n_frames == len(record.frame_timestamps):
|
||||||
@@ -175,7 +235,7 @@ class VideoFrameProvider:
|
|||||||
else:
|
else:
|
||||||
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
step = (t_last - t0) / (n_frames - 1) if n_frames > 1 else 0.0
|
||||||
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
timestamps = [float(t0 + i * step) for i in range(n_frames)]
|
||||||
return self.frames_at(record, timestamps)
|
return self.frames_at(record, timestamps, camera_key=target)
|
||||||
|
|
||||||
|
|
||||||
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
def make_frame_provider(root: Path, camera_key: str | None = None) -> FrameProvider:
|
||||||
|
|||||||
@@ -16,8 +16,15 @@
|
|||||||
"""Module 3: general VQA at a timed cadence.
|
"""Module 3: general VQA at a timed cadence.
|
||||||
|
|
||||||
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
Anchors ``K`` (question, answer) pairs to ``K`` consecutive frames per
|
||||||
emission so each frame gets at most one ``(vqa, user)`` and one
|
emission. For datasets with multiple cameras, every emission tick produces
|
||||||
``(vqa, assistant)`` pair — keeps the resolver contract scalar.
|
one ``(vqa, user)`` + ``(vqa, assistant)`` pair *per camera*: each pair is
|
||||||
|
generated against that camera's frame and stamped with the matching
|
||||||
|
``camera`` field on the emitted rows. The resolver disambiguates via
|
||||||
|
``camera=...``; recipes that consume VQA do so through one sub-recipe
|
||||||
|
per camera (see ``recipes/pi05_hirobot.yaml``).
|
||||||
|
|
||||||
|
Within a single (frame, camera) we still emit at most one ``(vqa, user)``
|
||||||
|
and one ``(vqa, assistant)`` row, so the resolver contract stays scalar.
|
||||||
|
|
||||||
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
Question types covered (per the plan's Module 3 table): bbox, keypoint,
|
||||||
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
count, attribute, spatial. The assistant's ``content`` is a JSON string
|
||||||
@@ -98,23 +105,37 @@ class GeneralVqaModule:
|
|||||||
anchor_idx = _emission_anchor_indices(
|
anchor_idx = _emission_anchor_indices(
|
||||||
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
record.frame_timestamps, self.config.vqa_emission_hz, self.config.K
|
||||||
)
|
)
|
||||||
# Build all messages first, then issue them as a single batched
|
cameras = self._target_cameras()
|
||||||
# generate_json call so the client can fan them out concurrently.
|
if not cameras:
|
||||||
per_call: list[tuple[float, str, list[dict[str, Any]]]] = []
|
# No camera available — keep behaviour parity with previous
|
||||||
|
# text-only stub: emit nothing rather than producing untagged
|
||||||
|
# rows that would fail validation.
|
||||||
|
staging.write("module_3", [])
|
||||||
|
return
|
||||||
|
|
||||||
|
# Build all messages first (one per (frame, camera)), then issue them
|
||||||
|
# as a single batched generate_json call so the client can fan them
|
||||||
|
# out concurrently.
|
||||||
|
per_call: list[tuple[float, str, str, list[dict[str, Any]]]] = []
|
||||||
for idx in anchor_idx:
|
for idx in anchor_idx:
|
||||||
ts = float(record.frame_timestamps[idx])
|
ts = float(record.frame_timestamps[idx])
|
||||||
qtype = rng.choice(self.config.question_types)
|
qtype = rng.choice(self.config.question_types)
|
||||||
messages = self._build_messages(record, qtype, ts)
|
for camera in cameras:
|
||||||
per_call.append((ts, qtype, messages))
|
messages = self._build_messages(record, qtype, ts, camera)
|
||||||
|
# Skip cameras that decoded to zero frames at this ts: no point
|
||||||
|
# asking the VLM to ground a bbox without an image.
|
||||||
|
if not _has_image_block(messages):
|
||||||
|
continue
|
||||||
|
per_call.append((ts, camera, qtype, messages))
|
||||||
|
|
||||||
if not per_call:
|
if not per_call:
|
||||||
staging.write("module_3", [])
|
staging.write("module_3", [])
|
||||||
return
|
return
|
||||||
|
|
||||||
results = self.vlm.generate_json([m for _, _, m in per_call])
|
results = self.vlm.generate_json([m for _, _, _, m in per_call])
|
||||||
|
|
||||||
rows: list[dict[str, Any]] = []
|
rows: list[dict[str, Any]] = []
|
||||||
for (ts, _qtype, _messages), result in zip(per_call, results):
|
for (ts, camera, _qtype, _messages), result in zip(per_call, results):
|
||||||
qa = self._postprocess(result)
|
qa = self._postprocess(result)
|
||||||
if qa is None:
|
if qa is None:
|
||||||
continue
|
continue
|
||||||
@@ -125,6 +146,7 @@ class GeneralVqaModule:
|
|||||||
"content": question,
|
"content": question,
|
||||||
"style": "vqa",
|
"style": "vqa",
|
||||||
"timestamp": ts,
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
@@ -134,19 +156,35 @@ class GeneralVqaModule:
|
|||||||
"content": json.dumps(answer, sort_keys=True),
|
"content": json.dumps(answer, sort_keys=True),
|
||||||
"style": "vqa",
|
"style": "vqa",
|
||||||
"timestamp": ts,
|
"timestamp": ts,
|
||||||
|
"camera": camera,
|
||||||
"tool_calls": None,
|
"tool_calls": None,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
staging.write("module_3", rows)
|
staging.write("module_3", rows)
|
||||||
|
|
||||||
|
def _target_cameras(self) -> list[str]:
|
||||||
|
"""Return the cameras Module 3 should iterate per emission tick.
|
||||||
|
|
||||||
|
Defaults to every camera the provider exposes. Datasets with no
|
||||||
|
cameras (or test/null providers) yield an empty list, which makes
|
||||||
|
``run_episode`` a no-op.
|
||||||
|
"""
|
||||||
|
return list(getattr(self.frame_provider, "camera_keys", []) or [])
|
||||||
|
|
||||||
def _build_messages(
|
def _build_messages(
|
||||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
question_type: str,
|
||||||
|
frame_timestamp: float,
|
||||||
|
camera_key: str,
|
||||||
) -> list[dict[str, Any]]:
|
) -> list[dict[str, Any]]:
|
||||||
prompt = load_prompt("module_3_vqa").format(
|
prompt = load_prompt("module_3_vqa").format(
|
||||||
episode_task=record.episode_task,
|
episode_task=record.episode_task,
|
||||||
question_type=question_type,
|
question_type=question_type,
|
||||||
)
|
)
|
||||||
images = self.frame_provider.frames_at(record, [frame_timestamp])
|
images = self.frame_provider.frames_at(
|
||||||
|
record, [frame_timestamp], camera_key=camera_key
|
||||||
|
)
|
||||||
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
content = [*to_image_blocks(images), {"type": "text", "text": prompt}]
|
||||||
return [{"role": "user", "content": content}]
|
return [{"role": "user", "content": content}]
|
||||||
|
|
||||||
@@ -166,8 +204,24 @@ class GeneralVqaModule:
|
|||||||
return question.strip(), answer
|
return question.strip(), answer
|
||||||
|
|
||||||
def _generate_one(
|
def _generate_one(
|
||||||
self, record: EpisodeRecord, question_type: str, frame_timestamp: float
|
self,
|
||||||
|
record: EpisodeRecord,
|
||||||
|
question_type: str,
|
||||||
|
frame_timestamp: float,
|
||||||
|
camera_key: str,
|
||||||
) -> tuple[str, dict[str, Any]] | None:
|
) -> tuple[str, dict[str, Any]] | None:
|
||||||
messages = self._build_messages(record, question_type, frame_timestamp)
|
messages = self._build_messages(record, question_type, frame_timestamp, camera_key)
|
||||||
result = self.vlm.generate_json([messages])[0]
|
result = self.vlm.generate_json([messages])[0]
|
||||||
return self._postprocess(result)
|
return self._postprocess(result)
|
||||||
|
|
||||||
|
|
||||||
|
def _has_image_block(messages: list[dict[str, Any]]) -> bool:
|
||||||
|
"""Return True if any user content block is a populated image block."""
|
||||||
|
for msg in messages:
|
||||||
|
content = msg.get("content")
|
||||||
|
if not isinstance(content, list):
|
||||||
|
continue
|
||||||
|
for block in content:
|
||||||
|
if isinstance(block, dict) and block.get("type") == "image":
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|||||||
@@ -43,6 +43,8 @@ from lerobot.datasets.language import (
|
|||||||
LANGUAGE_EVENTS,
|
LANGUAGE_EVENTS,
|
||||||
LANGUAGE_PERSISTENT,
|
LANGUAGE_PERSISTENT,
|
||||||
column_for_style,
|
column_for_style,
|
||||||
|
is_view_dependent_style,
|
||||||
|
validate_camera_field,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .reader import EpisodeRecord
|
from .reader import EpisodeRecord
|
||||||
@@ -98,6 +100,11 @@ class StagingValidator:
|
|||||||
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
"""Walks the staging tree and produces a :class:`ValidationReport`."""
|
||||||
|
|
||||||
timestamp_atol: float = 0.0 # exact-match by default
|
timestamp_atol: float = 0.0 # exact-match by default
|
||||||
|
dataset_camera_keys: tuple[str, ...] | None = None
|
||||||
|
"""Known ``observation.images.*`` keys on the dataset. When set, the
|
||||||
|
validator additionally enforces that every view-dependent row's
|
||||||
|
``camera`` field references one of these keys. Pass ``None`` (default)
|
||||||
|
to skip that cross-check (e.g. in unit tests with no real dataset)."""
|
||||||
|
|
||||||
def validate(
|
def validate(
|
||||||
self,
|
self,
|
||||||
@@ -130,6 +137,9 @@ class StagingValidator:
|
|||||||
persistent: list[dict[str, Any]] = []
|
persistent: list[dict[str, Any]] = []
|
||||||
for row in all_rows:
|
for row in all_rows:
|
||||||
self._check_column_routing(row, report, record.episode_index)
|
self._check_column_routing(row, report, record.episode_index)
|
||||||
|
self._check_camera_field(
|
||||||
|
row, report, record.episode_index, self.dataset_camera_keys
|
||||||
|
)
|
||||||
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
if column_for_style(row.get("style")) == LANGUAGE_PERSISTENT:
|
||||||
persistent.append(row)
|
persistent.append(row)
|
||||||
else:
|
else:
|
||||||
@@ -141,6 +151,59 @@ class StagingValidator:
|
|||||||
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
self._check_speech_interjection_pairs(events, report, record.episode_index)
|
||||||
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
self._check_plan_memory_consistency(persistent, events, report, record.episode_index)
|
||||||
self._check_vqa_json(events, report, record.episode_index)
|
self._check_vqa_json(events, report, record.episode_index)
|
||||||
|
self._check_vqa_uniqueness_per_frame_camera(events, report, record.episode_index)
|
||||||
|
|
||||||
|
def _check_camera_field(
|
||||||
|
self,
|
||||||
|
row: dict[str, Any],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
dataset_camera_keys: Sequence[str] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Enforce the camera invariant + that the key matches the dataset's cameras."""
|
||||||
|
style = row.get("style")
|
||||||
|
camera = row.get("camera")
|
||||||
|
try:
|
||||||
|
validate_camera_field(style, camera)
|
||||||
|
except ValueError as exc:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={row.get('_module')}: {exc}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
if (
|
||||||
|
is_view_dependent_style(style)
|
||||||
|
and dataset_camera_keys
|
||||||
|
and camera not in dataset_camera_keys
|
||||||
|
):
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index} module={row.get('_module')}: camera {camera!r} on style "
|
||||||
|
f"{style!r} is not one of the dataset's video keys {sorted(dataset_camera_keys)!r}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def _check_vqa_uniqueness_per_frame_camera(
|
||||||
|
self,
|
||||||
|
events: Iterable[dict[str, Any]],
|
||||||
|
report: ValidationReport,
|
||||||
|
episode_index: int,
|
||||||
|
) -> None:
|
||||||
|
"""Ensure at most one (vqa, user) and one (vqa, assistant) per (t, camera)."""
|
||||||
|
counts: dict[tuple[float, str, str], int] = {}
|
||||||
|
for row in events:
|
||||||
|
if row.get("style") != "vqa":
|
||||||
|
continue
|
||||||
|
ts = row.get("timestamp")
|
||||||
|
camera = row.get("camera")
|
||||||
|
role = row.get("role")
|
||||||
|
if ts is None or camera is None or role is None:
|
||||||
|
continue # other validators flag these
|
||||||
|
key = (float(ts), str(camera), str(role))
|
||||||
|
counts[key] = counts.get(key, 0) + 1
|
||||||
|
for (ts, camera, role), n in counts.items():
|
||||||
|
if n > 1:
|
||||||
|
report.add_error(
|
||||||
|
f"ep={episode_index}: {n} duplicate vqa rows at t={ts} "
|
||||||
|
f"camera={camera!r} role={role!r}; expected at most one per (t, camera, role)"
|
||||||
|
)
|
||||||
|
|
||||||
def _check_column_routing(
|
def _check_column_routing(
|
||||||
self,
|
self,
|
||||||
|
|||||||
@@ -55,6 +55,7 @@ from lerobot.datasets.language import (
|
|||||||
LANGUAGE_PERSISTENT,
|
LANGUAGE_PERSISTENT,
|
||||||
PERSISTENT_STYLES,
|
PERSISTENT_STYLES,
|
||||||
column_for_style,
|
column_for_style,
|
||||||
|
validate_camera_field,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .reader import EpisodeRecord
|
from .reader import EpisodeRecord
|
||||||
@@ -88,7 +89,11 @@ def _row_persistent_sort_key(row: dict[str, Any]) -> tuple:
|
|||||||
|
|
||||||
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
def _row_event_sort_key(row: dict[str, Any]) -> tuple:
|
||||||
# events are bucketed per-frame, but within a frame we still want determinism
|
# events are bucketed per-frame, but within a frame we still want determinism
|
||||||
return (row.get("style") or "", row.get("role") or "")
|
return (
|
||||||
|
row.get("style") or "",
|
||||||
|
row.get("role") or "",
|
||||||
|
row.get("camera") or "",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
||||||
@@ -101,11 +106,14 @@ def _normalize_persistent_row(row: dict[str, Any]) -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
if "timestamp" not in row:
|
if "timestamp" not in row:
|
||||||
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
raise ValueError(f"persistent row missing timestamp: {row!r}")
|
||||||
|
camera = row.get("camera")
|
||||||
|
validate_camera_field(style, camera)
|
||||||
return {
|
return {
|
||||||
"role": str(row["role"]),
|
"role": str(row["role"]),
|
||||||
"content": None if row.get("content") is None else str(row["content"]),
|
"content": None if row.get("content") is None else str(row["content"]),
|
||||||
"style": style,
|
"style": style,
|
||||||
"timestamp": float(row["timestamp"]),
|
"timestamp": float(row["timestamp"]),
|
||||||
|
"camera": None if camera is None else str(camera),
|
||||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,10 +127,13 @@ def _normalize_event_row(row: dict[str, Any]) -> dict[str, Any]:
|
|||||||
)
|
)
|
||||||
if column_for_style(style) != LANGUAGE_EVENTS:
|
if column_for_style(style) != LANGUAGE_EVENTS:
|
||||||
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
raise ValueError(f"event row with style {style!r} would not route to language_events")
|
||||||
|
camera = row.get("camera")
|
||||||
|
validate_camera_field(style, camera)
|
||||||
return {
|
return {
|
||||||
"role": str(row["role"]),
|
"role": str(row["role"]),
|
||||||
"content": None if row.get("content") is None else str(row["content"]),
|
"content": None if row.get("content") is None else str(row["content"]),
|
||||||
"style": style,
|
"style": style,
|
||||||
|
"camera": None if camera is None else str(camera),
|
||||||
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
"tool_calls": _normalize_tool_calls(row.get("tool_calls")),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -311,6 +322,7 @@ def speech_atom(timestamp: float, text: str) -> dict[str, Any]:
|
|||||||
"content": None,
|
"content": None,
|
||||||
"style": None,
|
"style": None,
|
||||||
"timestamp": float(timestamp),
|
"timestamp": float(timestamp),
|
||||||
|
"camera": None,
|
||||||
"tool_calls": [
|
"tool_calls": [
|
||||||
{
|
{
|
||||||
"type": "function",
|
"type": "function",
|
||||||
|
|||||||
@@ -72,7 +72,9 @@ def annotate(cfg: AnnotationPipelineConfig) -> None:
|
|||||||
)
|
)
|
||||||
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
module_3 = GeneralVqaModule(vlm=vlm, config=cfg.module_3, seed=cfg.seed, frame_provider=frame_provider)
|
||||||
writer = LanguageColumnsWriter()
|
writer = LanguageColumnsWriter()
|
||||||
validator = StagingValidator()
|
validator = StagingValidator(
|
||||||
|
dataset_camera_keys=tuple(getattr(frame_provider, "camera_keys", []) or []) or None,
|
||||||
|
)
|
||||||
|
|
||||||
executor = Executor(
|
executor = Executor(
|
||||||
config=cfg,
|
config=cfg,
|
||||||
|
|||||||
@@ -44,15 +44,20 @@ class _StubFrameProvider:
|
|||||||
"""Returns one sentinel object per requested timestamp."""
|
"""Returns one sentinel object per requested timestamp."""
|
||||||
|
|
||||||
sentinel: Any = field(default_factory=lambda: object())
|
sentinel: Any = field(default_factory=lambda: object())
|
||||||
calls: list[tuple[int, tuple[float, ...]]] = field(default_factory=list)
|
cameras: tuple[str, ...] = ("observation.images.top",)
|
||||||
video_calls: list[tuple[int, int]] = field(default_factory=list)
|
calls: list[tuple[int, tuple[float, ...], str | None]] = field(default_factory=list)
|
||||||
|
video_calls: list[tuple[int, int, str | None]] = field(default_factory=list)
|
||||||
|
|
||||||
def frames_at(self, record, timestamps):
|
@property
|
||||||
self.calls.append((record.episode_index, tuple(timestamps)))
|
def camera_keys(self) -> list[str]:
|
||||||
|
return list(self.cameras)
|
||||||
|
|
||||||
|
def frames_at(self, record, timestamps, camera_key=None):
|
||||||
|
self.calls.append((record.episode_index, tuple(timestamps), camera_key))
|
||||||
return [self.sentinel] * len(timestamps)
|
return [self.sentinel] * len(timestamps)
|
||||||
|
|
||||||
def video_for_episode(self, record, max_frames):
|
def video_for_episode(self, record, max_frames, camera_key=None):
|
||||||
self.video_calls.append((record.episode_index, max_frames))
|
self.video_calls.append((record.episode_index, max_frames, camera_key))
|
||||||
n = min(max_frames, len(record.frame_timestamps))
|
n = min(max_frames, len(record.frame_timestamps))
|
||||||
return [self.sentinel] * n
|
return [self.sentinel] * n
|
||||||
|
|
||||||
@@ -148,7 +153,7 @@ def test_module2_mid_episode_emits_paired_interjection_and_speech(
|
|||||||
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
assert any(abs(s["timestamp"] - inter_t) < 1e-9 for s in speeches)
|
||||||
|
|
||||||
|
|
||||||
def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path) -> None:
|
def test_module3_vqa_unique_per_frame_and_camera(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
payload = {
|
payload = {
|
||||||
"question": "How many cups?",
|
"question": "How many cups?",
|
||||||
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
"answer": {"label": "cup", "count": 2, "note": "white & blue"},
|
||||||
@@ -158,19 +163,34 @@ def test_module3_vqa_unique_per_frame(single_episode_root: Path, tmp_path: Path)
|
|||||||
vlm=vlm,
|
vlm=vlm,
|
||||||
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
config=Module3Config(vqa_emission_hz=1.0, K=3),
|
||||||
seed=1,
|
seed=1,
|
||||||
|
frame_provider=_StubFrameProvider(
|
||||||
|
cameras=("observation.images.top", "observation.images.wrist")
|
||||||
|
),
|
||||||
)
|
)
|
||||||
record = next(iter_episodes(single_episode_root))
|
record = next(iter_episodes(single_episode_root))
|
||||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
module.run_episode(record, staging)
|
module.run_episode(record, staging)
|
||||||
rows = staging.read("module_3")
|
rows = staging.read("module_3")
|
||||||
user_ts = [r["timestamp"] for r in rows if r["role"] == "user" and r["style"] == "vqa"]
|
# every vqa row must carry a camera tag and one of the configured cameras
|
||||||
assistant_ts = [r["timestamp"] for r in rows if r["role"] == "assistant" and r["style"] == "vqa"]
|
for r in rows:
|
||||||
# at most one user (vqa) per frame; same for assistant
|
assert r["style"] == "vqa"
|
||||||
assert len(user_ts) == len(set(user_ts))
|
assert r.get("camera") in {"observation.images.top", "observation.images.wrist"}
|
||||||
assert len(assistant_ts) == len(set(assistant_ts))
|
# at most one (vqa, user) and one (vqa, assistant) per (timestamp, camera)
|
||||||
|
user_keys = [
|
||||||
|
(r["timestamp"], r["camera"]) for r in rows if r["role"] == "user" and r["style"] == "vqa"
|
||||||
|
]
|
||||||
|
assistant_keys = [
|
||||||
|
(r["timestamp"], r["camera"])
|
||||||
|
for r in rows
|
||||||
|
if r["role"] == "assistant" and r["style"] == "vqa"
|
||||||
|
]
|
||||||
|
assert len(user_keys) == len(set(user_keys))
|
||||||
|
assert len(assistant_keys) == len(set(assistant_keys))
|
||||||
|
# both cameras must be represented
|
||||||
|
assert {c for _, c in user_keys} == {"observation.images.top", "observation.images.wrist"}
|
||||||
# every emitted timestamp must be an exact source frame timestamp
|
# every emitted timestamp must be an exact source frame timestamp
|
||||||
frame_set = set(record.frame_timestamps)
|
frame_set = set(record.frame_timestamps)
|
||||||
for ts in user_ts + assistant_ts:
|
for ts, _ in user_keys + assistant_keys:
|
||||||
assert ts in frame_set
|
assert ts in frame_set
|
||||||
|
|
||||||
|
|
||||||
@@ -254,11 +274,12 @@ def test_module3_attaches_frame_image_block_to_prompt(single_episode_root: Path,
|
|||||||
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
assert len(image_blocks) == 1, f"expected 1 image block per VQA prompt, got {content}"
|
||||||
assert image_blocks[0]["image"] is provider.sentinel
|
assert image_blocks[0]["image"] is provider.sentinel
|
||||||
assert len(text_blocks) == 1
|
assert len(text_blocks) == 1
|
||||||
# provider was called once per emission with the exact emission timestamp
|
# provider was called once per emission per camera with the exact emission timestamp
|
||||||
for ep_idx, ts_tuple in provider.calls:
|
for ep_idx, ts_tuple, camera in provider.calls:
|
||||||
assert ep_idx == record.episode_index
|
assert ep_idx == record.episode_index
|
||||||
assert len(ts_tuple) == 1
|
assert len(ts_tuple) == 1
|
||||||
assert ts_tuple[0] in record.frame_timestamps
|
assert ts_tuple[0] in record.frame_timestamps
|
||||||
|
assert camera in provider.cameras
|
||||||
|
|
||||||
|
|
||||||
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_path: Path) -> None:
|
||||||
@@ -271,6 +292,7 @@ def test_module3_assistant_content_is_valid_json(single_episode_root: Path, tmp_
|
|||||||
vlm=vlm,
|
vlm=vlm,
|
||||||
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
config=Module3Config(vqa_emission_hz=1.0, K=2),
|
||||||
seed=2,
|
seed=2,
|
||||||
|
frame_provider=_StubFrameProvider(),
|
||||||
)
|
)
|
||||||
record = next(iter_episodes(single_episode_root))
|
record = next(iter_episodes(single_episode_root))
|
||||||
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
staging = EpisodeStaging(tmp_path / "stage", record.episode_index)
|
||||||
|
|||||||
Reference in New Issue
Block a user