mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
5a6aa64570
Adds a nullable `camera` field to the language row struct (both persistent
and event variants) so view-dependent styles like `vqa` can carry which
`observation.images.*` view they were grounded against. Without this,
multi-camera datasets ended up with multiple `(vqa, role)` rows at the
same timestamp that the resolver could not disambiguate.
- `language.py`: add `camera` to PERSISTENT_ROW_FIELDS / EVENT_ROW_FIELDS,
to both Arrow struct types and the HF datasets feature mappings;
introduce VIEW_DEPENDENT_STYLES = {vqa, motion, trace} plus
`is_view_dependent_style` and `validate_camera_field` helpers (camera
required iff style is view-dependent).
- `language_render.py`: thread an optional `camera=` kwarg through every
resolver (`active_at`, `emitted_at`, `nth_prev`, `nth_next`) and through
`_matching_rows` / `_select_*`, so recipes can disambiguate per-camera
VQA with `emitted_at(t, style=vqa, role=assistant, camera=...)`.
Without a `camera` filter, multi-row matches keep raising the existing
ambiguity error — which is the desired behaviour on multi-camera data.
- `recipes/pi05_hirobot.yaml`: replace the single `ask_vqa` branch with
`ask_vqa_top` and `ask_vqa_wrist` per-camera sub-recipes (each carrying
the matching image block), keeping the original 0.20 budget and
documenting the customization point for datasets with different cameras.
- Tests: schema test asserts the new field order; new tests cover
`is_view_dependent_style`, `validate_camera_field` (both required and
forbidden directions), per-camera `emitted_at` filtering, and the
ambiguity error when two cameras emit `(vqa, assistant)` at the same
timestamp without a `camera=` filter. RenderMessagesStep + dataset
passthrough fixtures updated to include the new field.
- `docs/source/language_and_recipes.mdx`: document the `camera` field,
the per-camera resolver pattern, and the canonical recipe convention.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
308 lines
10 KiB
Python
308 lines
10 KiB
Python
#!/usr/bin/env python
|
|
|
|
from pathlib import Path
|
|
|
|
import pytest
|
|
|
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
|
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
|
|
|
|
|
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
|
|
return {
|
|
"role": role,
|
|
"content": content,
|
|
"style": style,
|
|
"timestamp": timestamp,
|
|
"camera": camera,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
|
|
|
|
def event_row(role, content, style, tool_calls=None, camera=None):
|
|
return {
|
|
"role": role,
|
|
"content": content,
|
|
"style": style,
|
|
"camera": camera,
|
|
"tool_calls": tool_calls,
|
|
}
|
|
|
|
|
|
PERSISTENT = [
|
|
persistent_row("assistant", "plan 0", "plan", 0.0),
|
|
persistent_row("assistant", "memory 0", "memory", 0.0),
|
|
persistent_row("assistant", "subtask 0", "subtask", 0.0),
|
|
persistent_row("assistant", "memory 1", "memory", 1.0),
|
|
persistent_row("assistant", "subtask 1", "subtask", 1.0),
|
|
]
|
|
EVENTS_AT_1 = [
|
|
event_row("user", "what is visible?", "vqa", camera="observation.images.top"),
|
|
event_row("assistant", '{"count": 2}', "vqa", camera="observation.images.top"),
|
|
]
|
|
EVENTS_AT_2 = [
|
|
event_row("user", "skip wiping", "interjection"),
|
|
event_row(
|
|
"assistant",
|
|
None,
|
|
None,
|
|
[{"type": "function", "function": {"name": "say", "arguments": {"text": "Skipping wiping."}}}],
|
|
),
|
|
]
|
|
# Same emission tick, two cameras: triggers per-camera disambiguation in
|
|
# resolvers, mirroring how Module 3 of the annotation pipeline writes one
|
|
# (vqa, user) + (vqa, assistant) pair per camera.
|
|
EVENTS_AT_3_TWO_CAMERAS = [
|
|
event_row("user", "how many cups (top)?", "vqa", camera="observation.images.top"),
|
|
event_row("assistant", '{"count": 3}', "vqa", camera="observation.images.top"),
|
|
event_row("user", "how many cups (wrist)?", "vqa", camera="observation.images.wrist"),
|
|
event_row("assistant", '{"count": 1}', "vqa", camera="observation.images.wrist"),
|
|
]
|
|
|
|
|
|
def test_resolver_temporal_semantics():
|
|
assert active_at(0.5, persistent=PERSISTENT, style="subtask")["content"] == "subtask 0"
|
|
assert active_at(1.0, persistent=PERSISTENT, style="subtask")["content"] == "subtask 1"
|
|
assert emitted_at(0.5, persistent=PERSISTENT, events=[], style="vqa", role="assistant") is None
|
|
assert (
|
|
emitted_at(1.0, persistent=PERSISTENT, events=EVENTS_AT_1, style="vqa", role="assistant")["content"]
|
|
== '{"count": 2}'
|
|
)
|
|
|
|
|
|
def test_persistent_relative_resolvers_reject_event_styles():
|
|
with pytest.raises(ValueError, match="event-only"):
|
|
active_at(1.0, persistent=PERSISTENT, style="vqa")
|
|
with pytest.raises(ValueError, match="event-only"):
|
|
nth_prev(1.0, persistent=PERSISTENT, style="interjection")
|
|
|
|
|
|
def test_nth_prev_and_next():
|
|
assert nth_prev(1.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 0"
|
|
assert nth_next(0.0, persistent=PERSISTENT, style="subtask", offset=1)["content"] == "subtask 1"
|
|
|
|
|
|
def test_substitution_if_present_multimodal_and_tool_calls():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(
|
|
role="user",
|
|
content=[
|
|
{"type": "image", "feature": "observation.images.top"},
|
|
{"type": "text", "text": "${task}: ${interjection}"},
|
|
],
|
|
stream="high_level",
|
|
if_present="interjection",
|
|
),
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${plan}",
|
|
stream="high_level",
|
|
target=True,
|
|
tool_calls_from="speech",
|
|
),
|
|
],
|
|
bindings={"plan": "active_at(t, style=plan)"},
|
|
)
|
|
|
|
rendered = render_sample(
|
|
recipe=recipe,
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_2,
|
|
t=2.0,
|
|
sample_idx=0,
|
|
task="clean kitchen",
|
|
)
|
|
|
|
assert rendered["messages"][0]["content"][1]["text"] == "clean kitchen: skip wiping"
|
|
assert rendered["messages"][1]["content"] == "plan 0"
|
|
assert rendered["messages"][1]["tool_calls"][0]["function"]["name"] == "say"
|
|
assert rendered["message_streams"] == ["high_level", "high_level"]
|
|
assert rendered["target_message_indices"] == [1]
|
|
|
|
|
|
def test_exact_event_miss_returns_none_when_target_skips():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(role="user", content="${vqa_query}", stream="high_level", if_present="vqa_query"),
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${vqa}",
|
|
stream="high_level",
|
|
target=True,
|
|
if_present="vqa",
|
|
),
|
|
]
|
|
)
|
|
|
|
assert (
|
|
render_sample(recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=0) is None
|
|
)
|
|
|
|
|
|
def test_deterministic_blend_sampling():
|
|
recipe = TrainingRecipe(
|
|
blend={
|
|
"a": TrainingRecipe(
|
|
weight=1.0,
|
|
messages=[
|
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
|
MessageTurn(role="assistant", content="a", stream="high_level", target=True),
|
|
],
|
|
),
|
|
"b": TrainingRecipe(
|
|
weight=1.0,
|
|
messages=[
|
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
|
MessageTurn(role="assistant", content="b", stream="high_level", target=True),
|
|
],
|
|
),
|
|
}
|
|
)
|
|
|
|
first = render_sample(
|
|
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
|
)
|
|
second = render_sample(
|
|
recipe=recipe, persistent=PERSISTENT, events=EVENTS_AT_2, t=0.0, sample_idx=123, task="x"
|
|
)
|
|
assert first == second
|
|
|
|
|
|
def test_emitted_at_filters_vqa_by_camera():
|
|
top = emitted_at(
|
|
3.0,
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
|
style="vqa",
|
|
role="assistant",
|
|
camera="observation.images.top",
|
|
)
|
|
wrist = emitted_at(
|
|
3.0,
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
|
style="vqa",
|
|
role="assistant",
|
|
camera="observation.images.wrist",
|
|
)
|
|
assert top["content"] == '{"count": 3}'
|
|
assert wrist["content"] == '{"count": 1}'
|
|
|
|
|
|
def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
|
with pytest.raises(ValueError, match="Ambiguous resolver"):
|
|
emitted_at(
|
|
3.0,
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
|
style="vqa",
|
|
role="assistant",
|
|
)
|
|
|
|
|
|
def test_per_camera_blend_renders_both_views():
|
|
recipe = TrainingRecipe(
|
|
blend={
|
|
"top": TrainingRecipe(
|
|
weight=1.0,
|
|
bindings={
|
|
"vqa_query": (
|
|
"emitted_at(t, style=vqa, role=user, camera=observation.images.top)"
|
|
),
|
|
"vqa": (
|
|
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"
|
|
),
|
|
},
|
|
messages=[
|
|
MessageTurn(
|
|
role="user",
|
|
content=[
|
|
{"type": "image", "feature": "observation.images.top"},
|
|
{"type": "text", "text": "${vqa_query}"},
|
|
],
|
|
stream="high_level",
|
|
if_present="vqa_query",
|
|
),
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${vqa}",
|
|
stream="high_level",
|
|
target=True,
|
|
if_present="vqa",
|
|
),
|
|
],
|
|
),
|
|
"wrist": TrainingRecipe(
|
|
weight=1.0,
|
|
bindings={
|
|
"vqa_query": (
|
|
"emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"
|
|
),
|
|
"vqa": (
|
|
"emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"
|
|
),
|
|
},
|
|
messages=[
|
|
MessageTurn(
|
|
role="user",
|
|
content=[
|
|
{"type": "image", "feature": "observation.images.wrist"},
|
|
{"type": "text", "text": "${vqa_query}"},
|
|
],
|
|
stream="high_level",
|
|
if_present="vqa_query",
|
|
),
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${vqa}",
|
|
stream="high_level",
|
|
target=True,
|
|
if_present="vqa",
|
|
),
|
|
],
|
|
),
|
|
}
|
|
)
|
|
|
|
rendered_top = render_sample(
|
|
recipe=recipe.blend["top"],
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
|
t=3.0,
|
|
sample_idx=0,
|
|
)
|
|
rendered_wrist = render_sample(
|
|
recipe=recipe.blend["wrist"],
|
|
persistent=PERSISTENT,
|
|
events=EVENTS_AT_3_TWO_CAMERAS,
|
|
t=3.0,
|
|
sample_idx=0,
|
|
)
|
|
|
|
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
|
|
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
|
|
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
|
|
|
|
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
|
|
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
|
|
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
|
|
|
|
|
def test_canonical_recipe_can_render_low_level_branch():
|
|
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/pi05_hirobot.yaml"))
|
|
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
|
|
|
rendered = render_sample(
|
|
recipe=low_level,
|
|
persistent=PERSISTENT,
|
|
events=[],
|
|
t=0.5,
|
|
sample_idx=0,
|
|
task="clean kitchen",
|
|
)
|
|
|
|
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
|
assert rendered["message_streams"][-1] == "low_level"
|
|
assert rendered["target_message_indices"] == [1]
|