docs(groot): remove optional Flash Attention setup instructions and update base model path for evaluation

This commit is contained in:
lbenhorin
2026-07-03 15:28:19 +03:00
parent 234ad0c9c7
commit 18a1342ecd
+4 -21
View File
@@ -43,25 +43,6 @@ For a source checkout:
pip install -e ".[groot]"
```
### Optional: Flash Attention acceleration
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
2. Install lerobot with the groot extra.
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
## Usage
To use GR00T N1.7:
@@ -141,7 +122,7 @@ lerobot-train \
--dataset.revision=main \
--dataset.video_backend=pyav \
--policy.type=groot \
--policy.base_model_path=$BASE_MODEL \
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
--policy.embodiment_tag=libero_sim \
--policy.push_to_hub=false \
--policy.max_steps=20000 \
@@ -178,9 +159,11 @@ Preliminary LeRobot integration results (GR00T-LeRobot, `eval.n_episodes >= 50`
| **Average** | **93.75%** |
```bash
export MODEL_ID=your_trained_model_on_huggingface
lerobot-eval \
--policy.type=groot \
--policy.base_model_path=$BASE_MODEL \
--policy.base_model_path=$MODEL_ID \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \