Files
lerobot/SUBTASK_GENERATION_FLOW.md
T
Jade Choghari 7e232fb114 more changes
2025-12-13 21:02:07 +00:00

16 KiB

Subtask Token Generation Flow Diagram

Overview

This document provides visual representations of how subtask tokens are processed during training and inference in the PI05 model.

Training Flow

┌─────────────────────────────────────────────────────────────┐
│                     TRAINING FORWARD PASS                    │
└─────────────────────────────────────────────────────────────┘

Input Batch:
├─ images              (observations)
├─ high_level_task     (user prompt tokens)
├─ tokens              (instruction prompt with state)
├─ subtask_tokens      (ground truth subtask - TARGET)
└─ actions             (ground truth actions)

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 1: Decode & Print Ground Truth Subtask Tokens        │
│  ---------------------------------------------------------- │
│  • Extract valid tokens (remove padding)                    │
│  • Decode using tokenizer                                   │
│  • Print: "[Training] Ground truth subtask {i}: {text}"    │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 2: Prefix-Only Forward Pass                          │
│  ---------------------------------------------------------- │
│  Embed Prefix:                                              │
│    • images → image embeddings                              │
│    • high_level_task → language embeddings                  │
│    • tokens → language embeddings                           │
│    • subtask_tokens → language embeddings (with causal mask)│
│                                                             │
│  Forward through transformer (prefix only)                  │
│    → Get prefix_out                                         │
│                                                             │
│  Apply LM Head:                                             │
│    prefix_out → logits                                      │
│                                                             │
│  Extract subtask logits:                                    │
│    logits[start_index:end_index]                           │
│                                                             │
│  Compute Cross-Entropy Loss:                                │
│    CE(predicted_logits, subtask_tokens) → subtask_loss     │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 3: Full Forward Pass (Prefix + Suffix)               │
│  ---------------------------------------------------------- │
│  Add noisy actions to prefix:                               │
│    x_t = time * noise + (1-time) * actions                 │
│                                                             │
│  Forward through transformer:                               │
│    [prefix_embs, action_embs] → [prefix_out, suffix_out]  │
│                                                             │
│  Predict velocity field:                                    │
│    suffix_out → v_t                                        │
│                                                             │
│  Compute Flow Matching Loss:                                │
│    MSE(u_t, v_t) → flow_loss                              │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 4: Combined Loss                                      │
│  ---------------------------------------------------------- │
│  total_loss = 10 * flow_loss + subtask_loss               │
└─────────────────────────────────────────────────────────────┘

Output: loss, {flow_loss, subtask_loss, loss}

Inference Flow

┌─────────────────────────────────────────────────────────────┐
│                    INFERENCE FORWARD PASS                    │
└─────────────────────────────────────────────────────────────┘

Input Batch:
├─ images              (observations)
├─ high_level_task     (user prompt tokens)
└─ tokens              (instruction prompt with state)

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 1: Autoregressive Subtask Token Generation           │
│  ---------------------------------------------------------- │
│                                                             │
│  Initialize:                                                │
│    prefix_embs = [images, high_level_task, tokens]         │
│    generated_tokens = []                                    │
│                                                             │
│  For t in range(max_subtask_tokens):                       │
│    ┌───────────────────────────────────────────┐          │
│    │ Forward Pass:                              │          │
│    │   prefix_embs → transformer → prefix_out   │          │
│    │                                             │          │
│    │ Apply LM Head:                              │          │
│    │   prefix_out → logits                      │          │
│    │                                             │          │
│    │ Greedy Decode:                              │          │
│    │   next_token = argmax(logits[-1])         │          │
│    │                                             │          │
│    │ Store Token:                                │          │
│    │   generated_tokens.append(next_token)      │          │
│    │                                             │          │
│    │ Check EOS:                                  │          │
│    │   if next_token == EOS: break              │          │
│    │                                             │          │
│    │ Embed & Append:                             │          │
│    │   next_emb = embed(next_token)             │          │
│    │   prefix_embs = concat(prefix_embs, next_emb)│        │
│    │   Update masks (causal attention)          │          │
│    └───────────────────────────────────────────┘          │
│                                                             │
│  Return: generated_tokens                                   │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 2: Decode & Print Generated Subtask Tokens           │
│  ---------------------------------------------------------- │
│  • Remove padding tokens (value = 0)                        │
│  • Decode using tokenizer                                   │
│  • Print: "[Inference] Generated subtask {i}: {text}"      │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 3: Embed Prefix (without subtask tokens)             │
│  ---------------------------------------------------------- │
│  prefix_embs = [images, high_level_task, tokens]           │
│  (Note: subtask_tokens = None during inference)            │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 4: Cache KV for Prefix                                │
│  ---------------------------------------------------------- │
│  Forward pass through transformer with use_cache=True       │
│    → past_key_values                                        │
└─────────────────────────────────────────────────────────────┘

                          ↓

┌─────────────────────────────────────────────────────────────┐
│  Step 5: Flow Matching Denoising Loop                       │
│  ---------------------------------------------------------- │
│  Initialize:                                                │
│    x_t = noise  (random action sequence)                   │
│    time = 1.0                                              │
│    dt = -1.0 / num_steps                                   │
│                                                             │
│  While time >= -dt/2:                                      │
│    ┌───────────────────────────────────────────┐          │
│    │ Denoise Step:                              │          │
│    │   • Embed x_t and time                     │          │
│    │   • Forward through transformer            │          │
│    │     (uses cached past_key_values)          │          │
│    │   • Predict velocity: v_t                  │          │
│    │                                             │          │
│    │ Euler Step:                                 │          │
│    │   x_t = x_t + dt * v_t                     │          │
│    │   time = time + dt                         │          │
│    └───────────────────────────────────────────┘          │
│                                                             │
│  Return: x_t (final denoised actions)                      │
└─────────────────────────────────────────────────────────────┘

Output: actions (predicted action chunk)

Key Differences Between Training and Inference

Aspect Training Inference
Subtask Tokens Ground truth provided Generated autoregressively
Subtask Processing Decode & print for monitoring Generate → decode → print
Forward Passes 2 passes (prefix-only, then full) Multiple passes (1 per token + denoising)
Loss Computation Subtask loss + flow loss No loss (inference only)
Subtask Usage Used for loss calculation Generated but not used in action prediction

Autoregressive Generation Detail

Time step t=0:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] →               │
│   → Forward → LM Head → "pick" (token: 1234)       │
└─────────────────────────────────────────────────────┘

Time step t=1:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] →        │
│   → Forward → LM Head → "up" (token: 5678)         │
└─────────────────────────────────────────────────────┘

Time step t=2:
┌─────────────────────────────────────────────────────┐
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] [up] →   │
│   → Forward → LM Head → "the" (token: 9012)        │
└─────────────────────────────────────────────────────┘

... continues until EOS or max_length ...

Final Result: "pick up the red block"

Attention Masking Pattern

During Subtask Generation:

Position:    0    1    2    3    4    5    6    (token positions)
Token:     [IMG] [IMG] [TASK] [T1] [T2] [T3]    (T* = generated tokens)
Mask Type:   0    0     0     1    1    1      (0=full attn, 1=causal)

Attention Pattern:
  [IMG] can attend to: [IMG] (itself)
  [IMG] can attend to: [IMG], [IMG] (all previous)
  [TASK] can attend to: [IMG], [IMG], [TASK] (all previous)
  [T1] can attend to: [IMG], [IMG], [TASK] (causal: only previous)
  [T2] can attend to: [IMG], [IMG], [TASK], [T1] (causal)
  [T3] can attend to: [IMG], [IMG], [TASK], [T1], [T2] (causal)

Benefits of This Approach

  1. Training:

    • Model learns to predict subtask tokens given observations
    • Joint training of subtask prediction and action prediction
    • Ground truth subtasks are visible for debugging
  2. Inference:

    • Model can generate interpretable subtask descriptions
    • Autoregressive generation ensures coherent subtask text
    • Provides insight into model's reasoning process
    • Can be used for hierarchical planning

Implementation Notes

  • Greedy Decoding: Always selects the most likely token (argmax)
  • No KV Cache: Each generation step performs full forward pass (can be optimized)
  • Max Length: Limited to 50 tokens (configurable)
  • EOS Handling: Stops early if all batches generate EOS token
  • Padding Handling: Filters out padding tokens before decoding