mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
16 KiB
16 KiB
Subtask Token Generation Flow Diagram
Overview
This document provides visual representations of how subtask tokens are processed during training and inference in the PI05 model.
Training Flow
┌─────────────────────────────────────────────────────────────┐
│ TRAINING FORWARD PASS │
└─────────────────────────────────────────────────────────────┘
Input Batch:
├─ images (observations)
├─ high_level_task (user prompt tokens)
├─ tokens (instruction prompt with state)
├─ subtask_tokens (ground truth subtask - TARGET)
└─ actions (ground truth actions)
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 1: Decode & Print Ground Truth Subtask Tokens │
│ ---------------------------------------------------------- │
│ • Extract valid tokens (remove padding) │
│ • Decode using tokenizer │
│ • Print: "[Training] Ground truth subtask {i}: {text}" │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 2: Prefix-Only Forward Pass │
│ ---------------------------------------------------------- │
│ Embed Prefix: │
│ • images → image embeddings │
│ • high_level_task → language embeddings │
│ • tokens → language embeddings │
│ • subtask_tokens → language embeddings (with causal mask)│
│ │
│ Forward through transformer (prefix only) │
│ → Get prefix_out │
│ │
│ Apply LM Head: │
│ prefix_out → logits │
│ │
│ Extract subtask logits: │
│ logits[start_index:end_index] │
│ │
│ Compute Cross-Entropy Loss: │
│ CE(predicted_logits, subtask_tokens) → subtask_loss │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 3: Full Forward Pass (Prefix + Suffix) │
│ ---------------------------------------------------------- │
│ Add noisy actions to prefix: │
│ x_t = time * noise + (1-time) * actions │
│ │
│ Forward through transformer: │
│ [prefix_embs, action_embs] → [prefix_out, suffix_out] │
│ │
│ Predict velocity field: │
│ suffix_out → v_t │
│ │
│ Compute Flow Matching Loss: │
│ MSE(u_t, v_t) → flow_loss │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 4: Combined Loss │
│ ---------------------------------------------------------- │
│ total_loss = 10 * flow_loss + subtask_loss │
└─────────────────────────────────────────────────────────────┘
Output: loss, {flow_loss, subtask_loss, loss}
Inference Flow
┌─────────────────────────────────────────────────────────────┐
│ INFERENCE FORWARD PASS │
└─────────────────────────────────────────────────────────────┘
Input Batch:
├─ images (observations)
├─ high_level_task (user prompt tokens)
└─ tokens (instruction prompt with state)
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 1: Autoregressive Subtask Token Generation │
│ ---------------------------------------------------------- │
│ │
│ Initialize: │
│ prefix_embs = [images, high_level_task, tokens] │
│ generated_tokens = [] │
│ │
│ For t in range(max_subtask_tokens): │
│ ┌───────────────────────────────────────────┐ │
│ │ Forward Pass: │ │
│ │ prefix_embs → transformer → prefix_out │ │
│ │ │ │
│ │ Apply LM Head: │ │
│ │ prefix_out → logits │ │
│ │ │ │
│ │ Greedy Decode: │ │
│ │ next_token = argmax(logits[-1]) │ │
│ │ │ │
│ │ Store Token: │ │
│ │ generated_tokens.append(next_token) │ │
│ │ │ │
│ │ Check EOS: │ │
│ │ if next_token == EOS: break │ │
│ │ │ │
│ │ Embed & Append: │ │
│ │ next_emb = embed(next_token) │ │
│ │ prefix_embs = concat(prefix_embs, next_emb)│ │
│ │ Update masks (causal attention) │ │
│ └───────────────────────────────────────────┘ │
│ │
│ Return: generated_tokens │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 2: Decode & Print Generated Subtask Tokens │
│ ---------------------------------------------------------- │
│ • Remove padding tokens (value = 0) │
│ • Decode using tokenizer │
│ • Print: "[Inference] Generated subtask {i}: {text}" │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 3: Embed Prefix (without subtask tokens) │
│ ---------------------------------------------------------- │
│ prefix_embs = [images, high_level_task, tokens] │
│ (Note: subtask_tokens = None during inference) │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 4: Cache KV for Prefix │
│ ---------------------------------------------------------- │
│ Forward pass through transformer with use_cache=True │
│ → past_key_values │
└─────────────────────────────────────────────────────────────┘
↓
┌─────────────────────────────────────────────────────────────┐
│ Step 5: Flow Matching Denoising Loop │
│ ---------------------------------------------------------- │
│ Initialize: │
│ x_t = noise (random action sequence) │
│ time = 1.0 │
│ dt = -1.0 / num_steps │
│ │
│ While time >= -dt/2: │
│ ┌───────────────────────────────────────────┐ │
│ │ Denoise Step: │ │
│ │ • Embed x_t and time │ │
│ │ • Forward through transformer │ │
│ │ (uses cached past_key_values) │ │
│ │ • Predict velocity: v_t │ │
│ │ │ │
│ │ Euler Step: │ │
│ │ x_t = x_t + dt * v_t │ │
│ │ time = time + dt │ │
│ └───────────────────────────────────────────┘ │
│ │
│ Return: x_t (final denoised actions) │
└─────────────────────────────────────────────────────────────┘
Output: actions (predicted action chunk)
Key Differences Between Training and Inference
| Aspect | Training | Inference |
|---|---|---|
| Subtask Tokens | Ground truth provided | Generated autoregressively |
| Subtask Processing | Decode & print for monitoring | Generate → decode → print |
| Forward Passes | 2 passes (prefix-only, then full) | Multiple passes (1 per token + denoising) |
| Loss Computation | Subtask loss + flow loss | No loss (inference only) |
| Subtask Usage | Used for loss calculation | Generated but not used in action prediction |
Autoregressive Generation Detail
Time step t=0:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] → │
│ → Forward → LM Head → "pick" (token: 1234) │
└─────────────────────────────────────────────────────┘
Time step t=1:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] → │
│ → Forward → LM Head → "up" (token: 5678) │
└─────────────────────────────────────────────────────┘
Time step t=2:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] [up] → │
│ → Forward → LM Head → "the" (token: 9012) │
└─────────────────────────────────────────────────────┘
... continues until EOS or max_length ...
Final Result: "pick up the red block"
Attention Masking Pattern
During Subtask Generation:
Position: 0 1 2 3 4 5 6 (token positions)
Token: [IMG] [IMG] [TASK] [T1] [T2] [T3] (T* = generated tokens)
Mask Type: 0 0 0 1 1 1 (0=full attn, 1=causal)
Attention Pattern:
[IMG] can attend to: [IMG] (itself)
[IMG] can attend to: [IMG], [IMG] (all previous)
[TASK] can attend to: [IMG], [IMG], [TASK] (all previous)
[T1] can attend to: [IMG], [IMG], [TASK] (causal: only previous)
[T2] can attend to: [IMG], [IMG], [TASK], [T1] (causal)
[T3] can attend to: [IMG], [IMG], [TASK], [T1], [T2] (causal)
Benefits of This Approach
-
Training:
- Model learns to predict subtask tokens given observations
- Joint training of subtask prediction and action prediction
- Ground truth subtasks are visible for debugging
-
Inference:
- Model can generate interpretable subtask descriptions
- Autoregressive generation ensures coherent subtask text
- Provides insight into model's reasoning process
- Can be used for hierarchical planning
Implementation Notes
- Greedy Decoding: Always selects the most likely token (argmax)
- No KV Cache: Each generation step performs full forward pass (can be optimized)
- Max Length: Limited to 50 tokens (configurable)
- EOS Handling: Stops early if all batches generate EOS token
- Padding Handling: Filters out padding tokens before decoding