mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
fix(processor): fallback to task message when recipe misses
Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick the cube",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [],
|
||||
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
|
||||
assert data["message_streams"] == ["low_level"]
|
||||
assert data["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_per_sample_in_batched_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
action=torch.arange(4).reshape(2, 2),
|
||||
complementary_data={
|
||||
"task": ["pick the cube", "open the drawer"],
|
||||
"timestamp": torch.tensor([0.0, 1.0]),
|
||||
"index": torch.tensor([7, 8]),
|
||||
"language_persistent": [[], []],
|
||||
"language_events": [
|
||||
[{"style": "unmatched", "timestamp": 0.0}],
|
||||
[{"style": "unmatched", "timestamp": 1.0}],
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [
|
||||
[{"role": "user", "content": "pick the cube"}],
|
||||
[{"role": "user", "content": "open the drawer"}],
|
||||
]
|
||||
assert data["message_streams"] == [["low_level"], ["low_level"]]
|
||||
assert data["target_message_indices"] == [[], []]
|
||||
|
||||
Reference in New Issue
Block a user