fix(language): address review — tools accessor, motion docs, conditional collate

* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
  had no `tools` field, so `from_dict` silently dropped the key (it
  warned about unknown fields then discarded them) and the property
  always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
  to the dataclass; `to_dict()` drops it when unset so existing
  datasets keep a clean `info.json`. Fixed the accessor to read
  `self.info.tools` (the previous `.get(...)` would have raised
  AttributeError on the dataclass anyway). Added regression tests:
  fallback when absent, round-trip from disk, and round-trip
  through `DatasetInfo.from_dict` / `to_dict`.

* **`motion` is not view-dependent — fix the docs.** The mdx claimed
  rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
  = {"vqa", "trace"}` and the validator agrees: motion primitives are
  joint/Cartesian-frame, not pixel-space. Updated both call-out
  paragraphs in `language_and_recipes.mdx`.

* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
  and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
  so non-language datasets keep PyTorch's `default_collate`. Also
  added a pass-through test in `test_collate.py` that asserts on a
  plain tensor batch the custom collate matches `default_collate`
  key-for-key, plus a test for the `None`-sample drop path.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-06 14:51:06 +02:00
parent 24d2ffe3c6
commit d55b581ca1
6 changed files with 156 additions and 8 deletions
+6 -5
View File
@@ -46,10 +46,11 @@ tool_calls: list[Json] | null
``` ```
The `camera` field tags rows whose `content` is grounded in a specific camera The `camera` field tags rows whose `content` is grounded in a specific camera
view. Rows of view-dependent styles (`vqa`, and the reserved `motion` / view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
`trace`) MUST set `camera` to the matching `observation.images.*` feature key. the matching `observation.images.*` feature key. Rows of every other style —
Rows of every other style MUST leave `camera` as `null`. Pipeline writers and including `motion`, which describes robot-frame primitives in joint / Cartesian
the validator enforce this via `validate_camera_field(style, camera)`. terms — MUST leave `camera` as `null`. Pipeline writers and the validator
enforce this via `validate_camera_field(style, camera)`.
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations. `meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
@@ -81,7 +82,7 @@ Exact event matching has no tolerance window, so writers must stamp event rows w
### View-dependent resolution ### View-dependent resolution
For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a For view-dependent styles (`vqa` and `trace`), the resolver gains a
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple `camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
camera at the same timestamp; without `camera=`, those resolvers see two camera at the same timestamp; without `camera=`, those resolvers see two
+13 -2
View File
@@ -316,6 +316,17 @@ class LeRobotDatasetMetadata:
"""Keys to access visual modalities (regardless of their storage method).""" """Keys to access visual modalities (regardless of their storage method)."""
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]] return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
@property
def has_language_columns(self) -> bool:
"""Return ``True`` if the dataset declares any language column.
Used to gate language-aware code paths (collate, render step) so
unannotated datasets keep PyTorch's default collate behavior.
"""
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
return any(col in self.features for col in LANGUAGE_COLUMNS)
@property @property
def tools(self) -> list[dict]: def tools(self) -> list[dict]:
"""OpenAI-style tool schemas declared by this dataset. """OpenAI-style tool schemas declared by this dataset.
@@ -333,8 +344,8 @@ class LeRobotDatasetMetadata:
""" """
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import) from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
declared = self.info.get("tools") declared = self.info.tools
if isinstance(declared, list) and declared: if declared:
return [dict(t) for t in declared] return [dict(t) for t in declared]
return [dict(t) for t in DEFAULT_TOOLS] return [dict(t) for t in DEFAULT_TOOLS]
+7
View File
@@ -129,6 +129,9 @@ class DatasetInfo:
# Optional metadata # Optional metadata
robot_type: str | None = None robot_type: str | None = None
splits: dict[str, str] = field(default_factory=dict) splits: dict[str, str] = field(default_factory=dict)
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
tools: list[dict] | None = None
def __post_init__(self) -> None: def __post_init__(self) -> None:
# Coerce feature shapes from list to tuple — JSON deserialisation # Coerce feature shapes from list to tuple — JSON deserialisation
@@ -150,11 +153,15 @@ class DatasetInfo:
"""Return a JSON-serialisable dict. """Return a JSON-serialisable dict.
Converts tuple shapes back to lists so ``json.dump`` can handle them. Converts tuple shapes back to lists so ``json.dump`` can handle them.
Drops ``tools`` when unset so existing datasets keep a clean
``info.json``.
""" """
d = dataclasses.asdict(self) d = dataclasses.asdict(self)
for ft in d["features"].values(): for ft in d["features"].values():
if isinstance(ft.get("shape"), tuple): if isinstance(ft.get("shape"), tuple):
ft["shape"] = list(ft["shape"]) ft["shape"] = list(ft["shape"])
if d.get("tools") is None:
d.pop("tools", None)
return d return d
@classmethod @classmethod
+5 -1
View File
@@ -402,6 +402,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
shuffle = True shuffle = True
sampler = None sampler = None
# Only swap in the language-aware collate when the dataset actually
# declares language columns; otherwise stay on PyTorch's default
# collate so non-language training runs are unaffected.
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=cfg.num_workers, num_workers=cfg.num_workers,
@@ -410,7 +414,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
sampler=sampler, sampler=sampler,
pin_memory=device.type == "cuda", pin_memory=device.type == "cuda",
drop_last=False, drop_last=False,
collate_fn=lerobot_collate_fn, collate_fn=collate_fn,
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None, prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0, persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
) )
+81
View File
@@ -385,3 +385,84 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
assert episodes_dir.exists() assert episodes_dir.exists()
parquet_files = list(episodes_dir.rglob("*.parquet")) parquet_files = list(episodes_dir.rglob("*.parquet"))
assert len(parquet_files) > 0 assert len(parquet_files) > 0
# ── Tools accessor ───────────────────────────────────────────────────
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
from lerobot.datasets.language import DEFAULT_TOOLS
root = tmp_path / "no_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/no_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
assert meta.tools == DEFAULT_TOOLS
# info.json on disk should NOT include a `tools` key for clean datasets
with open(root / INFO_PATH) as f:
info_on_disk = json.load(f)
assert "tools" not in info_on_disk
def test_tools_reads_declared_tools_from_info_json(tmp_path):
"""A `tools` list written into info.json survives load → meta.tools.
Regression test for the bug where ``DatasetInfo.from_dict`` silently
dropped the ``tools`` key (no matching dataclass field), so
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
what was on disk.
"""
from lerobot.datasets.io_utils import load_info
root = tmp_path / "with_tools"
meta = LeRobotDatasetMetadata.create(
repo_id="test/with_tools",
fps=DEFAULT_FPS,
features=SIMPLE_FEATURES,
root=root,
use_videos=False,
)
custom_tool = {
"type": "function",
"function": {
"name": "record_observation",
"description": "Capture a still image.",
"parameters": {
"type": "object",
"properties": {"label": {"type": "string"}},
"required": ["label"],
},
},
}
info_path = root / INFO_PATH
with open(info_path) as f:
raw = json.load(f)
raw["tools"] = [custom_tool]
with open(info_path, "w") as f:
json.dump(raw, f)
# Reload info from disk and rebind it on the metadata object
meta.info = load_info(root)
assert meta.tools == [custom_tool]
def test_tools_round_trip_through_dataset_info(tmp_path):
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
from lerobot.datasets.utils import DatasetInfo
raw = {
"codebase_version": "v3.1",
"fps": 30,
"features": SIMPLE_FEATURES,
"tools": [{"type": "function", "function": {"name": "say"}}],
}
info = DatasetInfo.from_dict(raw)
assert info.tools == raw["tools"]
assert info.to_dict()["tools"] == raw["tools"]
+44
View File
@@ -38,3 +38,47 @@ def test_lerobot_collate_preserves_messages_and_drops_raw_language():
assert out["target_message_indices"] == [[0], [0]] assert out["target_message_indices"] == [[0], [0]]
assert "language_persistent" not in out assert "language_persistent" not in out
assert "language_events" not in out assert "language_events" not in out
def test_lerobot_collate_passes_through_standard_batch():
"""On a non-language batch, the collate must match ``default_collate``.
Guards against silent regressions: ``lerobot_train.py`` only opts into
``lerobot_collate_fn`` when the dataset declares language columns, but
if a future change ever wires it in unconditionally we want the
behavior to remain a transparent pass-through for ordinary tensor
batches.
"""
from torch.utils.data._utils.collate import default_collate
batch = [
{
"observation.image": torch.zeros(3, 4, 4),
"action": torch.tensor([0.0, 1.0]),
"index": torch.tensor(0),
},
{
"observation.image": torch.ones(3, 4, 4),
"action": torch.tensor([2.0, 3.0]),
"index": torch.tensor(1),
},
]
custom = lerobot_collate_fn(batch)
expected = default_collate(batch)
assert custom.keys() == expected.keys()
for key in expected:
assert torch.equal(custom[key], expected[key]), f"key={key} diverged"
def test_lerobot_collate_drops_none_samples():
"""Recipes that yielded no target message return ``None`` — those samples
must be filtered out, and an entirely-``None`` batch must collapse to ``None``.
"""
batch = [None, {"index": torch.tensor(0)}, None]
out = lerobot_collate_fn(batch)
assert out is not None
assert out["index"].tolist() == [0]
assert lerobot_collate_fn([None, None]) is None