mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
fix(language): address review — tools accessor, motion docs, conditional collate
* **`meta.tools` actually reads `info.json["tools"]`.** `DatasetInfo`
had no `tools` field, so `from_dict` silently dropped the key (it
warned about unknown fields then discarded them) and the property
always returned `DEFAULT_TOOLS`. Added `tools: list[dict] | None`
to the dataclass; `to_dict()` drops it when unset so existing
datasets keep a clean `info.json`. Fixed the accessor to read
`self.info.tools` (the previous `.get(...)` would have raised
AttributeError on the dataclass anyway). Added regression tests:
fallback when absent, round-trip from disk, and round-trip
through `DatasetInfo.from_dict` / `to_dict`.
* **`motion` is not view-dependent — fix the docs.** The mdx claimed
rows of style `motion` must carry `camera`, but `VIEW_DEPENDENT_STYLES
= {"vqa", "trace"}` and the validator agrees: motion primitives are
joint/Cartesian-frame, not pixel-space. Updated both call-out
paragraphs in `language_and_recipes.mdx`.
* **Conditional `collate_fn` swap.** Added `meta.has_language_columns`
and gate the `lerobot_collate_fn` swap in `lerobot_train.py` on it,
so non-language datasets keep PyTorch's `default_collate`. Also
added a pass-through test in `test_collate.py` that asserts on a
plain tensor batch the custom collate matches `default_collate`
key-for-key, plus a test for the `None`-sample drop path.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -46,10 +46,11 @@ tool_calls: list[Json] | null
|
|||||||
```
|
```
|
||||||
|
|
||||||
The `camera` field tags rows whose `content` is grounded in a specific camera
|
The `camera` field tags rows whose `content` is grounded in a specific camera
|
||||||
view. Rows of view-dependent styles (`vqa`, and the reserved `motion` /
|
view. Rows of view-dependent styles (`vqa` and `trace`) MUST set `camera` to
|
||||||
`trace`) MUST set `camera` to the matching `observation.images.*` feature key.
|
the matching `observation.images.*` feature key. Rows of every other style —
|
||||||
Rows of every other style MUST leave `camera` as `null`. Pipeline writers and
|
including `motion`, which describes robot-frame primitives in joint / Cartesian
|
||||||
the validator enforce this via `validate_camera_field(style, camera)`.
|
terms — MUST leave `camera` as `null`. Pipeline writers and the validator
|
||||||
|
enforce this via `validate_camera_field(style, camera)`.
|
||||||
|
|
||||||
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
`meta/tasks.parquet` remains the canonical source for the task. The special `${task}` recipe binding always reads that task string and does not depend on language annotations.
|
||||||
|
|
||||||
@@ -81,7 +82,7 @@ Exact event matching has no tolerance window, so writers must stamp event rows w
|
|||||||
|
|
||||||
### View-dependent resolution
|
### View-dependent resolution
|
||||||
|
|
||||||
For view-dependent styles (`vqa`, `motion`, `trace`), the resolver gains a
|
For view-dependent styles (`vqa` and `trace`), the resolver gains a
|
||||||
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
`camera=` filter parallel to `role=` and `tool_name=`. Datasets with multiple
|
||||||
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
cameras typically emit one (`vqa`, `user`) + (`vqa`, `assistant`) pair per
|
||||||
camera at the same timestamp; without `camera=`, those resolvers see two
|
camera at the same timestamp; without `camera=`, those resolvers see two
|
||||||
|
|||||||
@@ -316,6 +316,17 @@ class LeRobotDatasetMetadata:
|
|||||||
"""Keys to access visual modalities (regardless of their storage method)."""
|
"""Keys to access visual modalities (regardless of their storage method)."""
|
||||||
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
return [key for key, ft in self.features.items() if ft["dtype"] in ["video", "image"]]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def has_language_columns(self) -> bool:
|
||||||
|
"""Return ``True`` if the dataset declares any language column.
|
||||||
|
|
||||||
|
Used to gate language-aware code paths (collate, render step) so
|
||||||
|
unannotated datasets keep PyTorch's default collate behavior.
|
||||||
|
"""
|
||||||
|
from .language import LANGUAGE_COLUMNS # noqa: PLC0415 (avoid circular import)
|
||||||
|
|
||||||
|
return any(col in self.features for col in LANGUAGE_COLUMNS)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def tools(self) -> list[dict]:
|
def tools(self) -> list[dict]:
|
||||||
"""OpenAI-style tool schemas declared by this dataset.
|
"""OpenAI-style tool schemas declared by this dataset.
|
||||||
@@ -333,8 +344,8 @@ class LeRobotDatasetMetadata:
|
|||||||
"""
|
"""
|
||||||
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
from .language import DEFAULT_TOOLS # noqa: PLC0415 (avoid circular import)
|
||||||
|
|
||||||
declared = self.info.get("tools")
|
declared = self.info.tools
|
||||||
if isinstance(declared, list) and declared:
|
if declared:
|
||||||
return [dict(t) for t in declared]
|
return [dict(t) for t in declared]
|
||||||
return [dict(t) for t in DEFAULT_TOOLS]
|
return [dict(t) for t in DEFAULT_TOOLS]
|
||||||
|
|
||||||
|
|||||||
@@ -129,6 +129,9 @@ class DatasetInfo:
|
|||||||
# Optional metadata
|
# Optional metadata
|
||||||
robot_type: str | None = None
|
robot_type: str | None = None
|
||||||
splits: dict[str, str] = field(default_factory=dict)
|
splits: dict[str, str] = field(default_factory=dict)
|
||||||
|
# OpenAI-style tool schemas declared by the dataset. ``None`` means the
|
||||||
|
# dataset doesn't declare any — readers fall back to ``DEFAULT_TOOLS``.
|
||||||
|
tools: list[dict] | None = None
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
# Coerce feature shapes from list to tuple — JSON deserialisation
|
# Coerce feature shapes from list to tuple — JSON deserialisation
|
||||||
@@ -150,11 +153,15 @@ class DatasetInfo:
|
|||||||
"""Return a JSON-serialisable dict.
|
"""Return a JSON-serialisable dict.
|
||||||
|
|
||||||
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
Converts tuple shapes back to lists so ``json.dump`` can handle them.
|
||||||
|
Drops ``tools`` when unset so existing datasets keep a clean
|
||||||
|
``info.json``.
|
||||||
"""
|
"""
|
||||||
d = dataclasses.asdict(self)
|
d = dataclasses.asdict(self)
|
||||||
for ft in d["features"].values():
|
for ft in d["features"].values():
|
||||||
if isinstance(ft.get("shape"), tuple):
|
if isinstance(ft.get("shape"), tuple):
|
||||||
ft["shape"] = list(ft["shape"])
|
ft["shape"] = list(ft["shape"])
|
||||||
|
if d.get("tools") is None:
|
||||||
|
d.pop("tools", None)
|
||||||
return d
|
return d
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
|
|||||||
@@ -402,6 +402,10 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
shuffle = True
|
shuffle = True
|
||||||
sampler = None
|
sampler = None
|
||||||
|
|
||||||
|
# Only swap in the language-aware collate when the dataset actually
|
||||||
|
# declares language columns; otherwise stay on PyTorch's default
|
||||||
|
# collate so non-language training runs are unaffected.
|
||||||
|
collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||||
dataloader = torch.utils.data.DataLoader(
|
dataloader = torch.utils.data.DataLoader(
|
||||||
dataset,
|
dataset,
|
||||||
num_workers=cfg.num_workers,
|
num_workers=cfg.num_workers,
|
||||||
@@ -410,7 +414,7 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
collate_fn=lerobot_collate_fn,
|
collate_fn=collate_fn,
|
||||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -385,3 +385,84 @@ def test_finalize_flushes_buffered_metadata(tmp_path):
|
|||||||
assert episodes_dir.exists()
|
assert episodes_dir.exists()
|
||||||
parquet_files = list(episodes_dir.rglob("*.parquet"))
|
parquet_files = list(episodes_dir.rglob("*.parquet"))
|
||||||
assert len(parquet_files) > 0
|
assert len(parquet_files) > 0
|
||||||
|
|
||||||
|
|
||||||
|
# ── Tools accessor ───────────────────────────────────────────────────
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_falls_back_to_default_when_info_has_no_tools_field(tmp_path):
|
||||||
|
"""meta.tools returns DEFAULT_TOOLS when info.json doesn't declare any."""
|
||||||
|
from lerobot.datasets.language import DEFAULT_TOOLS
|
||||||
|
|
||||||
|
root = tmp_path / "no_tools"
|
||||||
|
meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id="test/no_tools",
|
||||||
|
fps=DEFAULT_FPS,
|
||||||
|
features=SIMPLE_FEATURES,
|
||||||
|
root=root,
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert meta.tools == DEFAULT_TOOLS
|
||||||
|
# info.json on disk should NOT include a `tools` key for clean datasets
|
||||||
|
with open(root / INFO_PATH) as f:
|
||||||
|
info_on_disk = json.load(f)
|
||||||
|
assert "tools" not in info_on_disk
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_reads_declared_tools_from_info_json(tmp_path):
|
||||||
|
"""A `tools` list written into info.json survives load → meta.tools.
|
||||||
|
|
||||||
|
Regression test for the bug where ``DatasetInfo.from_dict`` silently
|
||||||
|
dropped the ``tools`` key (no matching dataclass field), so
|
||||||
|
``meta.tools`` always returned ``DEFAULT_TOOLS`` regardless of
|
||||||
|
what was on disk.
|
||||||
|
"""
|
||||||
|
from lerobot.datasets.io_utils import load_info
|
||||||
|
|
||||||
|
root = tmp_path / "with_tools"
|
||||||
|
meta = LeRobotDatasetMetadata.create(
|
||||||
|
repo_id="test/with_tools",
|
||||||
|
fps=DEFAULT_FPS,
|
||||||
|
features=SIMPLE_FEATURES,
|
||||||
|
root=root,
|
||||||
|
use_videos=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
custom_tool = {
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": "record_observation",
|
||||||
|
"description": "Capture a still image.",
|
||||||
|
"parameters": {
|
||||||
|
"type": "object",
|
||||||
|
"properties": {"label": {"type": "string"}},
|
||||||
|
"required": ["label"],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
info_path = root / INFO_PATH
|
||||||
|
with open(info_path) as f:
|
||||||
|
raw = json.load(f)
|
||||||
|
raw["tools"] = [custom_tool]
|
||||||
|
with open(info_path, "w") as f:
|
||||||
|
json.dump(raw, f)
|
||||||
|
|
||||||
|
# Reload info from disk and rebind it on the metadata object
|
||||||
|
meta.info = load_info(root)
|
||||||
|
assert meta.tools == [custom_tool]
|
||||||
|
|
||||||
|
|
||||||
|
def test_tools_round_trip_through_dataset_info(tmp_path):
|
||||||
|
"""A `tools` list survives DatasetInfo.from_dict / to_dict."""
|
||||||
|
from lerobot.datasets.utils import DatasetInfo
|
||||||
|
|
||||||
|
raw = {
|
||||||
|
"codebase_version": "v3.1",
|
||||||
|
"fps": 30,
|
||||||
|
"features": SIMPLE_FEATURES,
|
||||||
|
"tools": [{"type": "function", "function": {"name": "say"}}],
|
||||||
|
}
|
||||||
|
info = DatasetInfo.from_dict(raw)
|
||||||
|
assert info.tools == raw["tools"]
|
||||||
|
assert info.to_dict()["tools"] == raw["tools"]
|
||||||
|
|||||||
@@ -38,3 +38,47 @@ def test_lerobot_collate_preserves_messages_and_drops_raw_language():
|
|||||||
assert out["target_message_indices"] == [[0], [0]]
|
assert out["target_message_indices"] == [[0], [0]]
|
||||||
assert "language_persistent" not in out
|
assert "language_persistent" not in out
|
||||||
assert "language_events" not in out
|
assert "language_events" not in out
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_collate_passes_through_standard_batch():
|
||||||
|
"""On a non-language batch, the collate must match ``default_collate``.
|
||||||
|
|
||||||
|
Guards against silent regressions: ``lerobot_train.py`` only opts into
|
||||||
|
``lerobot_collate_fn`` when the dataset declares language columns, but
|
||||||
|
if a future change ever wires it in unconditionally we want the
|
||||||
|
behavior to remain a transparent pass-through for ordinary tensor
|
||||||
|
batches.
|
||||||
|
"""
|
||||||
|
from torch.utils.data._utils.collate import default_collate
|
||||||
|
|
||||||
|
batch = [
|
||||||
|
{
|
||||||
|
"observation.image": torch.zeros(3, 4, 4),
|
||||||
|
"action": torch.tensor([0.0, 1.0]),
|
||||||
|
"index": torch.tensor(0),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"observation.image": torch.ones(3, 4, 4),
|
||||||
|
"action": torch.tensor([2.0, 3.0]),
|
||||||
|
"index": torch.tensor(1),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
custom = lerobot_collate_fn(batch)
|
||||||
|
expected = default_collate(batch)
|
||||||
|
|
||||||
|
assert custom.keys() == expected.keys()
|
||||||
|
for key in expected:
|
||||||
|
assert torch.equal(custom[key], expected[key]), f"key={key} diverged"
|
||||||
|
|
||||||
|
|
||||||
|
def test_lerobot_collate_drops_none_samples():
|
||||||
|
"""Recipes that yielded no target message return ``None`` — those samples
|
||||||
|
must be filtered out, and an entirely-``None`` batch must collapse to ``None``.
|
||||||
|
"""
|
||||||
|
batch = [None, {"index": torch.tensor(0)}, None]
|
||||||
|
out = lerobot_collate_fn(batch)
|
||||||
|
assert out is not None
|
||||||
|
assert out["index"].tolist() == [0]
|
||||||
|
|
||||||
|
assert lerobot_collate_fn([None, None]) is None
|
||||||
|
|||||||
Reference in New Issue
Block a user