mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
fix(processor): fallback to task message when recipe misses
Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches. Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
@@ -69,7 +69,9 @@ class RenderMessagesStep(ProcessorStep):
|
|||||||
dataset_ctx=self.dataset_ctx,
|
dataset_ctx=self.dataset_ctx,
|
||||||
)
|
)
|
||||||
if rendered is None:
|
if rendered is None:
|
||||||
return None
|
rendered = _fallback_low_level_render(complementary_data.get("task"))
|
||||||
|
if rendered is None:
|
||||||
|
return None
|
||||||
|
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
|
||||||
@@ -107,7 +109,9 @@ class RenderMessagesStep(ProcessorStep):
|
|||||||
dataset_ctx=self.dataset_ctx,
|
dataset_ctx=self.dataset_ctx,
|
||||||
)
|
)
|
||||||
if rendered is None:
|
if rendered is None:
|
||||||
continue
|
rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
|
||||||
|
if rendered is None:
|
||||||
|
continue
|
||||||
keep_indices.append(i)
|
keep_indices.append(i)
|
||||||
messages.append(rendered["messages"])
|
messages.append(rendered["messages"])
|
||||||
message_streams.append(rendered["message_streams"])
|
message_streams.append(rendered["message_streams"])
|
||||||
@@ -178,3 +182,16 @@ def _select_value(value: Any, indices: list[int]) -> Any:
|
|||||||
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
|
||||||
return value.index_select(0, value.new_tensor(indices).long())
|
return value.index_select(0, value.new_tensor(indices).long())
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
|
||||||
|
"""Keep action-only samples trainable when no recipe branch matches."""
|
||||||
|
if hasattr(task, "item"):
|
||||||
|
task = task.item()
|
||||||
|
if not isinstance(task, str) or not task:
|
||||||
|
return None
|
||||||
|
return {
|
||||||
|
"messages": [{"role": "user", "content": task}],
|
||||||
|
"message_streams": ["low_level"],
|
||||||
|
"target_message_indices": [],
|
||||||
|
}
|
||||||
|
|||||||
@@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
|
|||||||
assert data["messages"][-1]["content"] == "reach carefully"
|
assert data["messages"][-1]["content"] == "reach carefully"
|
||||||
assert data["message_streams"] == ["high_level", "low_level"]
|
assert data["message_streams"] == ["high_level", "low_level"]
|
||||||
assert data["target_message_indices"] == [1]
|
assert data["target_message_indices"] == [1]
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${subtask}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
if_present="subtask",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transition = create_transition(
|
||||||
|
complementary_data={
|
||||||
|
"task": "pick the cube",
|
||||||
|
"timestamp": torch.tensor(0.0),
|
||||||
|
"index": torch.tensor(7),
|
||||||
|
"language_persistent": [],
|
||||||
|
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
out = RenderMessagesStep(recipe)(transition)
|
||||||
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
|
||||||
|
assert data["message_streams"] == ["low_level"]
|
||||||
|
assert data["target_message_indices"] == []
|
||||||
|
|
||||||
|
|
||||||
|
def test_render_messages_step_falls_back_per_sample_in_batched_language():
|
||||||
|
recipe = TrainingRecipe(
|
||||||
|
messages=[
|
||||||
|
MessageTurn(
|
||||||
|
role="assistant",
|
||||||
|
content="${subtask}",
|
||||||
|
stream="high_level",
|
||||||
|
target=True,
|
||||||
|
if_present="subtask",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
)
|
||||||
|
transition = create_transition(
|
||||||
|
action=torch.arange(4).reshape(2, 2),
|
||||||
|
complementary_data={
|
||||||
|
"task": ["pick the cube", "open the drawer"],
|
||||||
|
"timestamp": torch.tensor([0.0, 1.0]),
|
||||||
|
"index": torch.tensor([7, 8]),
|
||||||
|
"language_persistent": [[], []],
|
||||||
|
"language_events": [
|
||||||
|
[{"style": "unmatched", "timestamp": 0.0}],
|
||||||
|
[{"style": "unmatched", "timestamp": 1.0}],
|
||||||
|
],
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
out = RenderMessagesStep(recipe)(transition)
|
||||||
|
data = out[TransitionKey.COMPLEMENTARY_DATA]
|
||||||
|
|
||||||
|
assert data["messages"] == [
|
||||||
|
[{"role": "user", "content": "pick the cube"}],
|
||||||
|
[{"role": "user", "content": "open the drawer"}],
|
||||||
|
]
|
||||||
|
assert data["message_streams"] == [["low_level"], ["low_level"]]
|
||||||
|
assert data["target_message_indices"] == [[], []]
|
||||||
|
|||||||
Reference in New Issue
Block a user