mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix(processor): fallback to task message when recipe misses
Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -69,7 +69,9 @@ class RenderMessagesStep(ProcessorStep):
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
return None
|
||||
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||
if rendered is None:
|
||||
return None
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||
@@ -107,7 +109,9 @@ class RenderMessagesStep(ProcessorStep):
|
||||
dataset_ctx=self.dataset_ctx,
|
||||
)
|
||||
if rendered is None:
|
||||
continue
|
||||
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
|
||||
if rendered is None:
|
||||
continue
|
||||
keep_indices.append(i)
|
||||
messages.append(rendered["messages"])
|
||||
message_streams.append(rendered["message_streams"])
|
||||
@@ -178,3 +182,16 @@ def _select_value(value: Any, indices: list[int]) -> Any:
|
||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||
return value.index_select(0, value.new_tensor(indices).long())
|
||||
return value
|
||||
|
||||
|
||||
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||
if hasattr(task, "item"):
|
||||
task = task.item()
|
||||
if not isinstance(task, str) or not task:
|
||||
return None
|
||||
return {
|
||||
"messages": [{"role": "user", "content": task}],
|
||||
"message_streams": ["low_level"],
|
||||
"target_message_indices": [],
|
||||
}
|
||||
|
||||
@@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
|
||||
assert data["messages"][-1]["content"] == "reach carefully"
|
||||
assert data["message_streams"] == ["high_level", "low_level"]
|
||||
assert data["target_message_indices"] == [1]
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
complementary_data={
|
||||
"task": "pick the cube",
|
||||
"timestamp": torch.tensor(0.0),
|
||||
"index": torch.tensor(7),
|
||||
"language_persistent": [],
|
||||
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
|
||||
}
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
|
||||
assert data["message_streams"] == ["low_level"]
|
||||
assert data["target_message_indices"] == []
|
||||
|
||||
|
||||
def test_render_messages_step_falls_back_per_sample_in_batched_language():
|
||||
recipe = TrainingRecipe(
|
||||
messages=[
|
||||
MessageTurn(
|
||||
role="assistant",
|
||||
content="${subtask}",
|
||||
stream="high_level",
|
||||
target=True,
|
||||
if_present="subtask",
|
||||
),
|
||||
]
|
||||
)
|
||||
transition = create_transition(
|
||||
action=torch.arange(4).reshape(2, 2),
|
||||
complementary_data={
|
||||
"task": ["pick the cube", "open the drawer"],
|
||||
"timestamp": torch.tensor([0.0, 1.0]),
|
||||
"index": torch.tensor([7, 8]),
|
||||
"language_persistent": [[], []],
|
||||
"language_events": [
|
||||
[{"style": "unmatched", "timestamp": 0.0}],
|
||||
[{"style": "unmatched", "timestamp": 1.0}],
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
out = RenderMessagesStep(recipe)(transition)
|
||||
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||
|
||||
assert data["messages"] == [
|
||||
[{"role": "user", "content": "pick the cube"}],
|
||||
[{"role": "user", "content": "open the drawer"}],
|
||||
]
|
||||
assert data["message_streams"] == [["low_level"], ["low_level"]]
|
||||
assert data["target_message_indices"] == [[], []]
|
||||
|
||||
Reference in New Issue
Block a user