diff --git a/src/lerobot/policies/pi052/modeling_pi052.py b/src/lerobot/policies/pi052/modeling_pi052.py index 9278ca04b..1283042d4 100644 --- a/src/lerobot/policies/pi052/modeling_pi052.py +++ b/src/lerobot/policies/pi052/modeling_pi052.py @@ -68,12 +68,9 @@ from .configuration_pi052 import PI052Config logger = logging.getLogger(__name__) -# ====================================================================== -# PI0.5 flow-matching model + helpers (moved here from pi05_backbone.py). -# pi052-specific; the generic dual-expert transformer (PaliGemmaWithExpertModel, -# sdpa_attention_forward, compute_layer_complete, get_gemma_config) lives in -# ``lerobot.policies.pi_gemma`` and is imported above. -# ====================================================================== +# PI0.5 flow-matching model + helpers (pi052-specific). The generic dual-expert +# transformer (PaliGemmaWithExpertModel, sdpa_attention_forward, +# compute_layer_complete, get_gemma_config) lives in lerobot.policies.pi_gemma. class ActionSelectKwargs(TypedDict, total=False): inference_delay: int | None diff --git a/src/lerobot/policies/pi_gemma.py b/src/lerobot/policies/pi_gemma.py index f8d71cdd7..c8631cbff 100644 --- a/src/lerobot/policies/pi_gemma.py +++ b/src/lerobot/policies/pi_gemma.py @@ -383,12 +383,10 @@ __all__ = [ ] -# ====================================================================== -# PI0.5 / PI052 dual-expert backbone (moved here from pi052/pi05_backbone.py). -# Generic PaliGemma + Gemma action-expert transformer machinery shared by the -# pi052 policy. ``GemmaVariantConfig`` is openpi's width/depth variant config -# (renamed from GemmaConfig to avoid clashing with transformers' GemmaConfig). -# ====================================================================== +# PI0.5 / PI052 dual-expert backbone: generic PaliGemma + Gemma action-expert +# transformer machinery used by the pi052 policy. GemmaVariantConfig is openpi's +# width/depth variant config (renamed from GemmaConfig to avoid clashing with +# transformers' GemmaConfig). def sdpa_attention_forward( module,