mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 00:59:46 +00:00
more
This commit is contained in:
@@ -1,208 +0,0 @@
|
||||
# Subtask Token Generation - Quick Reference
|
||||
|
||||
## What Was Done
|
||||
|
||||
Added **autoregressive subtask token generation** to PI05 model with decoding and printing during both training and inference.
|
||||
|
||||
## Key Features
|
||||
|
||||
✅ **Training:** Prints ground truth subtask tokens for monitoring
|
||||
✅ **Inference:** Generates and prints predicted subtask tokens using next token prediction
|
||||
✅ **Autoregressive:** Each token conditioned on previous tokens
|
||||
✅ **Greedy Decoding:** Selects most likely token at each step
|
||||
|
||||
## Implementation Location
|
||||
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py`
|
||||
|
||||
**New Method:** `_generate_subtask_tokens()` (lines 844-914)
|
||||
- Autoregressive token generation
|
||||
- Uses PaliGemma language model head
|
||||
- Greedy decoding with early stopping
|
||||
|
||||
**Modified Methods:**
|
||||
- `sample_actions()` - Calls generation and prints during inference
|
||||
- `predict_action_chunk()` - Passes tokenizer to enable generation
|
||||
- `forward()` - Prints ground truth tokens during training
|
||||
- `__init__()` - Loads tokenizer
|
||||
|
||||
## Console Output Examples
|
||||
|
||||
### Training:
|
||||
```
|
||||
[Training] Ground truth subtask 0: pick up the red block
|
||||
[Training] Ground truth subtask 1: place in blue container
|
||||
```
|
||||
|
||||
### Inference:
|
||||
```
|
||||
[Inference] Generated subtask 0: grasp the object
|
||||
[Inference] Generated subtask 1: move to target location
|
||||
```
|
||||
|
||||
## How to Use
|
||||
|
||||
### No Code Changes Required!
|
||||
|
||||
The implementation is automatic:
|
||||
|
||||
1. **Training:** Just run your training script
|
||||
- Subtasks will be printed to console automatically
|
||||
|
||||
2. **Inference:** Just run your inference script
|
||||
- Subtasks will be generated and printed automatically
|
||||
|
||||
### To Disable (if needed):
|
||||
|
||||
To disable subtask generation during inference for better performance:
|
||||
|
||||
```python
|
||||
# In the model code, set tokenizer to None temporarily
|
||||
policy.tokenizer = None
|
||||
actions = policy.predict_action_chunk(batch)
|
||||
```
|
||||
|
||||
## Technical Specs
|
||||
|
||||
| Property | Value |
|
||||
|----------|-------|
|
||||
| **Generation Method** | Autoregressive (sequential) |
|
||||
| **Decoding Strategy** | Greedy (argmax) |
|
||||
| **Max Tokens** | 50 (configurable) |
|
||||
| **Tokenizer** | google/paligemma-3b-pt-224 |
|
||||
| **Attention** | Causal masking for generated tokens |
|
||||
| **Performance Cost** | ~50 extra forward passes per inference |
|
||||
|
||||
## Architecture Flow
|
||||
|
||||
```
|
||||
Training: Ground Truth Tokens → Decode → Print → Loss Computation
|
||||
Inference: Observations → Generate Tokens → Decode → Print → Action Prediction
|
||||
```
|
||||
|
||||
## Method: `_generate_subtask_tokens()`
|
||||
|
||||
**Purpose:** Generate subtask tokens autoregressively
|
||||
|
||||
**Algorithm:**
|
||||
```python
|
||||
1. Start with prefix = [images, high-level task, state]
|
||||
2. For each position (up to max_length):
|
||||
a. Forward pass → get logits
|
||||
b. Apply LM head → token probabilities
|
||||
c. Select best token (greedy)
|
||||
d. Embed token
|
||||
e. Append to prefix
|
||||
f. Update masks (causal attention)
|
||||
3. Stop when EOS or max length reached
|
||||
4. Return generated tokens
|
||||
```
|
||||
|
||||
**Key Parameters:**
|
||||
- `images` - Visual observations
|
||||
- `img_masks` - Image padding masks
|
||||
- `tokens` - Instruction tokens with state
|
||||
- `masks` - Token attention masks
|
||||
- `tokenizer` - For EOS detection
|
||||
- `max_length` - Maximum tokens to generate (default: 50)
|
||||
- `device` - Computation device
|
||||
|
||||
## Files Created
|
||||
|
||||
📄 `SUMMARY.md` - Comprehensive summary
|
||||
📄 `SUBTASK_GENERATION_CHANGES.md` - Detailed technical docs
|
||||
📄 `SUBTASK_GENERATION_FLOW.md` - Visual flow diagrams
|
||||
📄 `QUICK_REFERENCE.md` - This file
|
||||
📄 `examples/dataset/test_subtask_generation.py` - Test script
|
||||
|
||||
## Quick Test
|
||||
|
||||
```bash
|
||||
# Test that tokenizer loads correctly
|
||||
python examples/dataset/test_subtask_generation.py
|
||||
|
||||
# Run training to see ground truth subtasks
|
||||
python your_training_script.py
|
||||
|
||||
# Run inference to see generated subtasks
|
||||
python your_inference_script.py
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### No subtask output during inference?
|
||||
- Check that tokenizer loaded: `print(policy.tokenizer)`
|
||||
- Should see: `PaliGemmaTokenizerFast(name_or_path='google/paligemma-3b-pt-224'...)`
|
||||
|
||||
### Tokenizer failed to load?
|
||||
- Check internet connection (first run downloads tokenizer)
|
||||
- Check transformers library installed: `pip install transformers`
|
||||
|
||||
### Performance too slow during inference?
|
||||
- Disable subtask generation by setting `policy.tokenizer = None`
|
||||
- Or implement KV caching for faster generation (future optimization)
|
||||
|
||||
## Integration Points
|
||||
|
||||
The implementation integrates seamlessly with existing code:
|
||||
|
||||
- **Training Loop:** No changes needed, prints happen automatically
|
||||
- **Inference Loop:** No changes needed, generation happens automatically
|
||||
- **Data Processing:** Uses existing tokenizer from processor
|
||||
- **Loss Computation:** Already implemented in training forward pass
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Possible improvements (not yet implemented):
|
||||
|
||||
- [ ] KV caching for faster generation
|
||||
- [ ] Temperature/top-k/top-p sampling
|
||||
- [ ] Beam search for better quality
|
||||
- [ ] Optional flag to enable/disable printing
|
||||
- [ ] Save generated subtasks to file
|
||||
- [ ] Compute subtask prediction accuracy metrics
|
||||
- [ ] Use generated subtasks in action prediction (hierarchical)
|
||||
|
||||
## Code Snippet - How Autoregressive Generation Works
|
||||
|
||||
```python
|
||||
# Simplified pseudocode
|
||||
generated_tokens = []
|
||||
prefix = [images, high_level_task, state]
|
||||
|
||||
for t in range(max_length):
|
||||
# Forward pass
|
||||
logits = model(prefix)
|
||||
|
||||
# Greedy decode
|
||||
next_token = argmax(logits[-1])
|
||||
|
||||
# Store
|
||||
generated_tokens.append(next_token)
|
||||
|
||||
# Stop if EOS
|
||||
if next_token == EOS:
|
||||
break
|
||||
|
||||
# Append for next iteration
|
||||
prefix = prefix + [next_token]
|
||||
|
||||
return generated_tokens
|
||||
```
|
||||
|
||||
## Questions?
|
||||
|
||||
See the detailed documentation files:
|
||||
- `SUBTASK_GENERATION_CHANGES.md` - Full technical details
|
||||
- `SUBTASK_GENERATION_FLOW.md` - Visual flow diagrams
|
||||
- `SUMMARY.md` - Complete overview
|
||||
|
||||
---
|
||||
|
||||
**Implementation Status:** ✅ Complete and Ready to Use
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,181 +0,0 @@
|
||||
# Subtask Token Generation and Decoding Implementation
|
||||
|
||||
## Overview
|
||||
This document describes the implementation of subtask token generation and decoding in the PI05 model. The implementation enables the model to generate subtask tokens autoregressively during inference and decode them to human-readable text during both training and inference.
|
||||
|
||||
## Changes Made
|
||||
|
||||
### 1. Added Autoregressive Subtask Token Generation (`modeling_pi05.py`)
|
||||
|
||||
#### New Method: `_generate_subtask_tokens()`
|
||||
**Location:** Lines 844-914
|
||||
|
||||
**Purpose:** Generates subtask tokens autoregressively using next token prediction during inference.
|
||||
|
||||
**How it works:**
|
||||
1. Embeds the prefix (images + high-level task tokens + state)
|
||||
2. Iteratively generates tokens one at a time:
|
||||
- Forward pass through the model to get logits
|
||||
- Apply language model head to get token probabilities
|
||||
- Use greedy decoding to select the most likely next token
|
||||
- Embed the generated token and append to prefix
|
||||
- Update attention masks for causal attention
|
||||
3. Stops when EOS token is generated or max length is reached
|
||||
4. Returns tensor of generated tokens
|
||||
|
||||
**Key Features:**
|
||||
- Uses `@torch.no_grad()` decorator for inference efficiency
|
||||
- Implements greedy decoding (selects highest probability token)
|
||||
- Uses causal attention masking for generated tokens
|
||||
- Supports early stopping with EOS token detection
|
||||
|
||||
### 2. Updated `sample_actions()` Method
|
||||
**Location:** Lines 916-1020
|
||||
|
||||
**Changes:**
|
||||
- Added `tokenizer` parameter (optional)
|
||||
- Added `max_subtask_tokens` parameter (default: 50)
|
||||
- Calls `_generate_subtask_tokens()` if tokenizer is provided
|
||||
- Decodes and prints generated subtask tokens during inference
|
||||
|
||||
**Output Format:**
|
||||
```
|
||||
[Inference] Generated subtask {batch_idx}: {decoded_text}
|
||||
```
|
||||
|
||||
### 3. Updated `PI05Policy.__init__()` Method
|
||||
**Location:** Lines 1066-1099
|
||||
|
||||
**Changes:**
|
||||
- Added tokenizer loading using `AutoTokenizer.from_pretrained("google/paligemma-3b-pt-224")`
|
||||
- Stores tokenizer as `self.tokenizer`
|
||||
- Includes error handling with warning if tokenizer fails to load
|
||||
|
||||
### 4. Updated `predict_action_chunk()` Method
|
||||
**Location:** Lines 1387-1409
|
||||
|
||||
**Changes:**
|
||||
- Passes `self.tokenizer` to `model.sample_actions()`
|
||||
- Enables subtask generation during inference
|
||||
|
||||
### 5. Updated `forward()` Method (Training)
|
||||
**Location:** Lines 1411-1445
|
||||
|
||||
**Changes:**
|
||||
- Added ground truth subtask token decoding during training
|
||||
- Prints decoded subtask tokens when in training mode
|
||||
- Uses subtask masks to filter out padding tokens
|
||||
|
||||
**Output Format:**
|
||||
```
|
||||
[Training] Ground truth subtask {batch_idx}: {decoded_text}
|
||||
```
|
||||
|
||||
## Technical Details
|
||||
|
||||
### Autoregressive Generation Process
|
||||
|
||||
The generation process follows these steps:
|
||||
|
||||
```
|
||||
1. Start with prefix: [images, high-level task, state]
|
||||
2. For each generation step (up to max_subtask_tokens):
|
||||
a. Create attention masks (causal for generated tokens)
|
||||
b. Forward pass through transformer
|
||||
c. Apply language model head → logits
|
||||
d. Greedy decode: select argmax(logits)
|
||||
e. Embed selected token
|
||||
f. Append to prefix embeddings
|
||||
g. Update masks
|
||||
3. Stop when EOS token or max length reached
|
||||
4. Return generated token sequence
|
||||
```
|
||||
|
||||
### Attention Masking
|
||||
|
||||
- **Prefix tokens (images + high-level task):** Can attend to all previous tokens (attention mask = 0)
|
||||
- **Generated subtask tokens:** Use causal attention, can only attend to previous tokens (attention mask = 1)
|
||||
|
||||
### Tokenizer
|
||||
|
||||
- **Model:** `google/paligemma-3b-pt-224`
|
||||
- **Type:** PaliGemma tokenizer (based on SentencePiece)
|
||||
- **Usage:** Decodes token IDs to human-readable text
|
||||
- **Special tokens:** Automatically filtered with `skip_special_tokens=True`
|
||||
|
||||
## Usage Examples
|
||||
|
||||
### During Training
|
||||
|
||||
The model automatically prints ground truth subtask tokens:
|
||||
|
||||
```python
|
||||
# Training batch contains subtask_tokens
|
||||
loss, loss_dict = policy.forward(batch)
|
||||
|
||||
# Output (automatically printed):
|
||||
# [Training] Ground truth subtask 0: pick up the red block
|
||||
# [Training] Ground truth subtask 1: move to the blue container
|
||||
```
|
||||
|
||||
### During Inference
|
||||
|
||||
The model generates and prints predicted subtask tokens:
|
||||
|
||||
```python
|
||||
# Inference - tokenizer is automatically passed
|
||||
actions = policy.predict_action_chunk(batch)
|
||||
|
||||
# Output (automatically printed):
|
||||
# [Inference] Generated subtask 0: grasp the object
|
||||
# [Inference] Generated subtask 1: place in target location
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. **Transparency:** See what subtasks the model predicts during inference
|
||||
2. **Debugging:** Verify that subtask prediction is working correctly
|
||||
3. **Interpretability:** Understand the model's reasoning process
|
||||
4. **Monitoring:** Track subtask generation quality during training
|
||||
|
||||
## Performance Considerations
|
||||
|
||||
- **Training:** Minimal overhead (only decoding for printing, no generation)
|
||||
- **Inference:** Additional computational cost due to autoregressive generation
|
||||
- Each token requires a forward pass through the transformer
|
||||
- For max_subtask_tokens=50, up to 50 forward passes
|
||||
- Can be disabled by not passing tokenizer (for production deployments)
|
||||
|
||||
## Future Enhancements
|
||||
|
||||
Possible improvements to consider:
|
||||
|
||||
1. **Sampling strategies:** Add temperature, top-k, top-p sampling
|
||||
2. **Beam search:** Generate multiple candidates and select best
|
||||
3. **Caching:** Use KV-cache to speed up autoregressive generation
|
||||
4. **Logging:** Redirect prints to logger instead of console
|
||||
5. **Metrics:** Track subtask prediction accuracy during training
|
||||
6. **Optional flag:** Add config option to enable/disable printing
|
||||
|
||||
## Testing
|
||||
|
||||
To test the implementation, run:
|
||||
|
||||
```bash
|
||||
python examples/dataset/test_subtask_generation.py
|
||||
```
|
||||
|
||||
This will demonstrate the subtask generation features and verify the tokenizer is loaded correctly.
|
||||
|
||||
## Related Files
|
||||
|
||||
- `src/lerobot/policies/pi05/modeling_pi05.py` - Main implementation
|
||||
- `src/lerobot/policies/pi05/processor_pi05.py` - Subtask token preprocessing
|
||||
- `src/lerobot/policies/pi05/configuration_pi05.py` - Configuration
|
||||
- `examples/dataset/test_subtask_generation.py` - Test script
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
@@ -1,267 +0,0 @@
|
||||
# Subtask Token Generation Flow Diagram
|
||||
|
||||
## Overview
|
||||
This document provides visual representations of how subtask tokens are processed during training and inference in the PI05 model.
|
||||
|
||||
## Training Flow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ TRAINING FORWARD PASS │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
Input Batch:
|
||||
├─ images (observations)
|
||||
├─ high_level_task (user prompt tokens)
|
||||
├─ tokens (instruction prompt with state)
|
||||
├─ subtask_tokens (ground truth subtask - TARGET)
|
||||
└─ actions (ground truth actions)
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Decode & Print Ground Truth Subtask Tokens │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ • Extract valid tokens (remove padding) │
|
||||
│ • Decode using tokenizer │
|
||||
│ • Print: "[Training] Ground truth subtask {i}: {text}" │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Prefix-Only Forward Pass │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ Embed Prefix: │
|
||||
│ • images → image embeddings │
|
||||
│ • high_level_task → language embeddings │
|
||||
│ • tokens → language embeddings │
|
||||
│ • subtask_tokens → language embeddings (with causal mask)│
|
||||
│ │
|
||||
│ Forward through transformer (prefix only) │
|
||||
│ → Get prefix_out │
|
||||
│ │
|
||||
│ Apply LM Head: │
|
||||
│ prefix_out → logits │
|
||||
│ │
|
||||
│ Extract subtask logits: │
|
||||
│ logits[start_index:end_index] │
|
||||
│ │
|
||||
│ Compute Cross-Entropy Loss: │
|
||||
│ CE(predicted_logits, subtask_tokens) → subtask_loss │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Full Forward Pass (Prefix + Suffix) │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ Add noisy actions to prefix: │
|
||||
│ x_t = time * noise + (1-time) * actions │
|
||||
│ │
|
||||
│ Forward through transformer: │
|
||||
│ [prefix_embs, action_embs] → [prefix_out, suffix_out] │
|
||||
│ │
|
||||
│ Predict velocity field: │
|
||||
│ suffix_out → v_t │
|
||||
│ │
|
||||
│ Compute Flow Matching Loss: │
|
||||
│ MSE(u_t, v_t) → flow_loss │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Combined Loss │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ total_loss = 10 * flow_loss + subtask_loss │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
Output: loss, {flow_loss, subtask_loss, loss}
|
||||
```
|
||||
|
||||
## Inference Flow
|
||||
|
||||
```
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ INFERENCE FORWARD PASS │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
Input Batch:
|
||||
├─ images (observations)
|
||||
├─ high_level_task (user prompt tokens)
|
||||
└─ tokens (instruction prompt with state)
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 1: Autoregressive Subtask Token Generation │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ │
|
||||
│ Initialize: │
|
||||
│ prefix_embs = [images, high_level_task, tokens] │
|
||||
│ generated_tokens = [] │
|
||||
│ │
|
||||
│ For t in range(max_subtask_tokens): │
|
||||
│ ┌───────────────────────────────────────────┐ │
|
||||
│ │ Forward Pass: │ │
|
||||
│ │ prefix_embs → transformer → prefix_out │ │
|
||||
│ │ │ │
|
||||
│ │ Apply LM Head: │ │
|
||||
│ │ prefix_out → logits │ │
|
||||
│ │ │ │
|
||||
│ │ Greedy Decode: │ │
|
||||
│ │ next_token = argmax(logits[-1]) │ │
|
||||
│ │ │ │
|
||||
│ │ Store Token: │ │
|
||||
│ │ generated_tokens.append(next_token) │ │
|
||||
│ │ │ │
|
||||
│ │ Check EOS: │ │
|
||||
│ │ if next_token == EOS: break │ │
|
||||
│ │ │ │
|
||||
│ │ Embed & Append: │ │
|
||||
│ │ next_emb = embed(next_token) │ │
|
||||
│ │ prefix_embs = concat(prefix_embs, next_emb)│ │
|
||||
│ │ Update masks (causal attention) │ │
|
||||
│ └───────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ Return: generated_tokens │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 2: Decode & Print Generated Subtask Tokens │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ • Remove padding tokens (value = 0) │
|
||||
│ • Decode using tokenizer │
|
||||
│ • Print: "[Inference] Generated subtask {i}: {text}" │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 3: Embed Prefix (without subtask tokens) │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ prefix_embs = [images, high_level_task, tokens] │
|
||||
│ (Note: subtask_tokens = None during inference) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 4: Cache KV for Prefix │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ Forward pass through transformer with use_cache=True │
|
||||
│ → past_key_values │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
↓
|
||||
|
||||
┌─────────────────────────────────────────────────────────────┐
|
||||
│ Step 5: Flow Matching Denoising Loop │
|
||||
│ ---------------------------------------------------------- │
|
||||
│ Initialize: │
|
||||
│ x_t = noise (random action sequence) │
|
||||
│ time = 1.0 │
|
||||
│ dt = -1.0 / num_steps │
|
||||
│ │
|
||||
│ While time >= -dt/2: │
|
||||
│ ┌───────────────────────────────────────────┐ │
|
||||
│ │ Denoise Step: │ │
|
||||
│ │ • Embed x_t and time │ │
|
||||
│ │ • Forward through transformer │ │
|
||||
│ │ (uses cached past_key_values) │ │
|
||||
│ │ • Predict velocity: v_t │ │
|
||||
│ │ │ │
|
||||
│ │ Euler Step: │ │
|
||||
│ │ x_t = x_t + dt * v_t │ │
|
||||
│ │ time = time + dt │ │
|
||||
│ └───────────────────────────────────────────┘ │
|
||||
│ │
|
||||
│ Return: x_t (final denoised actions) │
|
||||
└─────────────────────────────────────────────────────────────┘
|
||||
|
||||
Output: actions (predicted action chunk)
|
||||
```
|
||||
|
||||
## Key Differences Between Training and Inference
|
||||
|
||||
| Aspect | Training | Inference |
|
||||
|--------|----------|-----------|
|
||||
| Subtask Tokens | **Ground truth provided** | **Generated autoregressively** |
|
||||
| Subtask Processing | Decode & print for monitoring | Generate → decode → print |
|
||||
| Forward Passes | 2 passes (prefix-only, then full) | Multiple passes (1 per token + denoising) |
|
||||
| Loss Computation | Subtask loss + flow loss | No loss (inference only) |
|
||||
| Subtask Usage | Used for loss calculation | Generated but not used in action prediction |
|
||||
|
||||
## Autoregressive Generation Detail
|
||||
|
||||
```
|
||||
Time step t=0:
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Prefix: [IMG] [IMG] [TASK] [STATE] → │
|
||||
│ → Forward → LM Head → "pick" (token: 1234) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
|
||||
Time step t=1:
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] → │
|
||||
│ → Forward → LM Head → "up" (token: 5678) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
|
||||
Time step t=2:
|
||||
┌─────────────────────────────────────────────────────┐
|
||||
│ Prefix: [IMG] [IMG] [TASK] [STATE] [pick] [up] → │
|
||||
│ → Forward → LM Head → "the" (token: 9012) │
|
||||
└─────────────────────────────────────────────────────┘
|
||||
|
||||
... continues until EOS or max_length ...
|
||||
|
||||
Final Result: "pick up the red block"
|
||||
```
|
||||
|
||||
## Attention Masking Pattern
|
||||
|
||||
```
|
||||
During Subtask Generation:
|
||||
|
||||
Position: 0 1 2 3 4 5 6 (token positions)
|
||||
Token: [IMG] [IMG] [TASK] [T1] [T2] [T3] (T* = generated tokens)
|
||||
Mask Type: 0 0 0 1 1 1 (0=full attn, 1=causal)
|
||||
|
||||
Attention Pattern:
|
||||
[IMG] can attend to: [IMG] (itself)
|
||||
[IMG] can attend to: [IMG], [IMG] (all previous)
|
||||
[TASK] can attend to: [IMG], [IMG], [TASK] (all previous)
|
||||
[T1] can attend to: [IMG], [IMG], [TASK] (causal: only previous)
|
||||
[T2] can attend to: [IMG], [IMG], [TASK], [T1] (causal)
|
||||
[T3] can attend to: [IMG], [IMG], [TASK], [T1], [T2] (causal)
|
||||
```
|
||||
|
||||
## Benefits of This Approach
|
||||
|
||||
1. **Training:**
|
||||
- Model learns to predict subtask tokens given observations
|
||||
- Joint training of subtask prediction and action prediction
|
||||
- Ground truth subtasks are visible for debugging
|
||||
|
||||
2. **Inference:**
|
||||
- Model can generate interpretable subtask descriptions
|
||||
- Autoregressive generation ensures coherent subtask text
|
||||
- Provides insight into model's reasoning process
|
||||
- Can be used for hierarchical planning
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
- **Greedy Decoding:** Always selects the most likely token (argmax)
|
||||
- **No KV Cache:** Each generation step performs full forward pass (can be optimized)
|
||||
- **Max Length:** Limited to 50 tokens (configurable)
|
||||
- **EOS Handling:** Stops early if all batches generate EOS token
|
||||
- **Padding Handling:** Filters out padding tokens before decoding
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
-152
@@ -1,152 +0,0 @@
|
||||
# Subtask Token Generation Implementation - Summary
|
||||
|
||||
## What Was Implemented
|
||||
|
||||
I've successfully added **autoregressive subtask token generation and decoding** to the PI05 model, enabling the model to:
|
||||
|
||||
1. **During Training:** Decode and print ground truth subtask tokens for monitoring
|
||||
2. **During Inference:** Generate subtask tokens using next token prediction and print them
|
||||
|
||||
## Key Changes
|
||||
|
||||
### 1. New Method: `_generate_subtask_tokens()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 844-914)
|
||||
|
||||
- Implements autoregressive token generation using greedy decoding
|
||||
- Uses the PaliGemma language model head for token prediction
|
||||
- Generates tokens one at a time, each conditioned on previous tokens
|
||||
- Stops when EOS token is generated or max length (50 tokens) is reached
|
||||
|
||||
### 2. Updated `sample_actions()` Method
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 916-1020)
|
||||
|
||||
- Added optional `tokenizer` and `max_subtask_tokens` parameters
|
||||
- Calls `_generate_subtask_tokens()` during inference if tokenizer is provided
|
||||
- Decodes and prints generated subtask tokens
|
||||
|
||||
### 3. Updated `PI05Policy.__init__()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1066-1099)
|
||||
|
||||
- Loads PaliGemma tokenizer (`google/paligemma-3b-pt-224`) for decoding
|
||||
- Stores as `self.tokenizer` for use throughout the policy
|
||||
|
||||
### 4. Updated `predict_action_chunk()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1387-1409)
|
||||
|
||||
- Passes tokenizer to `sample_actions()` to enable subtask generation
|
||||
|
||||
### 5. Updated `forward()` (Training Method)
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1411-1445)
|
||||
|
||||
- Decodes and prints ground truth subtask tokens during training
|
||||
- Helps monitor what the model is learning to predict
|
||||
|
||||
## How It Works
|
||||
|
||||
### During Inference:
|
||||
|
||||
```
|
||||
1. Initialize with prefix: [images, high-level task, state]
|
||||
2. Generate tokens autoregressively:
|
||||
- Forward pass → get logits
|
||||
- Select most likely token (greedy decoding)
|
||||
- Embed token and append to prefix
|
||||
- Repeat until EOS or max length
|
||||
3. Decode generated tokens to text
|
||||
4. Print: "[Inference] Generated subtask {i}: {text}"
|
||||
5. Continue with action prediction (flow matching)
|
||||
```
|
||||
|
||||
### During Training:
|
||||
|
||||
```
|
||||
1. Extract ground truth subtask tokens from batch
|
||||
2. Remove padding and decode to text
|
||||
3. Print: "[Training] Ground truth subtask {i}: {text}"
|
||||
4. Continue with normal training (subtask loss + flow loss)
|
||||
```
|
||||
|
||||
## Example Output
|
||||
|
||||
### Training:
|
||||
```
|
||||
[Training] Ground truth subtask 0: pick up the red block
|
||||
[Training] Ground truth subtask 1: move to the blue container
|
||||
[Training] Ground truth subtask 2: place the object down
|
||||
```
|
||||
|
||||
### Inference:
|
||||
```
|
||||
[Inference] Generated subtask 0: grasp the object
|
||||
[Inference] Generated subtask 1: move to target location
|
||||
[Inference] Generated subtask 2: release the gripper
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. ✓ **Transparency:** See what subtasks the model predicts
|
||||
2. ✓ **Debugging:** Verify subtask prediction works correctly
|
||||
3. ✓ **Interpretability:** Understand the model's reasoning
|
||||
4. ✓ **Monitoring:** Track subtask quality during training
|
||||
5. ✓ **Research:** Enables hierarchical reasoning analysis
|
||||
|
||||
## Files Modified
|
||||
|
||||
- `src/lerobot/policies/pi05/modeling_pi05.py` (main implementation)
|
||||
|
||||
## Files Created
|
||||
|
||||
- `examples/dataset/test_subtask_generation.py` (demo script)
|
||||
- `SUBTASK_GENERATION_CHANGES.md` (detailed documentation)
|
||||
- `SUBTASK_GENERATION_FLOW.md` (visual flow diagrams)
|
||||
- `SUMMARY.md` (this file)
|
||||
|
||||
## Testing
|
||||
|
||||
To verify the implementation:
|
||||
|
||||
```bash
|
||||
python examples/dataset/test_subtask_generation.py
|
||||
```
|
||||
|
||||
This will check that the tokenizer loads correctly and explain the features.
|
||||
|
||||
## Next Steps
|
||||
|
||||
To see subtask generation in action:
|
||||
|
||||
1. **During Training:**
|
||||
- Run your training script as usual
|
||||
- Watch console for `[Training] Ground truth subtask` messages
|
||||
|
||||
2. **During Inference:**
|
||||
- Run your inference script as usual
|
||||
- Watch console for `[Inference] Generated subtask` messages
|
||||
|
||||
## Technical Details
|
||||
|
||||
- **Generation Method:** Autoregressive (one token at a time)
|
||||
- **Decoding Strategy:** Greedy (always select most likely token)
|
||||
- **Max Tokens:** 50 (configurable via `max_subtask_tokens` parameter)
|
||||
- **Attention:** Causal masking for generated tokens
|
||||
- **Tokenizer:** PaliGemma tokenizer (google/paligemma-3b-pt-224)
|
||||
- **Performance:** Adds ~50 forward passes during inference (can be optimized with KV caching)
|
||||
|
||||
## Notes
|
||||
|
||||
- The implementation follows the same pattern as training (using LM head for prediction)
|
||||
- Subtask generation happens before action prediction
|
||||
- Generated subtasks are currently for visualization only (not used in action prediction)
|
||||
- In future, could be used for hierarchical planning or multi-step reasoning
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- See `SUBTASK_GENERATION_CHANGES.md` for detailed technical documentation
|
||||
- See `SUBTASK_GENERATION_FLOW.md` for visual flow diagrams
|
||||
- See training forward pass (lines 735-842) for reference implementation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user