fix(processor): fallback to task message when recipe misses

Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-19 15:31:24 +00:00
parent 15f79b5e5e
commit e425dfd624
2 changed files with 86 additions and 2 deletions
@@ -69,7 +69,9 @@ class RenderMessagesStep(ProcessorStep):
dataset_ctx=self.dataset_ctx, dataset_ctx=self.dataset_ctx,
) )
if rendered is None: if rendered is None:
return None rendered = _fallback_low_level_render(complementary_data.get("task"))
if rendered is None:
return None
new_transition = transition.copy() new_transition = transition.copy()
new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {})
@@ -107,7 +109,9 @@ class RenderMessagesStep(ProcessorStep):
dataset_ctx=self.dataset_ctx, dataset_ctx=self.dataset_ctx,
) )
if rendered is None: if rendered is None:
continue rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i))
if rendered is None:
continue
keep_indices.append(i) keep_indices.append(i)
messages.append(rendered["messages"]) messages.append(rendered["messages"])
message_streams.append(rendered["message_streams"]) message_streams.append(rendered["message_streams"])
@@ -178,3 +182,16 @@ def _select_value(value: Any, indices: list[int]) -> Any:
if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0: if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0:
return value.index_select(0, value.new_tensor(indices).long()) return value.index_select(0, value.new_tensor(indices).long())
return value return value
def _fallback_low_level_render(task: Any) -> dict[str, Any] | None:
"""Keep action-only samples trainable when no recipe branch matches."""
if hasattr(task, "item"):
task = task.item()
if not isinstance(task, str) or not task:
return None
return {
"messages": [{"role": "user", "content": task}],
"message_streams": ["low_level"],
"target_message_indices": [],
}
@@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
assert data["messages"][-1]["content"] == "reach carefully" assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"] assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1] assert data["target_message_indices"] == [1]
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
complementary_data={
"task": "pick the cube",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [],
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
assert data["message_streams"] == ["low_level"]
assert data["target_message_indices"] == []
def test_render_messages_step_falls_back_per_sample_in_batched_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
action=torch.arange(4).reshape(2, 2),
complementary_data={
"task": ["pick the cube", "open the drawer"],
"timestamp": torch.tensor([0.0, 1.0]),
"index": torch.tensor([7, 8]),
"language_persistent": [[], []],
"language_events": [
[{"style": "unmatched", "timestamp": 0.0}],
[{"style": "unmatched", "timestamp": 1.0}],
],
},
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [
[{"role": "user", "content": "pick the cube"}],
[{"role": "user", "content": "open the drawer"}],
]
assert data["message_streams"] == [["low_level"], ["low_level"]]
assert data["target_message_indices"] == [[], []]