fix(processor): fallback to task message when recipe misses

Keep action-only samples trainable by rendering the task as a low-level user message when no recipe branch matches.

Co-authored-by: Cursor <cursoragent@cursor.com>
This commit is contained in:
pepijn
2026-05-19 15:31:24 +00:00
parent 15f79b5e5e
commit e425dfd624
2 changed files with 86 additions and 2 deletions
@@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language():
assert data["messages"][-1]["content"] == "reach carefully"
assert data["message_streams"] == ["high_level", "low_level"]
assert data["target_message_indices"] == [1]
def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
complementary_data={
"task": "pick the cube",
"timestamp": torch.tensor(0.0),
"index": torch.tensor(7),
"language_persistent": [],
"language_events": [{"style": "unmatched", "timestamp": 0.0}],
}
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [{"role": "user", "content": "pick the cube"}]
assert data["message_streams"] == ["low_level"]
assert data["target_message_indices"] == []
def test_render_messages_step_falls_back_per_sample_in_batched_language():
recipe = TrainingRecipe(
messages=[
MessageTurn(
role="assistant",
content="${subtask}",
stream="high_level",
target=True,
if_present="subtask",
),
]
)
transition = create_transition(
action=torch.arange(4).reshape(2, 2),
complementary_data={
"task": ["pick the cube", "open the drawer"],
"timestamp": torch.tensor([0.0, 1.0]),
"index": torch.tensor([7, 8]),
"language_persistent": [[], []],
"language_events": [
[{"style": "unmatched", "timestamp": 0.0}],
[{"style": "unmatched", "timestamp": 1.0}],
],
},
)
out = RenderMessagesStep(recipe)(transition)
data = out[TransitionKey.COMPLEMENTARY_DATA]
assert data["messages"] == [
[{"role": "user", "content": "pick the cube"}],
[{"role": "user", "content": "open the drawer"}],
]
assert data["message_streams"] == [["low_level"], ["low_level"]]
assert data["target_message_indices"] == [[], []]