mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
fix(policies,recipe): register PI052Config + allow flow-only sub-recipes
Two regressions surfaced by the first training run:
1. ``--policy.type=pi052`` failed with ``invalid choice``. PI052Config
wasn't imported in ``policies/__init__.py``, so its
``@register_subclass("pi052")`` decorator never ran and draccus
didn't see it as a valid policy type. Mirror PI05Config /
SmolVLA2Config in the top-level imports + __all__.
2. ``low_level_execution`` (user-only ``${subtask}`` recipe used for
π0.5-style flow conditioning) tripped
``ValueError: Message recipes must contain at least one target
turn.`` The validator was too strict — a recipe with only a
``stream: low_level`` turn still drives meaningful supervision
(flow MSE on the action expert via ``predict_actions=True``).
Allow either ``target: true`` OR ``stream: low_level`` to satisfy
the "supervises something" requirement.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -134,7 +134,16 @@ class TrainingRecipe:
|
||||
return cls.from_dict(data)
|
||||
|
||||
def _validate_message_recipe(self) -> None:
|
||||
"""Ensure every templated binding is known and at least one turn is a target."""
|
||||
"""Ensure every templated binding is known and the recipe supervises something.
|
||||
|
||||
A recipe is valid if it has at least one of:
|
||||
|
||||
* a ``target: true`` assistant turn (drives text-CE supervision), or
|
||||
* a ``stream: low_level`` turn (drives flow / action supervision via
|
||||
``predict_actions=True``, even when no assistant turn is targeted —
|
||||
e.g. π0.5-style ``low_level_execution`` where the action expert
|
||||
conditions on a user-only ``${subtask}`` prompt).
|
||||
"""
|
||||
assert self.messages is not None
|
||||
known_bindings = set(DEFAULT_BINDINGS) | set(self.bindings or {}) | {"task"}
|
||||
|
||||
@@ -143,8 +152,14 @@ class TrainingRecipe:
|
||||
if missing:
|
||||
raise ValueError(f"MessageTurn references unknown binding(s): {sorted(missing)}")
|
||||
|
||||
if not any(turn.target for turn in self.messages):
|
||||
raise ValueError("Message recipes must contain at least one target turn.")
|
||||
has_target = any(turn.target for turn in self.messages)
|
||||
has_low_level = any(turn.stream == "low_level" for turn in self.messages)
|
||||
if not (has_target or has_low_level):
|
||||
raise ValueError(
|
||||
"Message recipes must contain at least one supervised turn — "
|
||||
"either ``target: true`` (text CE) or ``stream: low_level`` "
|
||||
"(flow/action loss)."
|
||||
)
|
||||
|
||||
def _validate_blend_recipe(self) -> None:
|
||||
"""Ensure each blend component is a non-empty, weighted message recipe."""
|
||||
|
||||
@@ -20,6 +20,7 @@ from .multi_task_dit.configuration_multi_task_dit import MultiTaskDiTConfig as M
|
||||
from .pi0.configuration_pi0 import PI0Config as PI0Config
|
||||
from .pi0_fast.configuration_pi0_fast import PI0FastConfig as PI0FastConfig
|
||||
from .pi05.configuration_pi05 import PI05Config as PI05Config
|
||||
from .pi052.configuration_pi052 import PI052Config as PI052Config
|
||||
from .pretrained import PreTrainedPolicy as PreTrainedPolicy
|
||||
from .rtc import ActionInterpolator as ActionInterpolator
|
||||
from .sac.configuration_sac import SACConfig as SACConfig
|
||||
@@ -46,6 +47,7 @@ __all__ = [
|
||||
"PI0Config",
|
||||
"PI0FastConfig",
|
||||
"PI05Config",
|
||||
"PI052Config",
|
||||
"RewardClassifierConfig",
|
||||
"SACConfig",
|
||||
"SARMConfig",
|
||||
|
||||
Reference in New Issue
Block a user