mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 14:39:43 +00:00
1ff10b935c
Resolves conflicts from 66 commits on the base branch: * pyproject.toml — keep base's transformers>=5.4.0,<5.6.0; add the sentencepiece-dep entry pi052 (FAST action tokenizer) needs. * policies/__init__.py — keep pi052 export; drop the RewardClassifierConfig export that base removed. * policies/factory.py — docstring list resolution (keep pi052; drop reward_classifier, removed by base). * annotations/steerable_pipeline/executor.py — adopt base's renamed _ensure_annotation_metadata_in_info (it already advertises the say tool); drop pi052's older _ensure_tools_in_info call. * configs/train.py — keep pi052's vqa_target_fraction; adopt base's SampleWeightingConfig (legacy RA-BC inline params already covered by the migration shim base added). * scripts/lerobot_train.py — merge pi052's per-policy processor rebuild + dataset_repo_id pass-through with base's active_cfg / is_reward_model_training tightening, and re-route vqa-weighted sampler to active_cfg.drop_n_last_frames. * datasets/language_render.py — adopt base's _select_one + timestamp tolerance (drops pi052's stale _select_latest / per-style sort_key). * tests — adopt base's parametrized per-camera blend + tolerance test; drop pi052 tests that overlap with base's tighter rewrites; keep pi052's flow-only / VQA-blend coverage; add a test_canonical_recipe_loads check on subtask_mem_vqa_speech.yaml. * policies/pi052/processor_pi052.py — import RenderMessagesStep directly from render_messages_processor (base intentionally dropped it from lerobot.processor's re-exports). * uv.lock — regenerated cleanly from base + pi052's pocket-tts / beartype. All 67 touched tests pass (30 pi052 + 37 recipe / language-render / pipeline / render-messages). Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
128 lines
4.2 KiB
Python
128 lines
4.2 KiB
Python
#!/usr/bin/env python
|
|
|
|
import pytest
|
|
|
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
|
|
|
import torch # noqa: E402
|
|
|
|
from lerobot.configs.recipe import MessageTurn, TrainingRecipe # noqa: E402
|
|
from lerobot.processor.converters import create_transition # noqa: E402
|
|
from lerobot.processor.render_messages_processor import RenderMessagesStep # noqa: E402
|
|
from lerobot.types import TransitionKey # noqa: E402
|
|
|
|
|
|
def test_render_messages_step_noops_without_language_columns():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
|
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
|
]
|
|
)
|
|
transition = create_transition(complementary_data={"task": "do it"})
|
|
|
|
assert RenderMessagesStep(recipe)(transition) == transition
|
|
|
|
|
|
def test_render_messages_step_renders_and_drops_raw_language():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(role="user", content="${task}", stream="high_level"),
|
|
MessageTurn(role="assistant", content="${subtask}", stream="low_level", target=True),
|
|
]
|
|
)
|
|
transition = create_transition(
|
|
complementary_data={
|
|
"task": "do it",
|
|
"timestamp": torch.tensor(0.0),
|
|
"index": torch.tensor(7),
|
|
"language_persistent": [
|
|
{
|
|
"role": "assistant",
|
|
"content": "reach carefully",
|
|
"style": "subtask",
|
|
"timestamp": 0.0,
|
|
"camera": None,
|
|
"tool_calls": None,
|
|
}
|
|
],
|
|
"language_events": [],
|
|
}
|
|
)
|
|
|
|
out = RenderMessagesStep(recipe)(transition)
|
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
assert "language_persistent" not in data
|
|
assert "language_events" not in data
|
|
assert data["messages"][-1]["content"] == "reach carefully"
|
|
assert data["message_streams"] == ["high_level", "low_level"]
|
|
assert data["target_message_indices"] == [1]
|
|
|
|
|
|
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${subtask}",
|
|
stream="high_level",
|
|
target=True,
|
|
if_present="subtask",
|
|
),
|
|
]
|
|
)
|
|
transition = create_transition(
|
|
complementary_data={
|
|
"task": "pick the cube",
|
|
"timestamp": torch.tensor(0.0),
|
|
"index": torch.tensor(7),
|
|
"language_persistent": [],
|
|
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
|
|
}
|
|
)
|
|
|
|
out = RenderMessagesStep(recipe)(transition)
|
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
|
|
assert data["message_streams"] == ["low_level"]
|
|
assert data["target_message_indices"] == []
|
|
|
|
|
|
def test_render_messages_step_falls_back_per_sample_in_batched_language():
|
|
recipe = TrainingRecipe(
|
|
messages=[
|
|
MessageTurn(
|
|
role="assistant",
|
|
content="${subtask}",
|
|
stream="high_level",
|
|
target=True,
|
|
if_present="subtask",
|
|
),
|
|
]
|
|
)
|
|
transition = create_transition(
|
|
action=torch.arange(4).reshape(2, 2),
|
|
complementary_data={
|
|
"task": ["pick the cube", "open the drawer"],
|
|
"timestamp": torch.tensor([0.0, 1.0]),
|
|
"index": torch.tensor([7, 8]),
|
|
"language_persistent": [[], []],
|
|
"language_events": [
|
|
[{"style": "unmatched", "timestamp": 0.0}],
|
|
[{"style": "unmatched", "timestamp": 1.0}],
|
|
],
|
|
},
|
|
)
|
|
|
|
out = RenderMessagesStep(recipe)(transition)
|
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
assert data["messages"] == [
|
|
[{"role": "user", "content": "pick the cube"}],
|
|
[{"role": "user", "content": "open the drawer"}],
|
|
]
|
|
assert data["message_streams"] == [["low_level"], ["low_level"]]
|
|
assert data["target_message_indices"] == [[], []]
|