fix(language): address review — tools accessor, motion docs, conditional collate

* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
  had no `tools` field, so `from_dict` silently dropped the key (it
  warned about unknown fields then discarded them) and the property
  always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
  to the dataclass; `to_dict()` drops it when unset so existing
  datasets keep a clean `info.json`. Fixed the accessor to read
  `self.info.tools` (the previous `.get(...)` would have raised
  AttributeError on the dataclass anyway). Added regression tests:
  fallback when absent, round-trip from disk, and round-trip
  through `DatasetInfo.from_dict` / `to_dict`.

* **`motion` is not view-dependent — fix the docs.** The mdx claimed
  rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
  = {"vqa", "trace"}` and the validator agrees: motion primitives are
  joint/Cartesian-frame, not pixel-space. Updated both call-out
  paragraphs in `language_and_recipes.mdx`.

* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
  and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
  so non-language datasets keep PyTorch's `default_collate`. Also
  added a pass-through test in `test_collate.py` that asserts on a
  plain tensor batch the custom collate matches `default_collate`
  key-for-key, plus a test for the `None`-sample drop path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-06 14:51:06 +02:00
parent 24d2ffe3c6
commit d55b581ca1
6 changed files with 156 additions and 8 deletions
+81
View File
@@ -385,3 +385,84 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0
# ── Tools accessor ───────────────────────────────────────────────────
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "no_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
assert meta.tools == DEFAULT_TOOLS
# info.json on disk should NOT include a `tools` key for clean datasets
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
def test_tools_reads_declared_tools_from_info_json(tmp_path):
"""A `tools` list written into info.json survives load → meta.tools.
Regression test for the bug where ``DatasetInfo.from_dict`` silently
dropped the ``tools`` key (no matching dataclass field), so
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
what was on disk.
"""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "with_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/with_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
info_path = root / INFO_PATH
with open(info_path) as f:
raw = json.load(f)
raw["tools"] = [custom_tool]
with open(info_path, "w") as f:
json.dump(raw, f)
# Reload info from disk and rebind it on the metadata object
meta.info = load_info(root)
assert meta.tools == [custom_tool]
def test_tools_round_trip_through_dataset_info(tmp_path):
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
from lerobot.datasets.utils import DatasetInfo
raw = {
"codebase_version": "v3.1",
"fps": 30,
"features": SIMPLE_FEATURES,
"tools": [{"type": "function", "function": {"name": "say"}}],
}
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]
+44
View File
@@ -38,3 +38,47 @@ def test_lerobot_collate_preserves_messages_and_drops_raw_language():
assert out["target_message_indices"] == [[0], [0]]
assert "language_persistent" not in out
assert "language_events" not in out
def test_lerobot_collate_passes_through_standard_batch():
"""On a non-language batch, the collate must match ``default_collate``.
Guards against silent regressions: ``lerobot_train.py`` only opts into
``lerobot_collate_fn`` when the dataset declares language columns, but
if a future change ever wires it in unconditionally we want the
behavior to remain a transparent pass-through for ordinary tensor
batches.
"""
from torch.utils.data._utils.collate import default_collate
batch = [
{
"observation.image": torch.zeros(3, 4, 4),
"action": torch.tensor([0.0, 1.0]),
"index": torch.tensor(0),
},
{
"observation.image": torch.ones(3, 4, 4),
"action": torch.tensor([2.0, 3.0]),
"index": torch.tensor(1),
},
]
custom = lerobot_collate_fn(batch)
expected = default_collate(batch)
assert custom.keys() == expected.keys()
for key in expected:
assert torch.equal(custom[key], expected[key]), f"key={key} diverged"
def test_lerobot_collate_drops_none_samples():
"""Recipes that yielded no target message return ``None`` — those samples
must be filtered out, and an entirely-``None`` batch must collapse to ``None``.
"""
batch = [None, {"index": torch.tensor(0)}, None]
out = lerobot_collate_fn(batch)
assert out is not None
assert out["index"].tolist() == [0]
assert lerobot_collate_fn([None, None]) is None