mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
more changes
This commit is contained in:
+152
@@ -0,0 +1,152 @@
|
||||
# Subtask Token Generation Implementation - Summary
|
||||
|
||||
## What Was Implemented
|
||||
|
||||
I've successfully added **autoregressive subtask token generation and decoding** to the PI05 model, enabling the model to:
|
||||
|
||||
1. **During Training:** Decode and print ground truth subtask tokens for monitoring
|
||||
2. **During Inference:** Generate subtask tokens using next token prediction and print them
|
||||
|
||||
## Key Changes
|
||||
|
||||
### 1. New Method: `_generate_subtask_tokens()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 844-914)
|
||||
|
||||
- Implements autoregressive token generation using greedy decoding
|
||||
- Uses the PaliGemma language model head for token prediction
|
||||
- Generates tokens one at a time, each conditioned on previous tokens
|
||||
- Stops when EOS token is generated or max length (50 tokens) is reached
|
||||
|
||||
### 2. Updated `sample_actions()` Method
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 916-1020)
|
||||
|
||||
- Added optional `tokenizer` and `max_subtask_tokens` parameters
|
||||
- Calls `_generate_subtask_tokens()` during inference if tokenizer is provided
|
||||
- Decodes and prints generated subtask tokens
|
||||
|
||||
### 3. Updated `PI05Policy.__init__()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1066-1099)
|
||||
|
||||
- Loads PaliGemma tokenizer (`google/paligemma-3b-pt-224`) for decoding
|
||||
- Stores as `self.tokenizer` for use throughout the policy
|
||||
|
||||
### 4. Updated `predict_action_chunk()`
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1387-1409)
|
||||
|
||||
- Passes tokenizer to `sample_actions()` to enable subtask generation
|
||||
|
||||
### 5. Updated `forward()` (Training Method)
|
||||
**File:** `src/lerobot/policies/pi05/modeling_pi05.py` (lines 1411-1445)
|
||||
|
||||
- Decodes and prints ground truth subtask tokens during training
|
||||
- Helps monitor what the model is learning to predict
|
||||
|
||||
## How It Works
|
||||
|
||||
### During Inference:
|
||||
|
||||
```
|
||||
1. Initialize with prefix: [images, high-level task, state]
|
||||
2. Generate tokens autoregressively:
|
||||
- Forward pass → get logits
|
||||
- Select most likely token (greedy decoding)
|
||||
- Embed token and append to prefix
|
||||
- Repeat until EOS or max length
|
||||
3. Decode generated tokens to text
|
||||
4. Print: "[Inference] Generated subtask {i}: {text}"
|
||||
5. Continue with action prediction (flow matching)
|
||||
```
|
||||
|
||||
### During Training:
|
||||
|
||||
```
|
||||
1. Extract ground truth subtask tokens from batch
|
||||
2. Remove padding and decode to text
|
||||
3. Print: "[Training] Ground truth subtask {i}: {text}"
|
||||
4. Continue with normal training (subtask loss + flow loss)
|
||||
```
|
||||
|
||||
## Example Output
|
||||
|
||||
### Training:
|
||||
```
|
||||
[Training] Ground truth subtask 0: pick up the red block
|
||||
[Training] Ground truth subtask 1: move to the blue container
|
||||
[Training] Ground truth subtask 2: place the object down
|
||||
```
|
||||
|
||||
### Inference:
|
||||
```
|
||||
[Inference] Generated subtask 0: grasp the object
|
||||
[Inference] Generated subtask 1: move to target location
|
||||
[Inference] Generated subtask 2: release the gripper
|
||||
```
|
||||
|
||||
## Benefits
|
||||
|
||||
1. ✓ **Transparency:** See what subtasks the model predicts
|
||||
2. ✓ **Debugging:** Verify subtask prediction works correctly
|
||||
3. ✓ **Interpretability:** Understand the model's reasoning
|
||||
4. ✓ **Monitoring:** Track subtask quality during training
|
||||
5. ✓ **Research:** Enables hierarchical reasoning analysis
|
||||
|
||||
## Files Modified
|
||||
|
||||
- `src/lerobot/policies/pi05/modeling_pi05.py` (main implementation)
|
||||
|
||||
## Files Created
|
||||
|
||||
- `examples/dataset/test_subtask_generation.py` (demo script)
|
||||
- `SUBTASK_GENERATION_CHANGES.md` (detailed documentation)
|
||||
- `SUBTASK_GENERATION_FLOW.md` (visual flow diagrams)
|
||||
- `SUMMARY.md` (this file)
|
||||
|
||||
## Testing
|
||||
|
||||
To verify the implementation:
|
||||
|
||||
```bash
|
||||
python examples/dataset/test_subtask_generation.py
|
||||
```
|
||||
|
||||
This will check that the tokenizer loads correctly and explain the features.
|
||||
|
||||
## Next Steps
|
||||
|
||||
To see subtask generation in action:
|
||||
|
||||
1. **During Training:**
|
||||
- Run your training script as usual
|
||||
- Watch console for `[Training] Ground truth subtask` messages
|
||||
|
||||
2. **During Inference:**
|
||||
- Run your inference script as usual
|
||||
- Watch console for `[Inference] Generated subtask` messages
|
||||
|
||||
## Technical Details
|
||||
|
||||
- **Generation Method:** Autoregressive (one token at a time)
|
||||
- **Decoding Strategy:** Greedy (always select most likely token)
|
||||
- **Max Tokens:** 50 (configurable via `max_subtask_tokens` parameter)
|
||||
- **Attention:** Causal masking for generated tokens
|
||||
- **Tokenizer:** PaliGemma tokenizer (google/paligemma-3b-pt-224)
|
||||
- **Performance:** Adds ~50 forward passes during inference (can be optimized with KV caching)
|
||||
|
||||
## Notes
|
||||
|
||||
- The implementation follows the same pattern as training (using LM head for prediction)
|
||||
- Subtask generation happens before action prediction
|
||||
- Generated subtasks are currently for visualization only (not used in action prediction)
|
||||
- In future, could be used for hierarchical planning or multi-step reasoning
|
||||
|
||||
## Related Documentation
|
||||
|
||||
- See `SUBTASK_GENERATION_CHANGES.md` for detailed technical documentation
|
||||
- See `SUBTASK_GENERATION_FLOW.md` for visual flow diagrams
|
||||
- See training forward pass (lines 735-842) for reference implementation
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user