mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
24d2ffe3c6
`lerobot.processor` re-exported `RenderMessagesStep` at the package
level, so importing anything from `lerobot.processor` pulled in
`lerobot.datasets.language` → `lerobot.datasets/__init__.py` →
`require_package("datasets")`, which fails in the Tier 1 base install
that intentionally omits the `[dataset]` extra. The chain bricked
collection for unrelated suites (`tests/policies/pi0_pi05/...`,
`tests/envs/...`, etc.).
* Stop re-exporting `RenderMessagesStep` from `lerobot.processor`. The
only consumer (the test) already imports from the submodule.
Document the deliberate omission in the module docstring.
* Add `pytest.importorskip("datasets", ...)` (and `pandas` where
needed) at the top of the four PR-added tests that exercise the
language stack:
- tests/datasets/test_language.py
- tests/datasets/test_language_render.py
- tests/processor/test_render_messages_processor.py
- tests/utils/test_collate.py
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
41 lines
1.3 KiB
Python
41 lines
1.3 KiB
Python
#!/usr/bin/env python
|
|
|
|
import pytest
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
import torch # noqa: E402
|
|
|
|
from lerobot.utils.collate import lerobot_collate_fn # noqa: E402
|
|
|
|
|
|
def test_lerobot_collate_preserves_messages_and_drops_raw_language():
|
|
batch = [
|
|
{
|
|
"index": torch.tensor(0),
|
|
"messages": [{"role": "assistant", "content": "a"}],
|
|
"message_streams": ["low_level"],
|
|
"target_message_indices": [0],
|
|
"language_persistent": [{"content": "raw"}],
|
|
"language_events": [],
|
|
},
|
|
{
|
|
"index": torch.tensor(1),
|
|
"messages": [{"role": "assistant", "content": "b"}],
|
|
"message_streams": ["low_level"],
|
|
"target_message_indices": [0],
|
|
"language_persistent": [{"content": "raw"}],
|
|
"language_events": [],
|
|
},
|
|
]
|
|
|
|
out = lerobot_collate_fn(batch)
|
|
|
|
assert out["index"].tolist() == [0, 1]
|
|
assert out["messages"][0][0]["content"] == "a"
|
|
assert out["messages"][1][0]["content"] == "b"
|
|
assert out["message_streams"] == [["low_level"], ["low_level"]]
|
|
assert out["target_message_indices"] == [[0], [0]]
|
|
assert "language_persistent" not in out
|
|
assert "language_events" not in out
|