mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
refactor(recipes): fold memory into action_execution, drop interjection, fuse smolvla2 forward
Recipe changes: * action_execution now bundles the memory update as a second assistant target gated on a new ``new_memory`` binding (fires only at subtask-boundary frames). No "Completed subtask: X" filler — the model emits the new subtask AND the updated memory back-to-back in one prefix. * user_interjection_response sub-recipe removed (current datasets don't have interjection / say() annotations). * Standalone memory_update sub-recipe removed (folded above). * Weights rebalanced: action_execution 0.85, ask_vqa_top/wrist 0.075 each (sums to 1.0). Runtime ``_msgs_for_memory`` updated to match the new boundary-frame prompt layout. Modeling: * SmolVLA2Policy now fuses the flow + text losses into a SINGLE backbone forward via ``_compute_fused_loss`` (one vlm_with_expert pass with [prefix, suffix] embeds, then both lm_head CE on lang slice + action_out_proj MSE on suffix). Mirrors pi052's existing ``_compute_all_losses_fused`` — saves one backbone pass per training step. Examples: * Removed the two training SLURM scaffolds; they were out-of-date with the recipe refactor. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This commit is contained in:
@@ -1,75 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#SBATCH --job-name=pi052-hirobot
|
|
||||||
#SBATCH --partition=hopper-prod
|
|
||||||
#SBATCH --qos=high
|
|
||||||
#SBATCH --time=48:00:00
|
|
||||||
#SBATCH --ntasks=1
|
|
||||||
#SBATCH --gpus-per-task=8
|
|
||||||
|
|
||||||
# π0.5 v2 training — reproduces the π0.5 paper's hierarchical recipe.
|
|
||||||
#
|
|
||||||
# Same recipe blend as the SmolVLA2 stack (recipes/pi052_hirobot.yaml),
|
|
||||||
# just on the PaliGemma 2B + Gemma-300m action-expert backbone the
|
|
||||||
# paper uses. The text head learns subtask prediction via cross-
|
|
||||||
# entropy on supervised spans; the action expert learns the flow
|
|
||||||
# field. Paper §IV.D mixes the two losses with α=10, which we encode
|
|
||||||
# as flow_loss_weight=10 / text_loss_weight=1.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
|
||||||
|
|
||||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
|
||||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
|
||||||
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
|
|
||||||
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
|
|
||||||
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
|
|
||||||
|
|
||||||
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
|
|
||||||
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/pi052_hirobot_super_poulain}"
|
|
||||||
JOB_NAME="${JOB_NAME:-pi052-hirobot-super-poulain}"
|
|
||||||
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
|
||||||
BATCH_SIZE="${BATCH_SIZE:-32}"
|
|
||||||
STEPS="${STEPS:-15000}"
|
|
||||||
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
|
|
||||||
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/pi052_hirobot_${STEPS}_${RUN_ID}}"
|
|
||||||
|
|
||||||
echo "Training pi052 on $DATASET"
|
|
||||||
echo " GPUs: $NUM_PROCESSES"
|
|
||||||
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
|
|
||||||
echo " steps: $STEPS"
|
|
||||||
echo " output: $OUTPUT_DIR"
|
|
||||||
echo " loss mix: flow_loss_weight=10 (paper α), text_loss_weight=1"
|
|
||||||
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
|
|
||||||
|
|
||||||
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
|
|
||||||
-m lerobot.scripts.lerobot_train \
|
|
||||||
--policy.type=pi052 \
|
|
||||||
--policy.recipe_path=recipes/pi052_hirobot.yaml \
|
|
||||||
--dataset.repo_id="$DATASET" \
|
|
||||||
--dataset.revision=main \
|
|
||||||
--dataset.video_backend=pyav \
|
|
||||||
--output_dir="$OUTPUT_DIR" \
|
|
||||||
--job_name="$JOB_NAME" \
|
|
||||||
--policy.repo_id="$POLICY_REPO_ID" \
|
|
||||||
--policy.compile_model=false \
|
|
||||||
--policy.device=cuda \
|
|
||||||
--policy.tokenizer_max_length=512 \
|
|
||||||
--policy.text_loss_weight=1.0 \
|
|
||||||
--policy.flow_loss_weight=10.0 \
|
|
||||||
--policy.unfreeze_lm_head=true \
|
|
||||||
--steps="$STEPS" \
|
|
||||||
--policy.scheduler_decay_steps="$STEPS" \
|
|
||||||
--batch_size="$BATCH_SIZE" \
|
|
||||||
--wandb.enable=true \
|
|
||||||
--wandb.disable_artifact=true \
|
|
||||||
--wandb.project=hirobot \
|
|
||||||
--log_freq=100 \
|
|
||||||
--save_freq="$STEPS" \
|
|
||||||
--num_workers=0 \
|
|
||||||
--dataset.image_transforms.enable=true \
|
|
||||||
--dataset.image_transforms.max_num_transforms=3 \
|
|
||||||
--dataset.image_transforms.random_order=true \
|
|
||||||
--policy.plan_dropout_prob=0.30 \
|
|
||||||
--policy.memory_dropout_prob=0.30 \
|
|
||||||
--policy.subtask_dropout_prob=0.20
|
|
||||||
@@ -1,82 +0,0 @@
|
|||||||
#!/bin/bash
|
|
||||||
#SBATCH --job-name=smolvla2-hirobot
|
|
||||||
#SBATCH --partition=hopper-prod
|
|
||||||
#SBATCH --qos=high
|
|
||||||
#SBATCH --time=48:00:00
|
|
||||||
#SBATCH --ntasks=1
|
|
||||||
#SBATCH --gpus-per-task=8
|
|
||||||
|
|
||||||
# SmolVLA2 training on an annotated dataset.
|
|
||||||
#
|
|
||||||
# The high_level_subtask recipe (recipes/smolvla2_hirobot.yaml) was
|
|
||||||
# fixed in PR3 to supervise the LM head with the *current* active
|
|
||||||
# subtask span at every frame, not the next-span target which is
|
|
||||||
# empty on stable phases. With the old recipe the head learned to
|
|
||||||
# emit ``\n`` on every chunk boundary; the new one supervises a
|
|
||||||
# real, scene-grounded string at every frame.
|
|
||||||
#
|
|
||||||
# Two regularisers are still on:
|
|
||||||
#
|
|
||||||
# * --dataset.image_transforms.enable=true: torchvision-v2
|
|
||||||
# ColorJitter + SharpnessJitter + RandomAffine per frame; default
|
|
||||||
# envelope (brightness ±20% etc).
|
|
||||||
# * --policy.{plan,memory,subtask}_dropout_prob: randomly drop the
|
|
||||||
# context messages carrying the named recipe binding so the model
|
|
||||||
# handles missing/stale context. Mirrors Pi0.7 §V.E.
|
|
||||||
|
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
cd "${LEROBOT_ROOT:-$HOME/lerobot}"
|
|
||||||
|
|
||||||
export PATH="$HOME/miniconda3/bin:$HOME/.local/bin:$PATH"
|
|
||||||
export LD_LIBRARY_PATH="$HOME/miniconda3/lib:${LD_LIBRARY_PATH:-}"
|
|
||||||
export NCCL_TIMEOUT="${NCCL_TIMEOUT:-1800}"
|
|
||||||
export HF_HUB_DOWNLOAD_TIMEOUT="${HF_HUB_DOWNLOAD_TIMEOUT:-120}"
|
|
||||||
export WANDB_INIT_TIMEOUT="${WANDB_INIT_TIMEOUT:-300}"
|
|
||||||
|
|
||||||
DATASET="${DATASET:-pepijn223/super_poulain_full_tool3}"
|
|
||||||
POLICY_REPO_ID="${POLICY_REPO_ID:-pepijn223/smolvla2_hirobot_super_poulain_tool6}"
|
|
||||||
JOB_NAME="${JOB_NAME:-smolvla2-hirobot-super-poulain-tool6}"
|
|
||||||
NUM_PROCESSES="${NUM_PROCESSES:-8}"
|
|
||||||
BATCH_SIZE="${BATCH_SIZE:-32}"
|
|
||||||
STEPS="${STEPS:-15000}"
|
|
||||||
RUN_ID="${SLURM_JOB_ID:-$(date +%Y%m%d_%H%M%S)}"
|
|
||||||
OUTPUT_DIR="${OUTPUT_DIR:-/fsx/pepijn/outputs/train/smolvla2_hirobot_super_poulain_tool3_${STEPS}_${RUN_ID}}"
|
|
||||||
|
|
||||||
echo "Training smolvla2 on $DATASET"
|
|
||||||
echo " GPUs: $NUM_PROCESSES"
|
|
||||||
echo " batch: $BATCH_SIZE / GPU (global=$((NUM_PROCESSES * BATCH_SIZE)))"
|
|
||||||
echo " steps: $STEPS"
|
|
||||||
echo " output: $OUTPUT_DIR"
|
|
||||||
echo " augmentation: image_transforms ON, prompt dropout {plan:0.30 memory:0.30 subtask:0.20}"
|
|
||||||
|
|
||||||
accelerate launch --multi_gpu --num_processes="$NUM_PROCESSES" \
|
|
||||||
-m lerobot.scripts.lerobot_train \
|
|
||||||
--policy.type=smolvla2 \
|
|
||||||
--policy.recipe_path=recipes/smolvla2_hirobot.yaml \
|
|
||||||
--dataset.repo_id="$DATASET" \
|
|
||||||
--dataset.revision=main \
|
|
||||||
--dataset.video_backend=pyav \
|
|
||||||
--output_dir="$OUTPUT_DIR" \
|
|
||||||
--job_name="$JOB_NAME" \
|
|
||||||
--policy.repo_id="$POLICY_REPO_ID" \
|
|
||||||
--policy.compile_model=false \
|
|
||||||
--policy.device=cuda \
|
|
||||||
--policy.tokenizer_max_length=512 \
|
|
||||||
--policy.text_loss_weight=1.0 \
|
|
||||||
--policy.flow_loss_weight=10.0 \
|
|
||||||
--steps="$STEPS" \
|
|
||||||
--policy.scheduler_decay_steps="$STEPS" \
|
|
||||||
--batch_size="$BATCH_SIZE" \
|
|
||||||
--wandb.enable=true \
|
|
||||||
--wandb.disable_artifact=true \
|
|
||||||
--wandb.project=hirobot \
|
|
||||||
--log_freq=100 \
|
|
||||||
--save_freq="$STEPS" \
|
|
||||||
--num_workers=0 \
|
|
||||||
--dataset.image_transforms.enable=true \
|
|
||||||
--dataset.image_transforms.max_num_transforms=3 \
|
|
||||||
--dataset.image_transforms.random_order=true \
|
|
||||||
--policy.plan_dropout_prob=0.30 \
|
|
||||||
--policy.memory_dropout_prob=0.30 \
|
|
||||||
--policy.subtask_dropout_prob=0.20
|
|
||||||
@@ -22,53 +22,50 @@
|
|||||||
# Pi 0.7 §V.A — subtask in the prompt + flow on actions.
|
# Pi 0.7 §V.A — subtask in the prompt + flow on actions.
|
||||||
#
|
#
|
||||||
# Flavor 2 — event-driven text-only recipes
|
# Flavor 2 — event-driven text-only recipes
|
||||||
# ``memory_update``, ``user_interjection_response``,
|
|
||||||
# ``ask_vqa_*``. Each handles a specific high-level event
|
# ``ask_vqa_*``. Each handles a specific high-level event
|
||||||
# with a TEXT output. ``if_present`` guards keep them from
|
# with a TEXT output. ``if_present`` guards keep them from
|
||||||
# firing on frames without the relevant annotation.
|
# firing on frames without the relevant annotation.
|
||||||
|
#
|
||||||
|
# Memory updates are folded INTO ``action_execution`` as a
|
||||||
|
# conditional second target gated on boundary frames — see
|
||||||
|
# ``smolvla2_hirobot.yaml`` for the rationale. The
|
||||||
|
# ``user_interjection_response`` recipe was dropped — the
|
||||||
|
# current datasets don't include interjection / say() annotations.
|
||||||
|
|
||||||
blend:
|
blend:
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# FLAVOR 1: action_execution (main path)
|
# FLAVOR 1: action_execution (main path)
|
||||||
|
#
|
||||||
|
# Bundles memory updates inline. On most frames the binding
|
||||||
|
# ``new_memory: emitted_at(t, style=memory)`` returns None and
|
||||||
|
# only the subtask is supervised. On *boundary* frames (the
|
||||||
|
# exact timestamp a new memory was annotated — i.e. when a
|
||||||
|
# subtask just completed) the binding fires and the recipe
|
||||||
|
# supervises the new memory as a follow-up assistant turn,
|
||||||
|
# with a "Completed subtask: …" user message in between to
|
||||||
|
# separate the two outputs in the rendered prefix.
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
action_execution:
|
action_execution:
|
||||||
weight: 0.60
|
weight: 0.85
|
||||||
|
bindings:
|
||||||
|
new_memory: "emitted_at(t, style=memory)"
|
||||||
messages:
|
messages:
|
||||||
- role: user
|
- role: user
|
||||||
stream: high_level
|
stream: high_level
|
||||||
content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
|
content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
|
||||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
|
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
|
||||||
|
# Memory-update tail — only renders at boundary frames where
|
||||||
|
# ``new_memory`` fires. The new memory is appended as a second
|
||||||
|
# assistant turn right after the subtask, with no intervening
|
||||||
|
# user filler: at a subtask boundary the model emits the new
|
||||||
|
# subtask AND the updated memory in one forward pass.
|
||||||
|
- {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory}
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# FLAVOR 2: event-driven text-only paths
|
# FLAVOR 2: event-driven text-only paths
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
memory_update:
|
|
||||||
weight: 0.10
|
|
||||||
bindings:
|
|
||||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
|
||||||
current_memory: "emitted_at(t, style=memory)"
|
|
||||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
|
||||||
messages:
|
|
||||||
- {role: user, content: "${task}", stream: high_level}
|
|
||||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
|
||||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
|
||||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
|
||||||
|
|
||||||
user_interjection_response:
|
|
||||||
weight: 0.15
|
|
||||||
bindings:
|
|
||||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
|
||||||
current_plan: "emitted_at(t, style=plan)"
|
|
||||||
interjection: "emitted_at(t, style=interjection)"
|
|
||||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
|
||||||
messages:
|
|
||||||
- {role: user, content: "${task}", stream: high_level}
|
|
||||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
|
||||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
|
||||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
|
||||||
|
|
||||||
ask_vqa_top:
|
ask_vqa_top:
|
||||||
weight: 0.075
|
weight: 0.075
|
||||||
bindings:
|
bindings:
|
||||||
|
|||||||
@@ -28,16 +28,17 @@
|
|||||||
# Each handles a specific high-level event with a TEXT
|
# Each handles a specific high-level event with a TEXT
|
||||||
# output (no action supervision). They fire when the
|
# output (no action supervision). They fire when the
|
||||||
# binding for the event resolves to non-None:
|
# binding for the event resolves to non-None:
|
||||||
# * ``memory_update``: at subtask boundary, predict new
|
|
||||||
# memory from task + prior memory + completed subtask.
|
|
||||||
# * ``user_interjection_response``: on user input, predict
|
|
||||||
# new plan + paired ``say()`` tool call.
|
|
||||||
# * ``ask_vqa_top`` / ``ask_vqa_wrist``: answer a
|
# * ``ask_vqa_top`` / ``ask_vqa_wrist``: answer a
|
||||||
# camera-grounded visual question.
|
# camera-grounded visual question.
|
||||||
# All use ``stream: high_level`` (no flow loss) and rely on
|
# All use ``stream: high_level`` (no flow loss) and rely on
|
||||||
# ``if_present`` guards so they only fire on frames where
|
# ``if_present`` guards so they only fire on frames where
|
||||||
# the relevant event annotation is present.
|
# the relevant event annotation is present.
|
||||||
#
|
#
|
||||||
|
# ``memory_update`` is folded into Flavor 1 (gated on the
|
||||||
|
# ``new_memory`` binding at boundary frames).
|
||||||
|
# ``user_interjection_response`` was dropped — the current
|
||||||
|
# datasets don't include interjection / say() annotations.
|
||||||
|
#
|
||||||
# How the chat tokenizer interprets the flavor split
|
# How the chat tokenizer interprets the flavor split
|
||||||
# ---------------------------------------------------
|
# ---------------------------------------------------
|
||||||
# * predict_actions = bool(targets_by_stream.get("low_level"))
|
# * predict_actions = bool(targets_by_stream.get("low_level"))
|
||||||
@@ -50,44 +51,38 @@ blend:
|
|||||||
|
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# FLAVOR 1: action_execution (main path)
|
# FLAVOR 1: action_execution (main path)
|
||||||
|
#
|
||||||
|
# Bundles memory updates inline. On most frames the binding
|
||||||
|
# ``new_memory: emitted_at(t, style=memory)`` returns None and
|
||||||
|
# only the subtask is supervised. On *boundary* frames (the
|
||||||
|
# exact timestamp a new memory was annotated — i.e. when a
|
||||||
|
# subtask just completed) the binding fires and the recipe
|
||||||
|
# supervises the new memory as a follow-up assistant turn,
|
||||||
|
# with a "Completed subtask: …" user message in between to
|
||||||
|
# separate the two outputs in the chat sequence. Mirrors the
|
||||||
|
# behaviour of the old standalone ``memory_update`` recipe
|
||||||
|
# but keeps everything inside the unified action_execution.
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
action_execution:
|
action_execution:
|
||||||
weight: 0.60
|
weight: 0.85
|
||||||
|
bindings:
|
||||||
|
new_memory: "emitted_at(t, style=memory)"
|
||||||
messages:
|
messages:
|
||||||
- role: user
|
- role: user
|
||||||
stream: high_level
|
stream: high_level
|
||||||
content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
|
content: "${task}\nPlan: ${plan}\nMemory: ${memory}"
|
||||||
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
|
- {role: assistant, content: "${subtask}", stream: low_level, target: true, if_present: subtask}
|
||||||
|
# Memory-update tail — only renders at boundary frames where
|
||||||
|
# ``new_memory`` fires. The new memory is appended as a second
|
||||||
|
# assistant turn right after the subtask, with no intervening
|
||||||
|
# user filler: at a subtask boundary the model emits the new
|
||||||
|
# subtask AND the updated memory in one forward pass.
|
||||||
|
- {role: assistant, content: "${new_memory}", stream: high_level, target: true, if_present: new_memory}
|
||||||
|
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
# FLAVOR 2: event-driven text-only paths
|
# FLAVOR 2: event-driven text-only paths
|
||||||
# ----------------------------------------------------------
|
# ----------------------------------------------------------
|
||||||
|
|
||||||
memory_update:
|
|
||||||
weight: 0.10
|
|
||||||
bindings:
|
|
||||||
prior_memory: "nth_prev(style=memory, offset=1)"
|
|
||||||
current_memory: "emitted_at(t, style=memory)"
|
|
||||||
completed_subtask: "nth_prev(style=subtask, offset=1)"
|
|
||||||
messages:
|
|
||||||
- {role: user, content: "${task}", stream: high_level}
|
|
||||||
- {role: assistant, content: "Previous memory: ${prior_memory}", stream: high_level, if_present: prior_memory}
|
|
||||||
- {role: user, content: "Completed subtask: ${completed_subtask}", stream: high_level, if_present: completed_subtask}
|
|
||||||
- {role: assistant, content: "${current_memory}", stream: high_level, target: true, if_present: current_memory}
|
|
||||||
|
|
||||||
user_interjection_response:
|
|
||||||
weight: 0.15
|
|
||||||
bindings:
|
|
||||||
prior_plan: "nth_prev(style=plan, offset=1)"
|
|
||||||
current_plan: "emitted_at(t, style=plan)"
|
|
||||||
interjection: "emitted_at(t, style=interjection)"
|
|
||||||
speech: "emitted_at(t, role=assistant, tool_name=say)"
|
|
||||||
messages:
|
|
||||||
- {role: user, content: "${task}", stream: high_level}
|
|
||||||
- {role: assistant, content: "Previous plan:\n${prior_plan}", stream: high_level, if_present: prior_plan}
|
|
||||||
- {role: user, content: "${interjection}", stream: high_level, if_present: interjection}
|
|
||||||
- {role: assistant, content: "${current_plan}", stream: high_level, target: true, if_present: current_plan, tool_calls_from: speech}
|
|
||||||
|
|
||||||
ask_vqa_top:
|
ask_vqa_top:
|
||||||
weight: 0.075
|
weight: 0.075
|
||||||
bindings:
|
bindings:
|
||||||
|
|||||||
@@ -738,24 +738,30 @@ def _msgs_for_subtask(state: dict[str, Any]) -> list[dict[str, Any]]:
|
|||||||
|
|
||||||
|
|
||||||
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
def _msgs_for_memory(state: dict[str, Any]) -> list[dict[str, Any]]:
|
||||||
"""``memory_update`` recipe layout."""
|
"""Memory-update prompt — matches the boundary-frame tail of
|
||||||
msgs: list[dict[str, Any]] = [
|
``action_execution`` in the v2 recipes.
|
||||||
{"role": "user", "content": state.get("task") or ""}
|
|
||||||
]
|
Recipe layout on a boundary frame:
|
||||||
|
user: "${task}\\nPlan: ${plan}\\nMemory: ${memory}"
|
||||||
|
assistant: "${subtask}"
|
||||||
|
assistant: → predicts new memory
|
||||||
|
|
||||||
|
At inference we fire this when the runtime detects a subtask
|
||||||
|
transition; the freshly-predicted subtask lives in
|
||||||
|
``state['current_subtask']``. No "Completed subtask: X" user
|
||||||
|
filler — the second assistant turn is generated immediately
|
||||||
|
after the subtask turn.
|
||||||
|
"""
|
||||||
|
head_parts = [state.get("task") or ""]
|
||||||
|
if state.get("current_plan"):
|
||||||
|
head_parts.append(f"Plan: {state['current_plan']}")
|
||||||
if state.get("current_memory"):
|
if state.get("current_memory"):
|
||||||
msgs.append(
|
head_parts.append(f"Memory: {state['current_memory']}")
|
||||||
{
|
msgs: list[dict[str, Any]] = [
|
||||||
"role": "assistant",
|
{"role": "user", "content": "\n".join(head_parts)},
|
||||||
"content": f"Previous memory: {state['current_memory']}",
|
]
|
||||||
}
|
|
||||||
)
|
|
||||||
if state.get("current_subtask"):
|
if state.get("current_subtask"):
|
||||||
msgs.append(
|
msgs.append({"role": "assistant", "content": state["current_subtask"]})
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Completed subtask: {state['current_subtask']}",
|
|
||||||
}
|
|
||||||
)
|
|
||||||
return msgs
|
return msgs
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -133,18 +133,29 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
device = batch[OBS_STATE].device
|
device = batch[OBS_STATE].device
|
||||||
total = torch.zeros((), device=device, dtype=torch.float32)
|
total = torch.zeros((), device=device, dtype=torch.float32)
|
||||||
|
|
||||||
# ------------------------------------------------------------
|
run_flow = (
|
||||||
# Flow loss path — only when at least one sample wants actions.
|
self.config.flow_loss_weight > 0
|
||||||
# ------------------------------------------------------------
|
and ACTION in batch
|
||||||
run_flow = self.config.flow_loss_weight > 0 and (
|
and (not has_per_sample_routing or bool(predict_actions_t.any().item()))
|
||||||
not has_per_sample_routing or bool(predict_actions_t.any().item())
|
|
||||||
)
|
)
|
||||||
if run_flow and ACTION in batch:
|
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
# Fused path — one backbone forward for flow + text together.
|
||||||
|
# ------------------------------------------------------------
|
||||||
|
if run_flow and has_text_data:
|
||||||
|
flow_loss, text_loss, flow_diag = self._compute_fused_loss(
|
||||||
|
batch, text_labels, predict_actions_t, noise=noise, time=time
|
||||||
|
)
|
||||||
|
total = total + self.config.flow_loss_weight * flow_loss
|
||||||
|
total = total + self.config.text_loss_weight * text_loss
|
||||||
|
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||||
|
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||||
|
for k, v in flow_diag.items():
|
||||||
|
loss_dict[f"flow_{k}"] = v
|
||||||
|
elif run_flow:
|
||||||
per_sample_flow, flow_diag = super().forward(
|
per_sample_flow, flow_diag = super().forward(
|
||||||
batch, noise=noise, time=time, reduction="none"
|
batch, noise=noise, time=time, reduction="none"
|
||||||
)
|
)
|
||||||
# ``per_sample_flow`` has shape (B,) from the SmolVLA
|
|
||||||
# reduction="none" branch.
|
|
||||||
if has_per_sample_routing:
|
if has_per_sample_routing:
|
||||||
mask = predict_actions_t.to(per_sample_flow.dtype)
|
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||||
masked = per_sample_flow * mask
|
masked = per_sample_flow * mask
|
||||||
@@ -156,11 +167,7 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
loss_dict["flow_loss"] = float(flow_loss.detach().item())
|
||||||
for k, v in flow_diag.items():
|
for k, v in flow_diag.items():
|
||||||
loss_dict[f"flow_{k}"] = v
|
loss_dict[f"flow_{k}"] = v
|
||||||
|
elif has_text_data:
|
||||||
# ------------------------------------------------------------
|
|
||||||
# Text loss path — prefix-only forward → lm_head → CE.
|
|
||||||
# ------------------------------------------------------------
|
|
||||||
if has_text_data:
|
|
||||||
text_loss = self._compute_text_loss(batch, text_labels)
|
text_loss = self._compute_text_loss(batch, text_labels)
|
||||||
total = total + self.config.text_loss_weight * text_loss
|
total = total + self.config.text_loss_weight * text_loss
|
||||||
loss_dict["text_loss"] = float(text_loss.detach().item())
|
loss_dict["text_loss"] = float(text_loss.detach().item())
|
||||||
@@ -253,6 +260,143 @@ class SmolVLA2Policy(SmolVLAPolicy):
|
|||||||
)
|
)
|
||||||
return loss / valid_labels.sum().clamp(min=1)
|
return loss / valid_labels.sum().clamp(min=1)
|
||||||
|
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
# Fused flow + text loss (single backbone forward)
|
||||||
|
# ------------------------------------------------------------------
|
||||||
|
|
||||||
|
def _compute_fused_loss(
|
||||||
|
self,
|
||||||
|
batch: dict[str, Tensor],
|
||||||
|
text_labels: Tensor,
|
||||||
|
predict_actions_t: Tensor | None,
|
||||||
|
noise: Tensor | None = None,
|
||||||
|
time: Tensor | None = None,
|
||||||
|
) -> tuple[Tensor, Tensor, dict[str, Any]]:
|
||||||
|
"""One backbone forward → both flow MSE and text CE.
|
||||||
|
|
||||||
|
Mirrors ``SmolVLAModel.forward`` (prefix + suffix concat, one
|
||||||
|
``vlm_with_expert`` call) but captures **both** outputs:
|
||||||
|
|
||||||
|
* ``prefix_out[:, lang_start:lang_end]`` → ``lm_head`` → CE on
|
||||||
|
``text_labels`` (same slicing as ``_compute_text_loss``).
|
||||||
|
* ``suffix_out[:, -chunk_size:]`` → ``action_out_proj`` → flow
|
||||||
|
MSE against ``noise - actions`` (same as the parent forward).
|
||||||
|
|
||||||
|
Saves one backbone pass per training step vs. running the flow
|
||||||
|
and text paths separately — same trick PI052Policy uses in
|
||||||
|
``_compute_all_losses_fused``.
|
||||||
|
"""
|
||||||
|
from ..smolvla.modeling_smolvla import resize_with_pad # noqa: F401 (kept for parity)
|
||||||
|
|
||||||
|
cfg = self.config
|
||||||
|
if cfg.adapt_to_pi_aloha:
|
||||||
|
batch[OBS_STATE] = self._pi_aloha_decode_state(batch[OBS_STATE])
|
||||||
|
batch[ACTION] = self._pi_aloha_encode_actions_inv(batch[ACTION])
|
||||||
|
|
||||||
|
images, img_masks = self.prepare_images(batch)
|
||||||
|
state = self.prepare_state(batch)
|
||||||
|
lang_tokens = batch[OBS_LANGUAGE_TOKENS]
|
||||||
|
lang_masks = batch[OBS_LANGUAGE_ATTENTION_MASK]
|
||||||
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
inner = self.model
|
||||||
|
if noise is None:
|
||||||
|
noise = inner.sample_noise(actions.shape, actions.device)
|
||||||
|
if time is None:
|
||||||
|
time = inner.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
|
time_expanded = time[:, None, None]
|
||||||
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
|
u_t = noise - actions
|
||||||
|
|
||||||
|
prefix_embs, prefix_pad_masks, prefix_att_masks = inner.embed_prefix(
|
||||||
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
|
)
|
||||||
|
suffix_embs, suffix_pad_masks, suffix_att_masks = inner.embed_suffix(x_t, time)
|
||||||
|
|
||||||
|
pad_masks = torch.cat([prefix_pad_masks, suffix_pad_masks], dim=1)
|
||||||
|
att_masks = torch.cat([prefix_att_masks, suffix_att_masks], dim=1)
|
||||||
|
att_2d_masks = make_att_2d_masks(pad_masks, att_masks)
|
||||||
|
position_ids = torch.cumsum(pad_masks, dim=1) - 1
|
||||||
|
|
||||||
|
out_pair, _ = inner.vlm_with_expert.forward(
|
||||||
|
attention_mask=att_2d_masks,
|
||||||
|
position_ids=position_ids,
|
||||||
|
past_key_values=None,
|
||||||
|
inputs_embeds=[prefix_embs, suffix_embs],
|
||||||
|
use_cache=False,
|
||||||
|
fill_kv_cache=False,
|
||||||
|
)
|
||||||
|
prefix_out, suffix_out = out_pair[0], out_pair[1]
|
||||||
|
if prefix_out is None or suffix_out is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"SmolVLA2: fused forward expected both prefix and suffix "
|
||||||
|
"hidden states from vlm_with_expert."
|
||||||
|
)
|
||||||
|
|
||||||
|
# ---------------- flow loss (per-sample maskable) ----------------
|
||||||
|
chunk = cfg.chunk_size
|
||||||
|
suffix_chunk = suffix_out[:, -chunk:].to(torch.float32)
|
||||||
|
v_t = inner.action_out_proj(suffix_chunk)
|
||||||
|
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
|
original_action_dim = cfg.action_feature.shape[0]
|
||||||
|
losses = losses[:, :, :original_action_dim]
|
||||||
|
flow_diag = {"losses_after_forward": float(losses.detach().mean().item())}
|
||||||
|
|
||||||
|
actions_is_pad = batch.get("action_is_pad")
|
||||||
|
if actions_is_pad is not None:
|
||||||
|
in_episode = ~actions_is_pad
|
||||||
|
losses = losses * in_episode.unsqueeze(-1)
|
||||||
|
flow_diag["losses_after_in_ep_bound"] = float(losses.detach().mean().item())
|
||||||
|
|
||||||
|
losses = losses[:, :, : cfg.max_action_dim]
|
||||||
|
flow_diag["losses_after_rm_padding"] = float(losses.detach().mean().item())
|
||||||
|
|
||||||
|
per_sample_flow = losses.mean(dim=(1, 2))
|
||||||
|
if predict_actions_t is not None:
|
||||||
|
mask = predict_actions_t.to(per_sample_flow.dtype)
|
||||||
|
flow_loss = (per_sample_flow * mask).sum() / mask.sum().clamp(min=1.0)
|
||||||
|
else:
|
||||||
|
flow_loss = per_sample_flow.mean()
|
||||||
|
|
||||||
|
# ---------------- text loss (lang slice of prefix) ---------------
|
||||||
|
num_lang = lang_tokens.shape[1]
|
||||||
|
state_for_dim = state if state.ndim >= 2 else state[:, None]
|
||||||
|
num_state = state_for_dim.shape[1] if state_for_dim.ndim >= 2 else 1
|
||||||
|
if num_state < 1:
|
||||||
|
num_state = 1
|
||||||
|
prefix_len = prefix_out.shape[1]
|
||||||
|
lang_end = prefix_len - num_state
|
||||||
|
lang_start = lang_end - num_lang
|
||||||
|
if lang_start < 0 or lang_end > prefix_len:
|
||||||
|
raise RuntimeError(
|
||||||
|
f"SmolVLA2: fused forward could not locate lang range "
|
||||||
|
f"(prefix_len={prefix_len}, num_lang={num_lang}, "
|
||||||
|
f"num_state={num_state})."
|
||||||
|
)
|
||||||
|
vlm = inner.vlm_with_expert.vlm
|
||||||
|
lang_hidden = prefix_out[:, lang_start:lang_end].to(vlm.lm_head.weight.dtype)
|
||||||
|
logits = vlm.lm_head(lang_hidden)
|
||||||
|
|
||||||
|
if text_labels.shape[1] != num_lang:
|
||||||
|
common = min(text_labels.shape[1], num_lang)
|
||||||
|
logits = logits[:, :common]
|
||||||
|
text_labels = text_labels[:, :common]
|
||||||
|
|
||||||
|
shift_logits = logits[:, :-1, :].contiguous()
|
||||||
|
shift_labels = text_labels[:, 1:].contiguous().long()
|
||||||
|
valid_labels = shift_labels != -100
|
||||||
|
ce = F.cross_entropy(
|
||||||
|
shift_logits.reshape(-1, shift_logits.shape[-1]),
|
||||||
|
shift_labels.reshape(-1),
|
||||||
|
ignore_index=-100,
|
||||||
|
reduction="sum",
|
||||||
|
)
|
||||||
|
text_loss = ce / valid_labels.sum().clamp(min=1)
|
||||||
|
|
||||||
|
return flow_loss, text_loss, flow_diag
|
||||||
|
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
# Inference: text generation
|
# Inference: text generation
|
||||||
# ------------------------------------------------------------------
|
# ------------------------------------------------------------------
|
||||||
|
|||||||
Reference in New Issue
Block a user