diff --git a/src/lerobot/processor/render_messages_processor.py b/src/lerobot/processor/render_messages_processor.py index 4c9a25c4c..3a35436e4 100644 --- a/src/lerobot/processor/render_messages_processor.py +++ b/src/lerobot/processor/render_messages_processor.py @@ -69,7 +69,9 @@ class RenderMessagesStep(ProcessorStep): dataset_ctx=self.dataset_ctx, ) if rendered is None: - return None + rendered = _fallback_low_level_render(complementary_data.get("task")) + if rendered is None: + return None new_transition = transition.copy() new_complementary_data = dict(new_transition.get(TransitionKey.COMPLEMENTARY_DATA) or {}) @@ -107,7 +109,9 @@ class RenderMessagesStep(ProcessorStep): dataset_ctx=self.dataset_ctx, ) if rendered is None: - continue + rendered = _fallback_low_level_render(_batch_value(complementary_data.get("task"), i)) + if rendered is None: + continue keep_indices.append(i) messages.append(rendered["messages"]) message_streams.append(rendered["message_streams"]) @@ -178,3 +182,16 @@ def _select_value(value: Any, indices: list[int]) -> Any: if hasattr(value, "index_select") and hasattr(value, "new_tensor") and getattr(value, "ndim", 0) > 0: return value.index_select(0, value.new_tensor(indices).long()) return value + + +def _fallback_low_level_render(task: Any) -> dict[str, Any] | None: + """Keep action-only samples trainable when no recipe branch matches.""" + if hasattr(task, "item"): + task = task.item() + if not isinstance(task, str) or not task: + return None + return { + "messages": [{"role": "user", "content": task}], + "message_streams": ["low_level"], + "target_message_indices": [], + } diff --git a/tests/processor/test_render_messages_processor.py b/tests/processor/test_render_messages_processor.py index ff808f38f..1d2dfb326 100644 --- a/tests/processor/test_render_messages_processor.py +++ b/tests/processor/test_render_messages_processor.py @@ -54,3 +54,70 @@ def test_render_messages_step_renders_and_drops_raw_language(): assert data["messages"][-1]["content"] == "reach carefully" assert data["message_streams"] == ["high_level", "low_level"] assert data["target_message_indices"] == [1] + + +def test_render_messages_step_falls_back_to_low_level_task_when_recipe_misses(): + recipe = TrainingRecipe( + messages=[ + MessageTurn( + role="assistant", + content="${subtask}", + stream="high_level", + target=True, + if_present="subtask", + ), + ] + ) + transition = create_transition( + complementary_data={ + "task": "pick the cube", + "timestamp": torch.tensor(0.0), + "index": torch.tensor(7), + "language_persistent": [], + "language_events": [{"style": "unmatched", "timestamp": 0.0}], + } + ) + + out = RenderMessagesStep(recipe)(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert data["messages"] == [{"role": "user", "content": "pick the cube"}] + assert data["message_streams"] == ["low_level"] + assert data["target_message_indices"] == [] + + +def test_render_messages_step_falls_back_per_sample_in_batched_language(): + recipe = TrainingRecipe( + messages=[ + MessageTurn( + role="assistant", + content="${subtask}", + stream="high_level", + target=True, + if_present="subtask", + ), + ] + ) + transition = create_transition( + action=torch.arange(4).reshape(2, 2), + complementary_data={ + "task": ["pick the cube", "open the drawer"], + "timestamp": torch.tensor([0.0, 1.0]), + "index": torch.tensor([7, 8]), + "language_persistent": [[], []], + "language_events": [ + [{"style": "unmatched", "timestamp": 0.0}], + [{"style": "unmatched", "timestamp": 1.0}], + ], + }, + ) + + out = RenderMessagesStep(recipe)(transition) + data = out[TransitionKey.COMPLEMENTARY_DATA] + + assert data["messages"] == [ + [{"role": "user", "content": "pick the cube"}], + [{"role": "user", "content": "open the drawer"}], + ] + assert data["message_streams"] == [["low_level"], ["low_level"]] + assert data["target_message_indices"] == [[], []]