mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
Merge branch 'feat/language-annotation-pipeline' into feat/smolvla-on-steerable
Resolves conflicts from 66 commits on the base branch: * pyproject.toml — keep base's transformers>=5.4.0,<5.6.0; add the sentencepiece-dep entry pi052 (FAST action tokenizer) needs. * policies/__init__.py — keep pi052 export; drop the RewardClassifierConfig export that base removed. * policies/factory.py — docstring list resolution (keep pi052; drop reward_classifier, removed by base). * annotations/steerable_pipeline/executor.py — adopt base's renamed _ensure_annotation_metadata_in_info (it already advertises the say tool); drop pi052's older _ensure_tools_in_info call. * configs/train.py — keep pi052's vqa_target_fraction; adopt base's SampleWeightingConfig (legacy RA-BC inline params already covered by the migration shim base added). * scripts/lerobot_train.py — merge pi052's per-policy processor rebuild + dataset_repo_id pass-through with base's active_cfg / is_reward_model_training tightening, and re-route vqa-weighted sampler to active_cfg.drop_n_last_frames. * datasets/language_render.py — adopt base's _select_one + timestamp tolerance (drops pi052's stale _select_latest / per-style sort_key). * tests — adopt base's parametrized per-camera blend + tolerance test; drop pi052 tests that overlap with base's tighter rewrites; keep pi052's flow-only / VQA-blend coverage; add a test_canonical_recipe_loads check on subtask_mem_vqa_speech.yaml. * policies/pi052/processor_pi052.py — import RenderMessagesStep directly from render_messages_processor (base intentionally dropped it from lerobot.processor's re-exports). * uv.lock — regenerated cleanly from base + pi052's pocket-tts / beartype. All 67 touched tests pass (30 pi052 + 37 recipe / language-render / pipeline / render-messages). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -113,7 +113,7 @@ def assert_metadata_consistency(aggr_ds, ds_0, ds_1):
|
||||
"""Test that metadata is correctly aggregated."""
|
||||
# Test basic info
|
||||
assert aggr_ds.fps == ds_0.fps == ds_1.fps, "FPS should be the same across all datasets"
|
||||
assert aggr_ds.meta.info["robot_type"] == ds_0.meta.info["robot_type"] == ds_1.meta.info["robot_type"], (
|
||||
assert aggr_ds.meta.info.robot_type == ds_0.meta.info.robot_type == ds_1.meta.info.robot_type, (
|
||||
"Robot type should be the same"
|
||||
)
|
||||
|
||||
@@ -153,8 +153,8 @@ def assert_video_frames_integrity(aggr_ds, ds_0, ds_1):
|
||||
|
||||
video_keys = list(
|
||||
filter(
|
||||
lambda key: aggr_ds.meta.info["features"][key]["dtype"] == "video",
|
||||
aggr_ds.meta.info["features"].keys(),
|
||||
lambda key: aggr_ds.meta.info.features[key]["dtype"] == "video",
|
||||
aggr_ds.meta.info.features.keys(),
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -161,7 +161,7 @@ def test_init_loads_existing_metadata(tmp_path, lerobot_dataset_metadata_factory
|
||||
|
||||
assert meta.total_episodes == 3
|
||||
assert meta.total_frames == 150
|
||||
assert meta.fps == info["fps"]
|
||||
assert meta.fps == info.fps
|
||||
|
||||
|
||||
# ── Property accessors ───────────────────────────────────────────────
|
||||
@@ -385,3 +385,84 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
|
||||
assert episodes_dir.exists()
|
||||
parquet_files = list(episodes_dir.rglob("*.parquet"))
|
||||
assert len(parquet_files) > 0
|
||||
|
||||
|
||||
# ── Tools accessor ───────────────────────────────────────────────────
|
||||
|
||||
|
||||
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
|
||||
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
|
||||
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||
|
||||
root = tmp_path / "no_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/no_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
assert meta.tools == DEFAULT_TOOLS
|
||||
# info.json on disk should NOT include a `tools` key for clean datasets
|
||||
with open(root / INFO_PATH) as f:
|
||||
info_on_disk = json.load(f)
|
||||
assert "tools" not in info_on_disk
|
||||
|
||||
|
||||
def test_tools_reads_declared_tools_from_info_json(tmp_path):
|
||||
"""A `tools` list written into info.json survives load → meta.tools.
|
||||
|
||||
Regression test for the bug where ``DatasetInfo.from_dict`` silently
|
||||
dropped the ``tools`` key (no matching dataclass field), so
|
||||
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
|
||||
what was on disk.
|
||||
"""
|
||||
from lerobot.datasets.io_utils import load_info
|
||||
|
||||
root = tmp_path / "with_tools"
|
||||
meta = LeRobotDatasetMetadata.create(
|
||||
repo_id="test/with_tools",
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=root,
|
||||
use_videos=False,
|
||||
)
|
||||
|
||||
custom_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "record_observation",
|
||||
"description": "Capture a still image.",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {"label": {"type": "string"}},
|
||||
"required": ["label"],
|
||||
},
|
||||
},
|
||||
}
|
||||
info_path = root / INFO_PATH
|
||||
with open(info_path) as f:
|
||||
raw = json.load(f)
|
||||
raw["tools"] = [custom_tool]
|
||||
with open(info_path, "w") as f:
|
||||
json.dump(raw, f)
|
||||
|
||||
# Reload info from disk and rebind it on the metadata object
|
||||
meta.info = load_info(root)
|
||||
assert meta.tools == [custom_tool]
|
||||
|
||||
|
||||
def test_tools_round_trip_through_dataset_info(tmp_path):
|
||||
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
|
||||
from lerobot.datasets.utils import DatasetInfo
|
||||
|
||||
raw = {
|
||||
"codebase_version": "v3.1",
|
||||
"fps": 30,
|
||||
"features": SIMPLE_FEATURES,
|
||||
"tools": [{"type": "function", "function": {"name": "say"}}],
|
||||
}
|
||||
info = DatasetInfo.from_dict(raw)
|
||||
assert info.tools == raw["tools"]
|
||||
assert info.to_dict()["tools"] == raw["tools"]
|
||||
|
||||
@@ -1,13 +1,17 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import pyarrow as pa
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.datasets.io_utils import write_info
|
||||
from lerobot.datasets.language import (
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
pytest.importorskip("pandas", reason="pandas is required (install lerobot[dataset])")
|
||||
|
||||
import numpy as np # noqa: E402
|
||||
import pandas as pd # noqa: E402
|
||||
import pyarrow as pa # noqa: E402
|
||||
|
||||
from lerobot.datasets import LeRobotDataset # noqa: E402
|
||||
from lerobot.datasets.io_utils import write_info # noqa: E402
|
||||
from lerobot.datasets.language import ( # noqa: E402
|
||||
EVENT_ONLY_STYLES,
|
||||
LANGUAGE_EVENTS,
|
||||
LANGUAGE_PERSISTENT,
|
||||
@@ -21,7 +25,7 @@ from lerobot.datasets.language import (
|
||||
language_persistent_arrow_type,
|
||||
validate_camera_field,
|
||||
)
|
||||
from lerobot.datasets.utils import DEFAULT_DATA_PATH
|
||||
from lerobot.datasets.utils import DEFAULT_DATA_PATH # noqa: E402
|
||||
|
||||
|
||||
def test_language_arrow_schema_has_expected_fields():
|
||||
|
||||
@@ -1,11 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe
|
||||
from lerobot.datasets.language_render import active_at, emitted_at, nth_next, nth_prev, render_sample
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
||||
from lerobot.datasets.language_render import ( # noqa: E402
|
||||
EMITTED_AT_TOLERANCE_S,
|
||||
active_at,
|
||||
emitted_at,
|
||||
nth_next,
|
||||
nth_prev,
|
||||
render_sample,
|
||||
)
|
||||
|
||||
|
||||
def persistent_row(role, content, style, timestamp, tool_calls=None, camera=None):
|
||||
@@ -201,84 +208,50 @@ def test_emitted_at_raises_on_ambiguous_per_camera_vqa():
|
||||
)
|
||||
|
||||
|
||||
def test_per_camera_blend_renders_both_views():
|
||||
recipe = TrainingRecipe(
|
||||
blend={
|
||||
"top": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.top)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.top)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.top"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
def _vqa_subrecipe(camera: str) -> TrainingRecipe:
|
||||
return TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": f"emitted_at(t, style=vqa, role=user, camera={camera})",
|
||||
"vqa": f"emitted_at(t, style=vqa, role=assistant, camera={camera})",
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[{"type": "image", "feature": camera}, {"type": "text", "text": "${vqa_query}"}],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
"wrist": TrainingRecipe(
|
||||
weight=1.0,
|
||||
bindings={
|
||||
"vqa_query": ("emitted_at(t, style=vqa, role=user, camera=observation.images.wrist)"),
|
||||
"vqa": ("emitted_at(t, style=vqa, role=assistant, camera=observation.images.wrist)"),
|
||||
},
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content=[
|
||||
{"type": "image", "feature": "observation.images.wrist"},
|
||||
{"type": "text", "text": "${vqa_query}"},
|
||||
],
|
||||
stream="high_level",
|
||||
if_present="vqa_query",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
],
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${vqa}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="vqa",
|
||||
),
|
||||
}
|
||||
],
|
||||
)
|
||||
|
||||
rendered_top = render_sample(
|
||||
recipe=recipe.blend["top"],
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
rendered_wrist = render_sample(
|
||||
recipe=recipe.blend["wrist"],
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
("camera", "expected_query", "expected_answer"),
|
||||
[
|
||||
("observation.images.top", "how many cups (top)?", '{"count": 3}'),
|
||||
("observation.images.wrist", "how many cups (wrist)?", '{"count": 1}'),
|
||||
],
|
||||
)
|
||||
def test_per_camera_blend_renders_both_views(camera, expected_query, expected_answer):
|
||||
rendered = render_sample(
|
||||
recipe=_vqa_subrecipe(camera),
|
||||
persistent=PERSISTENT,
|
||||
events=EVENTS_AT_3_TWO_CAMERAS,
|
||||
t=3.0,
|
||||
sample_idx=0,
|
||||
)
|
||||
|
||||
assert rendered_top["messages"][0]["content"][0]["feature"] == "observation.images.top"
|
||||
assert rendered_top["messages"][0]["content"][1]["text"] == "how many cups (top)?"
|
||||
assert rendered_top["messages"][1]["content"] == '{"count": 3}'
|
||||
|
||||
assert rendered_wrist["messages"][0]["content"][0]["feature"] == "observation.images.wrist"
|
||||
assert rendered_wrist["messages"][0]["content"][1]["text"] == "how many cups (wrist)?"
|
||||
assert rendered_wrist["messages"][1]["content"] == '{"count": 1}'
|
||||
assert rendered["messages"][0]["content"][0]["feature"] == camera
|
||||
assert rendered["messages"][0]["content"][1]["text"] == expected_query
|
||||
assert rendered["messages"][1]["content"] == expected_answer
|
||||
|
||||
|
||||
def test_resolve_task_picks_rephrasing_deterministically_per_sample():
|
||||
@@ -448,12 +421,65 @@ def test_vqa_frame_is_consumed_over_the_weighted_blend():
|
||||
assert rendered["messages"][-1]["content"] == "a subtask"
|
||||
|
||||
|
||||
def test_canonical_recipe_can_render_low_level_branch():
|
||||
"""The shipped ``subtasks_vqa.yaml`` recipe's ``low_level_execution``
|
||||
branch renders — a flow-only ``user(${subtask})`` turn (no text-CE
|
||||
target; its supervision is the action-expert flow loss)."""
|
||||
recipe = TrainingRecipe.from_yaml(Path("src/lerobot/configs/recipes/subtasks_vqa.yaml"))
|
||||
low_level = TrainingRecipe(blend={"low": recipe.blend["low_level_execution"]})
|
||||
def test_emitted_at_persistent_tolerates_small_timestamp_drift():
|
||||
"""Persistent ``emitted_at`` should match within EMITTED_AT_TOLERANCE_S
|
||||
so callers that derive ``t`` arithmetically (``frame_idx / fps``) still
|
||||
line up with the parquet-stored timestamp.
|
||||
"""
|
||||
rows = [persistent_row("assistant", "memo", "memory", 1.0)]
|
||||
# Half a tolerance window — bit-different float, comfortably inside
|
||||
inside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S / 2, persistent=rows, events=[], style="memory")
|
||||
assert inside is not None and inside["content"] == "memo"
|
||||
|
||||
# Just past the window — no match
|
||||
outside = emitted_at(1.0 + EMITTED_AT_TOLERANCE_S * 2, persistent=rows, events=[], style="memory")
|
||||
assert outside is None
|
||||
|
||||
|
||||
def test_render_sample_rejects_non_dict_language_rows():
|
||||
"""``_normalize_rows`` must surface malformed inputs as TypeError.
|
||||
|
||||
A pipeline that hands the renderer a non-dict (e.g. a stray string)
|
||||
is a real upstream bug — silent skipping would let it propagate.
|
||||
"""
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(role="user", content="${task}", stream="high_level"),
|
||||
MessageTurn(role="assistant", content="ok", stream="high_level", target=True),
|
||||
]
|
||||
)
|
||||
with pytest.raises(TypeError, match="must be dictionaries"):
|
||||
render_sample(
|
||||
recipe=recipe,
|
||||
persistent=["not a dict"],
|
||||
events=[],
|
||||
t=0.0,
|
||||
sample_idx=0,
|
||||
task="x",
|
||||
)
|
||||
|
||||
|
||||
def test_low_level_branch_renders_active_subtask():
|
||||
low_level = TrainingRecipe(
|
||||
blend={
|
||||
"low": TrainingRecipe(
|
||||
weight=1.0,
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="user",
|
||||
content="${task}\nPlan: ${plan}\nMemory: ${memory}",
|
||||
stream="high_level",
|
||||
),
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="low_level",
|
||||
target=True,
|
||||
),
|
||||
],
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
rendered = render_sample(
|
||||
recipe=low_level,
|
||||
@@ -464,6 +490,6 @@ def test_canonical_recipe_can_render_low_level_branch():
|
||||
task="clean kitchen",
|
||||
)
|
||||
|
||||
assert rendered["messages"][-1] == {"role": "user", "content": "subtask 0"}
|
||||
assert rendered["messages"][-1] == {"role": "assistant", "content": "subtask 0"}
|
||||
assert rendered["message_streams"][-1] == "low_level"
|
||||
assert rendered["target_message_indices"] == []
|
||||
assert rendered["target_message_indices"] == [1]
|
||||
|
||||
@@ -80,18 +80,18 @@ def _write_dataset_tree(
|
||||
)
|
||||
tasks = tasks_factory(total_tasks=1)
|
||||
episodes = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
features=info.features,
|
||||
fps=info.fps,
|
||||
total_episodes=1,
|
||||
total_frames=3,
|
||||
tasks=tasks,
|
||||
)
|
||||
stats = stats_factory(features=info["features"])
|
||||
stats = stats_factory(features=info.features)
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"],
|
||||
features=info.features,
|
||||
tasks=tasks,
|
||||
episodes=episodes,
|
||||
fps=info["fps"],
|
||||
fps=info.fps,
|
||||
)
|
||||
|
||||
create_info(root, info)
|
||||
@@ -416,6 +416,18 @@ def test_create_initial_counts_zero(tmp_path):
|
||||
assert dataset.num_frames == 0
|
||||
|
||||
|
||||
def test_create_propagates_video_files_size_in_mb(tmp_path):
|
||||
"""video_files_size_in_mb passed to create() is reflected in the dataset metadata."""
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=DUMMY_REPO_ID,
|
||||
fps=DEFAULT_FPS,
|
||||
features=SIMPLE_FEATURES,
|
||||
root=tmp_path / "ds",
|
||||
video_files_size_in_mb=42.0,
|
||||
)
|
||||
assert dataset.meta.video_files_size_in_mb == 42.0
|
||||
|
||||
|
||||
def test_add_frame_works_in_write_mode(tmp_path):
|
||||
"""add_frame() succeeds on a dataset created via create()."""
|
||||
dataset = LeRobotDataset.create(
|
||||
|
||||
Reference in New Issue
Block a user