mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
5a6aa64570
Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.
- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
to both Arrow struct types and the HF datasets feature mappings;
introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
`is_view_dependent_style` and `validate_camera_field` helpers (camera
required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
`_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
Without a `camera` filter, multi-row matches keep raising the existing
ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
`ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
the matching image block), keeping the original 0.20 budget and
documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
`is_view_dependent_style`, `validate_camera_field` (both required and
forbidden directions), per-camera `emitted_at` filtering, and the
ambiguity error when two cameras emit `(vqa, assistant)` at the same
timestamp without a `camera=` filter. RenderMessagesStep + dataset
passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
the per-camera resolver pattern, and the canonical recipe convention.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
147 lines
5.1 KiB
Python
147 lines
5.1 KiB
Python
#!/usr/bin/env python
|
|
|
|
import numpy as np
|
|
import pandas as pd
|
|
import pyarrow as pa
|
|
import pytest
|
|
|
|
from lerobot.datasets import LeRobotDataset
|
|
from lerobot.datasets.io_utils import write_info
|
|
from lerobot.datasets.language import (
|
|
EVENT_ONLY_STYLES,
|
|
LANGUAGE_EVENTS,
|
|
LANGUAGE_PERSISTENT,
|
|
PERSISTENT_STYLES,
|
|
STYLE_REGISTRY,
|
|
VIEW_DEPENDENT_STYLES,
|
|
column_for_style,
|
|
is_view_dependent_style,
|
|
language_events_arrow_type,
|
|
language_feature_info,
|
|
language_persistent_arrow_type,
|
|
validate_camera_field,
|
|
)
|
|
from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
|
|
|
|
|
def test_language_arrow_schema_has_expected_fields():
|
|
persistent_row_type = language_persistent_arrow_type().value_type
|
|
event_row_type = language_events_arrow_type().value_type
|
|
|
|
assert isinstance(persistent_row_type, pa.StructType)
|
|
assert persistent_row_type.names == [
|
|
"role",
|
|
"content",
|
|
"style",
|
|
"timestamp",
|
|
"camera",
|
|
"tool_calls",
|
|
]
|
|
|
|
assert isinstance(event_row_type, pa.StructType)
|
|
assert event_row_type.names == ["role", "content", "style", "camera", "tool_calls"]
|
|
|
|
|
|
def test_style_registry_routes_columns():
|
|
assert {"subtask", "plan", "memory", "motion"} == PERSISTENT_STYLES
|
|
assert {"interjection", "vqa", "trace"} == EVENT_ONLY_STYLES
|
|
assert PERSISTENT_STYLES | EVENT_ONLY_STYLES <= STYLE_REGISTRY
|
|
|
|
assert column_for_style("subtask") == LANGUAGE_PERSISTENT
|
|
assert column_for_style("plan") == LANGUAGE_PERSISTENT
|
|
assert column_for_style("memory") == LANGUAGE_PERSISTENT
|
|
assert column_for_style("motion") == LANGUAGE_PERSISTENT
|
|
assert column_for_style("interjection") == LANGUAGE_EVENTS
|
|
assert column_for_style("vqa") == LANGUAGE_EVENTS
|
|
assert column_for_style("trace") == LANGUAGE_EVENTS
|
|
assert column_for_style(None) == LANGUAGE_EVENTS
|
|
|
|
|
|
def test_view_dependent_styles():
|
|
assert {"vqa", "motion", "trace"} == VIEW_DEPENDENT_STYLES
|
|
assert is_view_dependent_style("vqa")
|
|
assert is_view_dependent_style("motion")
|
|
assert is_view_dependent_style("trace")
|
|
assert not is_view_dependent_style("subtask")
|
|
assert not is_view_dependent_style("plan")
|
|
assert not is_view_dependent_style("interjection")
|
|
assert not is_view_dependent_style(None)
|
|
|
|
|
|
def test_validate_camera_field_requires_camera_for_view_dependent_styles():
|
|
validate_camera_field("vqa", "observation.images.top")
|
|
validate_camera_field("motion", "observation.images.wrist")
|
|
validate_camera_field("trace", "observation.images.front")
|
|
with pytest.raises(ValueError, match="view-dependent"):
|
|
validate_camera_field("vqa", None)
|
|
with pytest.raises(ValueError, match="view-dependent"):
|
|
validate_camera_field("motion", "")
|
|
|
|
|
|
def test_validate_camera_field_rejects_camera_on_non_view_dependent_styles():
|
|
validate_camera_field("subtask", None)
|
|
validate_camera_field("plan", None)
|
|
validate_camera_field("memory", None)
|
|
validate_camera_field("interjection", None)
|
|
validate_camera_field(None, None)
|
|
with pytest.raises(ValueError, match="must have camera=None"):
|
|
validate_camera_field("subtask", "observation.images.top")
|
|
with pytest.raises(ValueError, match="must have camera=None"):
|
|
validate_camera_field("interjection", "observation.images.top")
|
|
with pytest.raises(ValueError, match="must have camera=None"):
|
|
validate_camera_field(None, "observation.images.top")
|
|
|
|
|
|
def test_unknown_style_rejected():
|
|
with pytest.raises(ValueError, match="Unknown language style"):
|
|
column_for_style("surprise")
|
|
|
|
|
|
def test_lerobot_dataset_passes_language_columns_through(tmp_path, empty_lerobot_dataset_factory):
|
|
root = tmp_path / "language_dataset"
|
|
dataset = empty_lerobot_dataset_factory(
|
|
root=root,
|
|
features={"state": {"dtype": "float32", "shape": (2,), "names": None}},
|
|
use_videos=False,
|
|
)
|
|
dataset.add_frame({"state": np.array([0.0, 1.0], dtype=np.float32), "task": "tidy"})
|
|
dataset.add_frame({"state": np.array([1.0, 2.0], dtype=np.float32), "task": "tidy"})
|
|
dataset.save_episode()
|
|
dataset.finalize()
|
|
|
|
persistent = [
|
|
{
|
|
"role": "assistant",
|
|
"content": "reach for the cup",
|
|
"style": "subtask",
|
|
"timestamp": 0.0,
|
|
"camera": None,
|
|
"tool_calls": None,
|
|
}
|
|
]
|
|
event = {
|
|
"role": "user",
|
|
"content": "what is visible?",
|
|
"style": "vqa",
|
|
"camera": "observation.images.top",
|
|
"tool_calls": None,
|
|
}
|
|
data_path = root / DEFAULT_DATA_PATH.format(chunk_index=0, file_index=0)
|
|
df = pd.read_parquet(data_path)
|
|
df[LANGUAGE_PERSISTENT] = [persistent, persistent]
|
|
df[LANGUAGE_EVENTS] = [[event], []]
|
|
df.to_parquet(data_path)
|
|
|
|
info = dataset.meta.info
|
|
info["features"].update(language_feature_info())
|
|
write_info(info, root)
|
|
|
|
reloaded = LeRobotDataset(repo_id=dataset.repo_id, root=root)
|
|
|
|
first = reloaded[0]
|
|
second = reloaded[1]
|
|
assert first[LANGUAGE_PERSISTENT] == persistent
|
|
assert first[LANGUAGE_EVENTS] == [event]
|
|
assert second[LANGUAGE_PERSISTENT] == persistent
|
|
assert second[LANGUAGE_EVENTS] == []
|