mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
153 lines
4.9 KiB
Markdown
153 lines
4.9 KiB
Markdown
# Subtask Token Generation Implementation - Summary
|
|
|
|
## What Was Implemented
|
|
|
|
I've successfully added **autoregressive subtask token generation and decoding** to the PI05 model, enabling the model to:
|
|
|
|
1. **During Training:** Decode and print ground truth subtask tokens for monitoring
|
|
2. **During Inference:** Generate subtask tokens using next token prediction and print them
|
|
|
|
## Key Changes
|
|
|
|
### 1. New Method: `_generate_subtask_tokens()`
|
|
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 844-914)
|
|
|
|
- Implements autoregressive token generation using greedy decoding
|
|
- Uses the PaliGemma language model head for token prediction
|
|
- Generates tokens one at a time, each conditioned on previous tokens
|
|
- Stops when EOS token is generated or max length (50 tokens) is reached
|
|
|
|
### 2. Updated `sample_actions()` Method
|
|
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 916-1020)
|
|
|
|
- Added optional `tokenizer` and `max_subtask_tokens` parameters
|
|
- Calls `_generate_subtask_tokens()` during inference if tokenizer is provided
|
|
- Decodes and prints generated subtask tokens
|
|
|
|
### 3. Updated `PI05Policy.__init__()`
|
|
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1066-1099)
|
|
|
|
- Loads PaliGemma tokenizer (`google/paligemma-3b-pt-224`) for decoding
|
|
- Stores as `self.tokenizer` for use throughout the policy
|
|
|
|
### 4. Updated `predict_action_chunk()`
|
|
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1387-1409)
|
|
|
|
- Passes tokenizer to `sample_actions()` to enable subtask generation
|
|
|
|
### 5. Updated `forward()` (Training Method)
|
|
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1411-1445)
|
|
|
|
- Decodes and prints ground truth subtask tokens during training
|
|
- Helps monitor what the model is learning to predict
|
|
|
|
## How It Works
|
|
|
|
### During Inference:
|
|
|
|
```
|
|
1. Initialize with prefix: [images, high-level task, state]
|
|
2. Generate tokens autoregressively:
|
|
- Forward pass → get logits
|
|
- Select most likely token (greedy decoding)
|
|
- Embed token and append to prefix
|
|
- Repeat until EOS or max length
|
|
3. Decode generated tokens to text
|
|
4. Print: "[Inference] Generated subtask {i}: {text}"
|
|
5. Continue with action prediction (flow matching)
|
|
```
|
|
|
|
### During Training:
|
|
|
|
```
|
|
1. Extract ground truth subtask tokens from batch
|
|
2. Remove padding and decode to text
|
|
3. Print: "[Training] Ground truth subtask {i}: {text}"
|
|
4. Continue with normal training (subtask loss + flow loss)
|
|
```
|
|
|
|
## Example Output
|
|
|
|
### Training:
|
|
```
|
|
[Training] Ground truth subtask 0: pick up the red block
|
|
[Training] Ground truth subtask 1: move to the blue container
|
|
[Training] Ground truth subtask 2: place the object down
|
|
```
|
|
|
|
### Inference:
|
|
```
|
|
[Inference] Generated subtask 0: grasp the object
|
|
[Inference] Generated subtask 1: move to target location
|
|
[Inference] Generated subtask 2: release the gripper
|
|
```
|
|
|
|
## Benefits
|
|
|
|
1. ✓ **Transparency:** See what subtasks the model predicts
|
|
2. ✓ **Debugging:** Verify subtask prediction works correctly
|
|
3. ✓ **Interpretability:** Understand the model's reasoning
|
|
4. ✓ **Monitoring:** Track subtask quality during training
|
|
5. ✓ **Research:** Enables hierarchical reasoning analysis
|
|
|
|
## Files Modified
|
|
|
|
- `src/lerobot/policies/pi05/modeling_pi05.py` (main implementation)
|
|
|
|
## Files Created
|
|
|
|
- `examples/dataset/test_subtask_generation.py` (demo script)
|
|
- `SUBTASK_GENERATION_CHANGES.md` (detailed documentation)
|
|
- `SUBTASK_GENERATION_FLOW.md` (visual flow diagrams)
|
|
- `SUMMARY.md` (this file)
|
|
|
|
## Testing
|
|
|
|
To verify the implementation:
|
|
|
|
```bash
|
|
python examples/dataset/test_subtask_generation.py
|
|
```
|
|
|
|
This will check that the tokenizer loads correctly and explain the features.
|
|
|
|
## Next Steps
|
|
|
|
To see subtask generation in action:
|
|
|
|
1. **During Training:**
|
|
- Run your training script as usual
|
|
- Watch console for `[Training] Ground truth subtask` messages
|
|
|
|
2. **During Inference:**
|
|
- Run your inference script as usual
|
|
- Watch console for `[Inference] Generated subtask` messages
|
|
|
|
## Technical Details
|
|
|
|
- **Generation Method:** Autoregressive (one token at a time)
|
|
- **Decoding Strategy:** Greedy (always select most likely token)
|
|
- **Max Tokens:** 50 (configurable via `max_subtask_tokens` parameter)
|
|
- **Attention:** Causal masking for generated tokens
|
|
- **Tokenizer:** PaliGemma tokenizer (google/paligemma-3b-pt-224)
|
|
- **Performance:** Adds ~50 forward passes during inference (can be optimized with KV caching)
|
|
|
|
## Notes
|
|
|
|
- The implementation follows the same pattern as training (using LM head for prediction)
|
|
- Subtask generation happens before action prediction
|
|
- Generated subtasks are currently for visualization only (not used in action prediction)
|
|
- In future, could be used for hierarchical planning or multi-step reasoning
|
|
|
|
## Related Documentation
|
|
|
|
- See `SUBTASK_GENERATION_CHANGES.md` for detailed technical documentation
|
|
- See `SUBTASK_GENERATION_FLOW.md` for visual flow diagrams
|
|
- See training forward pass (lines 735-842) for reference implementation
|
|
|
|
|
|
|
|
|
|
|
|
|