refactor(recipes): fold memory into action_execution, drop interjection, fuse smolvla2 forward

Recipe changes:
* action_execution now bundles the memory update as a second
  assistant target gated on a new ``new_memory`` binding (fires
  only at subtask-boundary frames). No "Completed subtask: X"
  filler — the model emits the new subtask AND the updated
  memory back-to-back in one prefix.
* user_interjection_response sub-recipe removed (current
  datasets don't have interjection / say() annotations).
* Standalone memory_update sub-recipe removed (folded above).
* Weights rebalanced: action_execution 0.85, ask_vqa_top/wrist
  0.075 each (sums to 1.0).

Runtime ``_msgs_for_memory`` updated to match the new
boundary-frame prompt layout.

Modeling:
* SmolVLA2Policy now fuses the flow + text losses into a SINGLE
  backbone forward via ``_compute_fused_loss`` (one
  vlm_with_expert pass with [prefix, suffix] embeds, then both
  lm_head CE on lang slice + action_out_proj MSE on suffix).
  Mirrors pi052's existing ``_compute_all_losses_fused`` —
  saves one backbone pass per training step.

Examples:
* Removed the two training SLURM scaffolds; they were
  out-of-date with the recipe refactor.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
Pepijn
2026-05-13 12:51:09 +02:00
parent 058b8f3958
commit b2aa372fcf
6 changed files with 228 additions and 243 deletions
-75
View File
@@ -1,75 +0,0 @@
#!/bin/bash
#SBATCH --job-name=pi052-hirobot
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=48:00:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=8
# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe.
#
# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml),
# just on the PaliGemma 2B + Gemma-300m action-expert backbone the
# paper uses. The text head learns subtask prediction via cross-
# entropy on supervised spans; the action expert learns the flow
# field. Paper §IV.D mixes the two losses with α=10, which we encode
# as flow_loss_weight=10 / text_loss_weight=1.
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}"
JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}"
NUM_PROCESSES="${NUM_PROCESSES:-8}"
BATCH_SIZE="${BATCH_SIZE:-32}"
STEPS="${STEPS:-15000}"
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}"
echo "Training pi052 on $DATASET"
echo " GPUs: $NUM_PROCESSES"
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
echo " steps: $STEPS"
echo " output: $OUTPUT_DIR"
echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1"
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
-m lerobot.scripts.lerobot_train \
--policy.type=pi052 \
--policy.recipe_path=recipes/pi052_hirobot.yaml \
--dataset.repo_id="$DATASET" \
--dataset.revision=main \
--dataset.video_backend=pyav \
--output_dir="$OUTPUT_DIR" \
--job_name="$JOB_NAME" \
--policy.repo_id="$POLICY_REPO_ID" \
--policy.compile_model=false \
--policy.device=cuda \
--policy.tokenizer_max_length=512 \
--policy.text_loss_weight=1.0 \
--policy.flow_loss_weight=10.0 \
--policy.unfreeze_lm_head=true \
--steps="$STEPS" \
--policy.scheduler_decay_steps="$STEPS" \
--batch_size="$BATCH_SIZE" \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=hirobot \
--log_freq=100 \
--save_freq="$STEPS" \
--num_workers=0 \
--dataset.image_transforms.enable=true \
--dataset.image_transforms.max_num_transforms=3 \
--dataset.image_transforms.random_order=true \
--policy.plan_dropout_prob=0.30 \
--policy.memory_dropout_prob=0.30 \
--policy.subtask_dropout_prob=0.20
-82
View File
@@ -1,82 +0,0 @@
#!/bin/bash
#SBATCH --job-name=smolvla2-hirobot
#SBATCH --partition=hopper-prod
#SBATCH --qos=high
#SBATCH --time=48:00:00
#SBATCH --ntasks=1
#SBATCH --gpus-per-task=8
# SmolVLA2 training on an annotated dataset.
#
# The high_level_subtask recipe (recipes/smolvla2_hirobot.yaml) was
# fixed in PR3 to supervise the LM head with the *current* active
# subtask span at every frame, not the next-span target which is
# empty on stable phases. With the old recipe the head learned to
# emit ``\n`` on every chunk boundary; the new one supervises a
# real, scene-grounded string at every frame.
#
# Two regularisers are still on:
#
# * --dataset.image_transforms.enable=true: torchvision-v2
# ColorJitter + SharpnessJitter + RandomAffine per frame; default
# envelope (brightness ±20% etc).
# * --policy.{plan,memory,subtask}_dropout_prob: randomly drop the
# context messages carrying the named recipe binding so the model
# handles missing/stale context. Mirrors Pi0.7 §V.E.
set -euo pipefail
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool6}"
JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool6}"
NUM_PROCESSES="${NUM_PROCESSES:-8}"
BATCH_SIZE="${BATCH_SIZE:-32}"
STEPS="${STEPS:-15000}"
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_super_poulain_tool3_${STEPS}_${RUN_ID}}"
echo "Training smolvla2 on $DATASET"
echo " GPUs: $NUM_PROCESSES"
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
echo " steps: $STEPS"
echo " output: $OUTPUT_DIR"
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
-m lerobot.scripts.lerobot_train \
--policy.type=smolvla2 \
--policy.recipe_path=recipes/smolvla2_hirobot.yaml \
--dataset.repo_id="$DATASET" \
--dataset.revision=main \
--dataset.video_backend=pyav \
--output_dir="$OUTPUT_DIR" \
--job_name="$JOB_NAME" \
--policy.repo_id="$POLICY_REPO_ID" \
--policy.compile_model=false \
--policy.device=cuda \
--policy.tokenizer_max_length=512 \
--policy.text_loss_weight=1.0 \
--policy.flow_loss_weight=10.0 \
--steps="$STEPS" \
--policy.scheduler_decay_steps="$STEPS" \
--batch_size="$BATCH_SIZE" \
--wandb.enable=true \
--wandb.disable_artifact=true \
--wandb.project=hirobot \
--log_freq=100 \
--save_freq="$STEPS" \
--num_workers=0 \
--dataset.image_transforms.enable=true \
--dataset.image_transforms.max_num_transforms=3 \
--dataset.image_transforms.random_order=true \
--policy.plan_dropout_prob=0.30 \
--policy.memory_dropout_prob=0.30 \
--policy.subtask_dropout_prob=0.20
+24 -27
View File
@@ -22,53 +22,50 @@
# Pi 0.7 §V.A — subtask in the prompt + flow on actions. # Pi 0.7 §V.A — subtask in the prompt + flow on actions.
# #
# Flavor 2 — event-driven text-only recipes # Flavor 2 — event-driven text-only recipes
# ``memory_update``, ``user_interjection_response``,
# ``ask_vqa_*``. Each handles a specific high-level event # ``ask_vqa_*``. Each handles a specific high-level event
# with a TEXT output. ``if_present`` guards keep them from # with a TEXT output. ``if_present`` guards keep them from
# firing on frames without the relevant annotation. # firing on frames without the relevant annotation.
#
# Memory updates are folded INTO ``action_execution`` as a
# conditional second target gated on boundary frames — see
# ``smolvla2_hirobot.yaml`` for the rationale. The
# ``user_interjection_response`` recipe was dropped — the
# current datasets don't include interjection / say() annotations.
blend: blend:
# ---------------------------------------------------------- # ----------------------------------------------------------
# FLAVOR 1: action_execution (main path) # FLAVOR 1: action_execution (main path)
#
# Bundles memory updates inline. On most frames the binding
# ``new_memory: emitted_at(t, style=memory)`` returns None and
# only the subtask is supervised. On *boundary* frames (the
# exact timestamp a new memory was annotated — i.e. when a
# subtask just completed) the binding fires and the recipe
# supervises the new memory as a follow-up assistant turn,
# with a "Completed subtask: …" user message in between to
# separate the two outputs in the rendered prefix.
# ---------------------------------------------------------- # ----------------------------------------------------------
action_execution: action_execution:
weight: 0.60 weight: 0.85
bindings:
new_memory: "emitted_at(t, style=memory)"
messages: messages:
- role: user - role: user
stream: high_level stream: high_level
content: "${task}\nPlan: ${plan}\nMemory: ${memory}" content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask} - {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
# Memory-update tail — only renders at boundary frames where
# ``new_memory`` fires. The new memory is appended as a second
# assistant turn right after the subtask, with no intervening
# user filler: at a subtask boundary the model emits the new
# subtask AND the updated memory in one forward pass.
- {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory}
# ---------------------------------------------------------- # ----------------------------------------------------------
# FLAVOR 2: event-driven text-only paths # FLAVOR 2: event-driven text-only paths
# ---------------------------------------------------------- # ----------------------------------------------------------
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.15
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
ask_vqa_top: ask_vqa_top:
weight: 0.075 weight: 0.075
bindings: bindings:
@@ -28,16 +28,17 @@
# Each handles a specific high-level event with a TEXT # Each handles a specific high-level event with a TEXT
# output (no action supervision). They fire when the # output (no action supervision). They fire when the
# binding for the event resolves to non-None: # binding for the event resolves to non-None:
# * ``memory_update``: at subtask boundary, predict new
# memory from task + prior memory + completed subtask.
# * ``user_interjection_response``: on user input, predict
# new plan + paired ``say()`` tool call.
# * ``ask_vqa_top`` / ``ask_vqa_wrist``: answer a # * ``ask_vqa_top`` / ``ask_vqa_wrist``: answer a
# camera-grounded visual question. # camera-grounded visual question.
# All use ``stream: high_level`` (no flow loss) and rely on # All use ``stream: high_level`` (no flow loss) and rely on
# ``if_present`` guards so they only fire on frames where # ``if_present`` guards so they only fire on frames where
# the relevant event annotation is present. # the relevant event annotation is present.
# #
# ``memory_update`` is folded into Flavor 1 (gated on the
# ``new_memory`` binding at boundary frames).
# ``user_interjection_response`` was dropped — the current
# datasets don't include interjection / say() annotations.
#
# How the chat tokenizer interprets the flavor split # How the chat tokenizer interprets the flavor split
# --------------------------------------------------- # ---------------------------------------------------
# * predict_actions = bool(targets_by_stream.get("low_level")) # * predict_actions = bool(targets_by_stream.get("low_level"))
@@ -50,44 +51,38 @@ blend:
# ---------------------------------------------------------- # ----------------------------------------------------------
# FLAVOR 1: action_execution (main path) # FLAVOR 1: action_execution (main path)
#
# Bundles memory updates inline. On most frames the binding
# ``new_memory: emitted_at(t, style=memory)`` returns None and
# only the subtask is supervised. On *boundary* frames (the
# exact timestamp a new memory was annotated — i.e. when a
# subtask just completed) the binding fires and the recipe
# supervises the new memory as a follow-up assistant turn,
# with a "Completed subtask: …" user message in between to
# separate the two outputs in the chat sequence. Mirrors the
# behaviour of the old standalone ``memory_update`` recipe
# but keeps everything inside the unified action_execution.
# ---------------------------------------------------------- # ----------------------------------------------------------
action_execution: action_execution:
weight: 0.60 weight: 0.85
bindings:
new_memory: "emitted_at(t, style=memory)"
messages: messages:
- role: user - role: user
stream: high_level stream: high_level
content: "${task}\nPlan: ${plan}\nMemory: ${memory}" content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask} - {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
# Memory-update tail — only renders at boundary frames where
# ``new_memory`` fires. The new memory is appended as a second
# assistant turn right after the subtask, with no intervening
# user filler: at a subtask boundary the model emits the new
# subtask AND the updated memory in one forward pass.
- {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory}
# ---------------------------------------------------------- # ----------------------------------------------------------
# FLAVOR 2: event-driven text-only paths # FLAVOR 2: event-driven text-only paths
# ---------------------------------------------------------- # ----------------------------------------------------------
memory_update:
weight: 0.10
bindings:
prior_memory: "nth_prev(style=memory, offset=1)"
current_memory: "emitted_at(t, style=memory)"
completed_subtask: "nth_prev(style=subtask, offset=1)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
user_interjection_response:
weight: 0.15
bindings:
prior_plan: "nth_prev(style=plan, offset=1)"
current_plan: "emitted_at(t, style=plan)"
interjection: "emitted_at(t, style=interjection)"
speech: "emitted_at(t, role=assistant, tool_name=say)"
messages:
- {role: user, content: "${task}", stream: high_level}
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
ask_vqa_top: ask_vqa_top:
weight: 0.075 weight: 0.075
bindings: bindings:
@@ -738,24 +738,30 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]: def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
"""``memory_update`` recipe layout.""" """Memory-update prompt — matches the boundary-frame tail of
msgs: list[dict[str, Any]] = [ ``action_execution`` in the v2 recipes.
{"role": "user", "content": state.get("task") or ""}
] Recipe layout on a boundary frame:
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
assistant: "${subtask}"
assistant: → predicts new memory
At inference we fire this when the runtime detects a subtask
transition; the freshly-predicted subtask lives in
``state['current_subtask']``. No "Completed subtask: X" user
filler — the second assistant turn is generated immediately
after the subtask turn.
"""
head_parts = [state.get("task") or ""]
if state.get("current_plan"):
head_parts.append(f"Plan: {state['current_plan']}")
if state.get("current_memory"): if state.get("current_memory"):
msgs.append( head_parts.append(f"Memory: {state['current_memory']}")
{ msgs: list[dict[str, Any]] = [
"role": "assistant", {"role": "user", "content": "\n".join(head_parts)},
"content": f"Previous memory: {state['current_memory']}", ]
}
)
if state.get("current_subtask"): if state.get("current_subtask"):
msgs.append( msgs.append({"role": "assistant", "content": state["current_subtask"]})
{
"role": "user",
"content": f"Completed subtask: {state['current_subtask']}",
}
)
return msgs return msgs
@@ -133,18 +133,29 @@ class SmolVLA2Policy(SmolVLAPolicy):
device = batch[OBS_STATE].device device = batch[OBS_STATE].device
total = torch.zeros((), device=device, dtype=torch.float32) total = torch.zeros((), device=device, dtype=torch.float32)
# ------------------------------------------------------------ run_flow = (
# Flow loss path — only when at least one sample wants actions. self.config.flow_loss_weight > 0
# ------------------------------------------------------------ and ACTION in batch
run_flow = self.config.flow_loss_weight > 0 and ( and (not has_per_sample_routing or bool(predict_actions_t.any().item()))
not has_per_sample_routing or bool(predict_actions_t.any().item())
) )
if run_flow and ACTION in batch:
# ------------------------------------------------------------
# Fused path — one backbone forward for flow + text together.
# ------------------------------------------------------------
if run_flow and has_text_data:
flow_loss, text_loss, flow_diag = self._compute_fused_loss(
batch, text_labels, predict_actions_t, noise=noise, time=time
)
total = total + self.config.flow_loss_weight * flow_loss
total = total + self.config.text_loss_weight * text_loss
loss_dict["flow_loss"] = float(flow_loss.detach().item())
loss_dict["text_loss"] = float(text_loss.detach().item())
for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v
elif run_flow:
per_sample_flow, flow_diag = super().forward( per_sample_flow, flow_diag = super().forward(
batch, noise=noise, time=time, reduction="none" batch, noise=noise, time=time, reduction="none"
) )
# ``per_sample_flow`` has shape (B,) from the SmolVLA
# reduction="none" branch.
if has_per_sample_routing: if has_per_sample_routing:
mask = predict_actions_t.to(per_sample_flow.dtype) mask = predict_actions_t.to(per_sample_flow.dtype)
masked = per_sample_flow * mask masked = per_sample_flow * mask
@@ -156,11 +167,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
loss_dict["flow_loss"] = float(flow_loss.detach().item()) loss_dict["flow_loss"] = float(flow_loss.detach().item())
for k, v in flow_diag.items(): for k, v in flow_diag.items():
loss_dict[f"flow_{k}"] = v loss_dict[f"flow_{k}"] = v
elif has_text_data:
# ------------------------------------------------------------
# Text loss path — prefix-only forward → lm_head → CE.
# ------------------------------------------------------------
if has_text_data:
text_loss = self._compute_text_loss(batch, text_labels) text_loss = self._compute_text_loss(batch, text_labels)
total = total + self.config.text_loss_weight * text_loss total = total + self.config.text_loss_weight * text_loss
loss_dict["text_loss"] = float(text_loss.detach().item()) loss_dict["text_loss"] = float(text_loss.detach().item())
@@ -253,6 +260,143 @@ class SmolVLA2Policy(SmolVLAPolicy):
) )
return loss / valid_labels.sum().clamp(min=1) return loss / valid_labels.sum().clamp(min=1)
# ------------------------------------------------------------------
# Fused flow + text loss (single backbone forward)
# ------------------------------------------------------------------
def _compute_fused_loss(
self,
batch: dict[str, Tensor],
text_labels: Tensor,
predict_actions_t: Tensor | None,
noise: Tensor | None = None,
time: Tensor | None = None,
) -> tuple[Tensor, Tensor, dict[str, Any]]:
"""One backbone forward → both flow MSE and text CE.
Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one
``vlm_with_expert`` call) but captures **both** outputs:
* ``prefix_out[:, lang_start:lang_end]`` ``lm_head`` CE on
``text_labels`` (same slicing as ``_compute_text_loss``).
* ``suffix_out[:, -chunk_size:]`` ``action_out_proj`` flow
MSE against ``noise - actions`` (same as the parent forward).
Saves one backbone pass per training step vs. running the flow
and text paths separately same trick PI052Policy uses in
``_compute_all_losses_fused``.
"""
from ..smolvla.modeling_smolvla import resize_with_pad # noqa: F401 (kept for parity)
cfg = self.config
if cfg.adapt_to_pi_aloha:
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
images, img_masks = self.prepare_images(batch)
state = self.prepare_state(batch)
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
actions = self.prepare_action(batch)
inner = self.model
if noise is None:
noise = inner.sample_noise(actions.shape, actions.device)
if time is None:
time = inner.sample_time(actions.shape[0], actions.device)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
u_t = noise - actions
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
)
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
position_ids = torch.cumsum(pad_masks, dim=1) - 1
out_pair, _ = inner.vlm_with_expert.forward(
attention_mask=att_2d_masks,
position_ids=position_ids,
past_key_values=None,
inputs_embeds=[prefix_embs, suffix_embs],
use_cache=False,
fill_kv_cache=False,
)
prefix_out, suffix_out = out_pair[0], out_pair[1]
if prefix_out is None or suffix_out is None:
raise RuntimeError(
"SmolVLA2: fused forward expected both prefix and suffix "
"hidden states from vlm_with_expert."
)
# ---------------- flow loss (per-sample maskable) ----------------
chunk = cfg.chunk_size
suffix_chunk = suffix_out[:, -chunk:].to(torch.float32)
v_t = inner.action_out_proj(suffix_chunk)
losses = F.mse_loss(u_t, v_t, reduction="none")
original_action_dim = cfg.action_feature.shape[0]
losses = losses[:, :, :original_action_dim]
flow_diag = {"losses_after_forward": float(losses.detach().mean().item())}
actions_is_pad = batch.get("action_is_pad")
if actions_is_pad is not None:
in_episode = ~actions_is_pad
losses = losses * in_episode.unsqueeze(-1)
flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item())
losses = losses[:, :, : cfg.max_action_dim]
flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item())
per_sample_flow = losses.mean(dim=(1, 2))
if predict_actions_t is not None:
mask = predict_actions_t.to(per_sample_flow.dtype)
flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0)
else:
flow_loss = per_sample_flow.mean()
# ---------------- text loss (lang slice of prefix) ---------------
num_lang = lang_tokens.shape[1]
state_for_dim = state if state.ndim >= 2 else state[:, None]
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
if num_state < 1:
num_state = 1
prefix_len = prefix_out.shape[1]
lang_end = prefix_len - num_state
lang_start = lang_end - num_lang
if lang_start < 0 or lang_end > prefix_len:
raise RuntimeError(
f"SmolVLA2: fused forward could not locate lang range "
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
f"num_state={num_state})."
)
vlm = inner.vlm_with_expert.vlm
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
logits = vlm.lm_head(lang_hidden)
if text_labels.shape[1] != num_lang:
common = min(text_labels.shape[1], num_lang)
logits = logits[:, :common]
text_labels = text_labels[:, :common]
shift_logits = logits[:, :-1, :].contiguous()
shift_labels = text_labels[:, 1:].contiguous().long()
valid_labels = shift_labels != -100
ce = F.cross_entropy(
shift_logits.reshape(-1, shift_logits.shape[-1]),
shift_labels.reshape(-1),
ignore_index=-100,
reduction="sum",
)
text_loss = ce / valid_labels.sum().clamp(min=1)
return flow_loss, text_loss, flow_diag
# ------------------------------------------------------------------ # ------------------------------------------------------------------
# Inference: text generation # Inference: text generation
# ------------------------------------------------------------------ # ------------------------------------------------------------------