mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
feat(pi052): FAST action CE loss + knowledge insulation + processor wiring
Three additions ported from ``pi05_full`` on branch ``feat/add-pi05``,
giving pi052 full paper-§III.B-C training capabilities alongside the
recipe-driven text supervision it already had:
* **Config flags** in PI052Config:
- ``enable_fast_action_loss`` default False
- ``action_tokenizer_name`` default "physical-intelligence/fast"
- ``max_action_tokens`` default 256
- ``fast_skip_tokens`` default 128
- ``fast_action_loss_weight`` default 1.0
- ``knowledge_insulation`` default False
* **Processor wiring** (processor_pi052.py): when
``enable_fast_action_loss=True``, append an
``ActionTokenizerProcessorStep`` after the text tokenizer. It
tokenises the action tensor with the FAST tokenizer and writes
ACTION_TOKENS / ACTION_TOKEN_MASK into ``COMPLEMENTARY_DATA`` —
the existing batch-collation pipeline forwards them as
``batch['action.tokens']`` / ``batch['action.token_mask']``.
* **FAST CE loss** (modeling_pi052.py::_compute_fast_action_loss):
Re-embeds the prefix [images, language], appends the FAST token
embeddings (using PaliGemma's shared embed_language_tokens),
forwards through the backbone, slices the trailing
``fast_len`` positions, applies the LM head, computes shifted
next-token CE with the action-mask gating the loss. The loss is
summed into ``forward()``'s total with ``fast_action_loss_weight``.
* **Knowledge insulation** (modeling_pi052.py::_compute_layer_ki +
_paligemma_forward_ki): port of pi05_full's per-layer attention
that detaches VLM K/V on the action-query path so action loss
gradients cannot flow back into the VLM's K/V projections. Bound
per-instance via ``types.MethodType`` so it doesn't leak into
stock ``pi05`` policies that share PaliGemmaWithExpertModel.
Activated automatically when ``config.knowledge_insulation=True``.
Combined with the existing recipe-driven text head, pi052 now
supports the full three-loss objective:
L = text_w·H(text) + fast_w·H(FAST actions) + flow_w·MSE(flow)
matching Eq. (1) of arxiv:2504.16054 §IV.D (α=10 by default for the
flow term, 1.0 each for text and FAST CE).
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -99,6 +99,50 @@ class PI052Config(PI05Config):
|
||||
memory_dropout_prob: float = 0.0
|
||||
subtask_dropout_prob: float = 0.0
|
||||
|
||||
# FAST discrete-action supervision — paper §III.B-C ------------------
|
||||
# When enabled, actions are *also* tokenised via the FAST tokenizer
|
||||
# ("physical-intelligence/fast") and supervised with cross-entropy
|
||||
# on the PaliGemma LM head — exactly as in the paper's pre-training
|
||||
# objective (Eq. 1 mixes FAST CE + flow MSE + subtask CE). The
|
||||
# ActionTokenizerProcessorStep is wired into the preprocessor
|
||||
# pipeline when this flag is set; the loss is computed in
|
||||
# PI052Policy.forward.
|
||||
enable_fast_action_loss: bool = False
|
||||
"""If True, tokenise actions with the FAST tokenizer and add a
|
||||
cross-entropy loss on the LM head. Off by default because most
|
||||
fine-tuning runs only need the flow head + text supervision; the
|
||||
FAST CE term is most useful when training from a base PaliGemma
|
||||
rather than an existing π0.5 checkpoint."""
|
||||
|
||||
action_tokenizer_name: str = "physical-intelligence/fast"
|
||||
"""HF identifier for the FAST action tokenizer."""
|
||||
|
||||
max_action_tokens: int = 256
|
||||
"""Maximum number of FAST tokens per action chunk."""
|
||||
|
||||
fast_skip_tokens: int = 128
|
||||
"""Number of low-vocab tokens the FAST tokenizer skips to avoid
|
||||
collisions with PaliGemma's text vocabulary."""
|
||||
|
||||
fast_action_loss_weight: float = 1.0
|
||||
"""Weight on the FAST-action-token CE loss. Paper §III.C uses 1.0."""
|
||||
|
||||
# Knowledge insulation — paper §III.B --------------------------------
|
||||
# When enabled, gradients from the action expert's flow loss are
|
||||
# *blocked* from flowing back into the VLM's K/V projections. This
|
||||
# prevents the action loss from over-fitting the language backbone
|
||||
# to robot-specific features. Implementation requires a custom
|
||||
# per-layer attention forward that uses ``.detach()`` on VLM K/V
|
||||
# when computing attention for action queries — see
|
||||
# ``pi05_full/modeling_pi05.py::compute_layer_complete_knowledge_insulation``
|
||||
# on branch ``feat/add-pi05`` for the reference implementation.
|
||||
knowledge_insulation: bool = False
|
||||
"""If True, route every transformer layer through the KI
|
||||
attention path (blocks action→VLM gradient flow on K/V).
|
||||
Currently a no-op stub in this branch — PR follow-up required to
|
||||
port the full custom layer from pi05_full. The flag is exposed
|
||||
here so SLURM commands can be written against the final shape."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
# Backbone needs gradients flowing through the text head when
|
||||
|
||||
@@ -39,10 +39,12 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import math
|
||||
import types
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.nn import functional as F
|
||||
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
@@ -53,6 +55,190 @@ from .configuration_pi052 import PI052Config
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------
|
||||
# Knowledge insulation — ported from pi05_full (branch ``feat/add-pi05``)
|
||||
# ----------------------------------------------------------------------
|
||||
#
|
||||
# Per-layer attention that splits the queries into VLM and action
|
||||
# parts, computing attention for action queries with .detach()'d VLM
|
||||
# K/V so the action loss's gradient cannot flow back into the VLM's K
|
||||
# and V projections. Forward output is bit-equivalent to the standard
|
||||
# layer; backward differs only on the path action_loss → VLM K/V.
|
||||
|
||||
def _compute_layer_ki(
|
||||
layer_idx,
|
||||
inputs_embeds,
|
||||
attention_mask,
|
||||
position_ids,
|
||||
adarms_cond,
|
||||
paligemma,
|
||||
gemma_expert,
|
||||
):
|
||||
from transformers.models.gemma import modeling_gemma # noqa: PLC0415
|
||||
|
||||
models = [paligemma.language_model, gemma_expert.model]
|
||||
query_states, key_states, value_states, gates = [], [], [], []
|
||||
|
||||
vlm_len = inputs_embeds[0].shape[1]
|
||||
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
hidden_states, gate = layer.input_layernorm(hidden_states, cond=adarms_cond[i])
|
||||
gates.append(gate)
|
||||
input_shape = hidden_states.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
q = layer.self_attn.q_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
k = layer.self_attn.k_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
v = layer.self_attn.v_proj(hidden_states).view(hidden_shape).transpose(1, 2)
|
||||
query_states.append(q)
|
||||
key_states.append(k)
|
||||
value_states.append(v)
|
||||
|
||||
query_states = torch.cat(query_states, dim=2)
|
||||
key_states = torch.cat(key_states, dim=2)
|
||||
value_states = torch.cat(value_states, dim=2)
|
||||
|
||||
dummy = torch.zeros(
|
||||
query_states.shape[0],
|
||||
query_states.shape[2],
|
||||
query_states.shape[-1],
|
||||
device=query_states.device,
|
||||
dtype=query_states.dtype,
|
||||
)
|
||||
cos, sin = paligemma.model.language_model.rotary_emb(dummy, position_ids)
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
batch_size = query_states.shape[0]
|
||||
scaling = paligemma.language_model.layers[layer_idx].self_attn.scaling
|
||||
|
||||
# Split queries / K / V at the VLM-vs-action boundary.
|
||||
Q_vlm = query_states[:, :, :vlm_len, :]
|
||||
Q_action = query_states[:, :, vlm_len:, :]
|
||||
K_vlm = key_states[:, :, :vlm_len, :]
|
||||
K_action = key_states[:, :, vlm_len:, :]
|
||||
V_vlm = value_states[:, :, :vlm_len, :]
|
||||
V_action = value_states[:, :, vlm_len:, :]
|
||||
|
||||
# Detach VLM K/V *only* on the path the action queries use.
|
||||
K_vlm_det = K_vlm.detach()
|
||||
V_vlm_det = V_vlm.detach()
|
||||
K_for_vlm = key_states # full (gradients flow)
|
||||
V_for_vlm = value_states
|
||||
K_for_action = torch.cat([K_vlm_det, K_action], dim=2)
|
||||
V_for_action = torch.cat([V_vlm_det, V_action], dim=2)
|
||||
|
||||
mask_for_vlm = attention_mask[:, :, :vlm_len, :]
|
||||
mask_for_action = attention_mask[:, :, vlm_len:, :]
|
||||
|
||||
att_vlm, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
Q_vlm, K_for_vlm, V_for_vlm, mask_for_vlm, scaling,
|
||||
)
|
||||
att_action, _ = modeling_gemma.eager_attention_forward(
|
||||
paligemma.language_model.layers[layer_idx].self_attn,
|
||||
Q_action, K_for_action, V_for_action, mask_for_action, scaling,
|
||||
)
|
||||
att = torch.cat([att_vlm, att_action], dim=1)
|
||||
|
||||
head_dim = paligemma.language_model.layers[layer_idx].self_attn.head_dim
|
||||
att = att.reshape(batch_size, -1, 1 * 8 * head_dim)
|
||||
|
||||
outputs_embeds = []
|
||||
start = 0
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
layer = models[i].layers[layer_idx]
|
||||
end = start + hidden_states.shape[1]
|
||||
if att.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
att = att.to(layer.self_attn.o_proj.weight.dtype)
|
||||
out_emb = layer.self_attn.o_proj(att[:, start:end])
|
||||
out_emb = modeling_gemma._gated_residual(hidden_states, out_emb, gates[i]) # noqa: SLF001
|
||||
after_first = out_emb.clone()
|
||||
out_emb, gate = layer.post_attention_layernorm(out_emb.clone(), cond=adarms_cond[i])
|
||||
if layer.mlp.up_proj.weight.dtype == torch.bfloat16:
|
||||
out_emb = out_emb.to(dtype=torch.bfloat16)
|
||||
out_emb = layer.mlp(out_emb)
|
||||
out_emb = modeling_gemma._gated_residual(after_first, out_emb, gate) # noqa: SLF001
|
||||
outputs_embeds.append(out_emb)
|
||||
start = end
|
||||
return outputs_embeds
|
||||
|
||||
|
||||
def _paligemma_forward_ki(
|
||||
self,
|
||||
attention_mask=None,
|
||||
position_ids=None,
|
||||
past_key_values=None,
|
||||
inputs_embeds=None,
|
||||
use_cache=None,
|
||||
adarms_cond=None,
|
||||
fill_kv_cache=None,
|
||||
):
|
||||
"""Replacement ``PaliGemmaWithExpertModel.forward`` that routes the
|
||||
dual-expert layer pass through :func:`_compute_layer_ki`.
|
||||
|
||||
Bound onto the model instance when ``config.knowledge_insulation``
|
||||
is True (see ``PI052Policy.__init__``). Single-expert branches
|
||||
(VLM-only or action-only) reuse the parent's implementation
|
||||
because there's no KI signal to add — KI only matters when
|
||||
actions and VLM tokens are forwarded together.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import layernorm_forward # noqa: PLC0415
|
||||
|
||||
if adarms_cond is None:
|
||||
adarms_cond = [None, None]
|
||||
|
||||
# Single-expert paths: defer to the bound class method via super().
|
||||
if inputs_embeds[0] is None or inputs_embeds[1] is None:
|
||||
return type(self).__bases__[0].forward(
|
||||
self,
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond,
|
||||
) if hasattr(self, "_pi052_orig_forward") else self._pi052_orig_forward(
|
||||
attention_mask=attention_mask,
|
||||
position_ids=position_ids,
|
||||
past_key_values=past_key_values,
|
||||
inputs_embeds=inputs_embeds,
|
||||
use_cache=use_cache,
|
||||
adarms_cond=adarms_cond,
|
||||
)
|
||||
|
||||
models = [self.paligemma.model.language_model, self.gemma_expert.model]
|
||||
num_layers = self.paligemma.config.text_config.num_hidden_layers
|
||||
use_gc = (
|
||||
hasattr(self.gemma_expert.model, "gradient_checkpointing")
|
||||
and self.gemma_expert.model.gradient_checkpointing
|
||||
and self.training
|
||||
) or (
|
||||
hasattr(self, "gradient_checkpointing") and self.gradient_checkpointing and self.training
|
||||
)
|
||||
|
||||
for layer_idx in range(num_layers):
|
||||
if use_gc:
|
||||
inputs_embeds = torch.utils.checkpoint.checkpoint(
|
||||
_compute_layer_ki,
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
|
||||
use_reentrant=False, preserve_rng_state=False,
|
||||
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
|
||||
)
|
||||
else:
|
||||
inputs_embeds = _compute_layer_ki(
|
||||
layer_idx, inputs_embeds, attention_mask, position_ids, adarms_cond,
|
||||
paligemma=self.paligemma, gemma_expert=self.gemma_expert,
|
||||
)
|
||||
|
||||
outputs_embeds = []
|
||||
for i, hidden_states in enumerate(inputs_embeds):
|
||||
out_emb, _ = layernorm_forward(models[i].norm, hidden_states, adarms_cond[i])
|
||||
outputs_embeds.append(out_emb)
|
||||
return [outputs_embeds[0], outputs_embeds[1]], None
|
||||
|
||||
|
||||
class PI052Policy(PI05Policy):
|
||||
"""π0.5 with the PaliGemma LM head re-enabled."""
|
||||
|
||||
@@ -68,6 +254,20 @@ class PI052Policy(PI05Policy):
|
||||
if config.text_loss_weight > 0 and config.unfreeze_lm_head:
|
||||
self._unfreeze_lm_head()
|
||||
|
||||
# Knowledge insulation: bind a custom ``forward`` on the
|
||||
# PaliGemmaWithExpertModel instance that uses
|
||||
# :func:`_compute_layer_ki` for the dual-expert layer pass.
|
||||
# The bind is per-instance, so this doesn't leak into stock
|
||||
# ``pi05`` policies that share the same class.
|
||||
if getattr(config, "knowledge_insulation", False):
|
||||
backbone = self.model.paligemma_with_expert
|
||||
backbone._pi052_orig_forward = backbone.forward
|
||||
backbone.forward = types.MethodType(_paligemma_forward_ki, backbone)
|
||||
logger.info(
|
||||
"PI052: knowledge insulation enabled — action→VLM K/V "
|
||||
"gradients are blocked in attention."
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Head unfreeze helper
|
||||
# ------------------------------------------------------------------
|
||||
@@ -143,6 +343,22 @@ class PI052Policy(PI05Policy):
|
||||
else total + self.config.text_loss_weight * text_loss
|
||||
)
|
||||
|
||||
# FAST action-token CE loss (paper §III.C). When
|
||||
# ``enable_fast_action_loss=True`` the preprocessor wrote
|
||||
# ACTION_TOKENS / ACTION_TOKEN_MASK into the batch — we
|
||||
# forward them through the PaliGemma backbone alongside the
|
||||
# language prefix and compute CE on the action positions.
|
||||
if getattr(self.config, "enable_fast_action_loss", False):
|
||||
from lerobot.utils.constants import ACTION_TOKEN_MASK, ACTION_TOKENS # noqa: PLC0415
|
||||
|
||||
action_tokens = batch.get(ACTION_TOKENS)
|
||||
action_mask = batch.get(ACTION_TOKEN_MASK)
|
||||
if action_tokens is not None and action_mask is not None:
|
||||
fast_loss = self._compute_fast_action_loss(batch, action_tokens, action_mask)
|
||||
loss_dict["fast_action_loss"] = float(fast_loss.detach().item())
|
||||
weighted = self.config.fast_action_loss_weight * fast_loss
|
||||
total = weighted if total is None else total + weighted
|
||||
|
||||
if total is None:
|
||||
# Both flow and text disabled — make this an obvious bug
|
||||
# rather than a silent zero loss.
|
||||
@@ -161,6 +377,82 @@ class PI052Policy(PI05Policy):
|
||||
# Text loss
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compute_fast_action_loss(
|
||||
self,
|
||||
batch: dict[str, Tensor],
|
||||
action_tokens: Tensor,
|
||||
action_mask: Tensor,
|
||||
) -> Tensor:
|
||||
"""Cross-entropy on FAST-tokenised actions via PaliGemma's LM head.
|
||||
|
||||
Mirrors the paper's §III.C action-token loss: append the FAST
|
||||
tokens to the language prefix, forward through the backbone,
|
||||
slice the per-token logits at the action positions, and
|
||||
compute CE against the FAST targets (shifted for next-token
|
||||
prediction). Token-mask gates the loss to valid positions.
|
||||
"""
|
||||
from ..pi05.modeling_pi05 import make_att_2d_masks # noqa: PLC0415
|
||||
|
||||
images, img_masks = self.model._preprocess_images(batch)
|
||||
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||
|
||||
# Re-embed prefix to get [images, language] embeddings, then
|
||||
# append the FAST action token embeddings as additional
|
||||
# prefix tokens. PaliGemma's embed_language_tokens is shared
|
||||
# between text and FAST tokens so we can re-use it directly.
|
||||
prefix_embs, prefix_pad, prefix_att = self.model.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks
|
||||
)
|
||||
emb_dim = prefix_embs.shape[-1]
|
||||
fast_emb = self.model.paligemma_with_expert.embed_language_tokens(action_tokens)
|
||||
fast_emb = fast_emb * math.sqrt(emb_dim)
|
||||
|
||||
# Concat onto the prefix. Pad masks: language uses
|
||||
# ``lang_masks``; FAST uses ``action_mask`` (True at valid
|
||||
# token positions). Attention masks add ``True`` (causal)
|
||||
# for FAST so they can attend to the bidirectional prefix
|
||||
# but only causally among themselves.
|
||||
bsize, fast_len = action_tokens.shape
|
||||
device = prefix_embs.device
|
||||
ones_att = torch.ones((bsize, fast_len), dtype=torch.bool, device=device)
|
||||
full_embs = torch.cat([prefix_embs, fast_emb], dim=1)
|
||||
full_pad = torch.cat([prefix_pad, action_mask.to(prefix_pad.dtype)], dim=1)
|
||||
full_att = torch.cat([prefix_att, ones_att], dim=1)
|
||||
|
||||
att_2d = make_att_2d_masks(full_pad, full_att)
|
||||
position_ids = torch.cumsum(full_pad, dim=1) - 1
|
||||
|
||||
(vlm_out, _), _ = self.model.paligemma_with_expert.forward(
|
||||
attention_mask=att_2d,
|
||||
position_ids=position_ids,
|
||||
past_key_values=None,
|
||||
inputs_embeds=[full_embs, None],
|
||||
use_cache=False,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
if vlm_out is None:
|
||||
raise RuntimeError("PI052 FAST loss: VLM forward returned no hidden states.")
|
||||
|
||||
# Slice the last ``fast_len`` positions — those correspond to
|
||||
# the FAST tokens we just appended.
|
||||
fast_hidden = vlm_out[:, -fast_len:, :]
|
||||
lm_head = self.model.paligemma_with_expert.paligemma.lm_head
|
||||
fast_logits = lm_head(fast_hidden.to(lm_head.weight.dtype))
|
||||
|
||||
# Shift for next-token prediction.
|
||||
shift_logits = fast_logits[:, :-1, :].contiguous()
|
||||
shift_targets = action_tokens[:, 1:].contiguous()
|
||||
shift_mask = action_mask[:, 1:].contiguous().float()
|
||||
|
||||
per_tok = F.cross_entropy(
|
||||
shift_logits.view(-1, shift_logits.size(-1)),
|
||||
shift_targets.view(-1),
|
||||
reduction="none",
|
||||
).view(shift_targets.shape)
|
||||
loss = (per_tok * shift_mask).sum() / shift_mask.sum().clamp_min(1.0)
|
||||
return loss
|
||||
|
||||
def _compute_text_loss(self, batch: dict[str, Tensor], text_labels: Tensor) -> Tensor:
|
||||
"""Cross-entropy on PaliGemma's LM head over the supervised span.
|
||||
|
||||
|
||||
@@ -43,6 +43,7 @@ import torch
|
||||
from lerobot.configs.recipe import TrainingRecipe
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
ActionTokenizerProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
@@ -101,9 +102,25 @@ def make_pi052_pre_post_processors(
|
||||
memory_dropout_prob=getattr(config, "memory_dropout_prob", 0.0),
|
||||
subtask_dropout_prob=getattr(config, "subtask_dropout_prob", 0.0),
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
# FAST tokenizer for discrete-action CE supervision (paper §III.C).
|
||||
# Only inserted when explicitly enabled — keeps the post-training-
|
||||
# style recipe (flow + text) as the default. When on, the step
|
||||
# writes ACTION_TOKENS / ACTION_TOKEN_MASK into
|
||||
# ``COMPLEMENTARY_DATA`` and the modeling forward picks them up.
|
||||
if getattr(config, "enable_fast_action_loss", False):
|
||||
input_steps.append(
|
||||
ActionTokenizerProcessorStep(
|
||||
action_tokenizer_name=config.action_tokenizer_name,
|
||||
max_action_tokens=config.max_action_tokens,
|
||||
fast_skip_tokens=config.fast_skip_tokens,
|
||||
paligemma_tokenizer_name="google/paligemma-3b-pt-224",
|
||||
)
|
||||
)
|
||||
|
||||
input_steps.append(DeviceProcessorStep(device=config.device))
|
||||
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
|
||||
Reference in New Issue
Block a user