Compare commits

..

31 Commits

Author SHA1 Message Date
lbenhorin 7fa2dfefa2 docs(groot): add note on inference.queue_threshold value for stable inference 2026-07-03 15:34:48 +03:00
lbenhorin 5bbdf02750 docs(groot): update training command with image transformation parameters 2026-07-03 15:31:18 +03:00
lbenhorin a28e93f009 Merge branch 'docs/groot-train-eval-commands' into nvidia-gr00t-n17-lerobot 2026-07-03 15:29:32 +03:00
lbenhorin 18a1342ecd docs(groot): remove optional Flash Attention setup instructions and update base model path for evaluation 2026-07-03 15:28:19 +03:00
Andy Wrenn 234ad0c9c7 Remove sample so101 training command 2026-07-03 05:24:49 -07:00
Andy Wrenn 7e2178e66b Add sample so101 training command 2026-07-03 05:20:49 -07:00
lbenhorin 0a1c2cb76c docs(groot): update training and rollout commands with new parameters and dependencies 2026-07-03 12:02:08 +03:00
nv-sachdevkartik a17da38b2a docs(groot): remove checkpoint download note above LIBERO eval 2026-07-02 14:21:30 +00:00
nv-sachdevkartik e99c65f38a docs(groot): restore suite checkpoint download intro sentence 2026-07-02 14:10:09 +00:00
nv-sachdevkartik e88a0d6aef docs(groot): drop hf download step from LIBERO eval, fix intro 2026-07-02 14:09:48 +00:00
nv-sachdevkartik 488650678b docs(groot): use $BASE_MODEL for base_model_path in LIBERO eval 2026-07-02 14:06:03 +00:00
nv-sachdevkartik 92beebcaa9 docs(groot): remove LIBERO checkpoints subdirectory section 2026-07-02 14:05:06 +00:00
nv-sachdevkartik 602d710a73 docs(groot): add LIBERO training command example 2026-07-02 14:02:02 +00:00
nv-sachdevkartik 8a60f06a75 docs(groot): use literal HF repo IDs for dataset/policy repo_id
Public-facing Hub references (--dataset.repo_id, --policy.repo_id) shown as
concrete IDs; local-only values ($OUTPUT_DIR, $JOB_NAME) stay as placeholders.
2026-07-02 13:01:04 +00:00
nv-sachdevkartik be0bfd90cd docs(groot): keep BASE_MODEL export in training command 2026-07-02 12:52:58 +00:00
nv-sachdevkartik a8ded211aa docs(groot): drop export block, reference env vars directly
Use $DATASET_ID / $BASE_MODEL / $REPO_ID / $OUTPUT_DIR / $JOB_NAME as
bare placeholders in the commands without concrete export assignments.
2026-07-02 12:52:38 +00:00
nv-sachdevkartik 62ff497ebc docs(groot): parameterize commands with env vars + fill LIBERO results
- Introduce BASE_MODEL / DATASET_ID / REPO_ID / JOB_NAME / OUTPUT_DIR env vars
  in the training command and reuse OUTPUT_DIR + BASE_MODEL in the rollout cmd.
- Fill the LIBERO benchmark table with GR00T-LeRobot success rates
  (Spatial 94%, Object 98%, Goal 93%, LIBERO 10/Long 90%; avg 93.75%),
  drop the OSS column and XX placeholders. LeRobot-focused.
2026-07-02 12:49:35 +00:00
nv-sachdevkartik 9ec4530248 docs(groot): update Training & hardware Evaluation commands
Replace the multi-GPU accelerate-launch Training snippet with the current
single-command 'uv run lerobot-train' N1.7 recipe (relative actions excluding
gripper, bf16, flash attention, chunk/n_action_steps=16, bs64/20k steps).

Replace the bimanual 'Evaluate in your hardware setup' rollout example with the
SO-101 follower RTC 'uv run lerobot-rollout' command (strategy.type=base,
inference.type=rtc, wrist+front cameras, place-the-vial task).

Docs-only; no source/test changes.
2026-07-02 12:41:13 +00:00
acwrenn53 1ce1c01337 Merge pull request #42 from johnnynunez/split/groot-n17-train-random-crop
feat(groot): train-time random crop for N1.7 (crop geometry only, zero new deps)
2026-07-02 03:18:43 -07:00
johnnynunez f53490c15e feat(groot): train-time random crop for N1.7 (eval keeps center crop)
Isaac-GR00T crops a random crop_fraction window during training and the
deterministic center window at eval, replaying the sampled window across all
camera views of a sample. This contract is unchanged since the N1.5 release
(gr00t/data/transform/video.py: "If mode is 'train', return a random crop
transform. If mode is 'eval', return a center crop transform.") and mirrors
LeRobot's own Diffusion/VQBeT crop_is_random pattern. The LeRobot N1.7 port
used the eval center crop for training too, so the fine-tuned projector/DiT
never sees frame borders and trains on a single fixed appearance point.

Scope: crop geometry ONLY - no color jitter, no new dependencies. The random
window is plain numpy slicing inside the existing cv2 eval transform:

- _transform_n1_7_image_for_vlm_albumentations gains crop_position=(y, x)
  fractions; None keeps the center crop byte-identical to before (verified
  by test)
- GrootN17VLMEncodeStep gains a runtime-only 'training' flag (never
  serialized; reloaded pipelines default to eval); training samples ONE
  window per sample and reuses it across (timestep, view) frames - Isaac's
  cross-view consistency
- gated on torch.is_grad_enabled() so no_grad validation and frozen-eval
  paths are unaffected
- wired via dataset_meta is not None in make_groot_pre_post_processors and
  the existing _set_groot_preprocessor_training on serialized reloads

Verification: tests/policies/groot/test_groot_train_random_crop.py (8 passed:
center-crop bit-exactness with crop_position=None, corner/center windows,
cross-view replay, train!=eval, no_grad gating, seed reproducibility,
serialization contract) + groot suite 23 passed / 5 skipped on RTX PRO 6000 /
CUDA 13.3.
2026-07-02 03:17:47 +02:00
acwrenn53 459d416bbf Merge pull request #41 from johnnynunez/split/groot-n17-state-dropout
feat(groot): activate checkpoint-configured N1.7 raw-state dropout during training
2026-07-01 16:16:48 -07:00
acwrenn53 bb6f5a2c7e Merge pull request #39 from johnnynunez/split/groot-n17-training-optim-contract
fix(groot): align N1.7 fine-tuning optimizer/scheduler/precision with Isaac-GR00T
2026-07-01 16:16:34 -07:00
johnnynunez f42cdcf137 fix(groot): align N1.7 fine-tuning optimizer/scheduler/precision with Isaac-GR00T
Evidence from the LeRobot-vs-OSS checkpoint comparison: the LeRobot/HF 8k
checkpoint's DiT moved only ~19% as far from base as the OSS-trained one
(0.0547 vs 0.285 relative L2) - undertrained because the scheduler decayed over
a hardcoded 10k steps regardless of --steps, on top of beta1/clip mismatches.

- AdamW betas (0.95, 0.999) -> (0.9, 0.999) and grad_clip_norm 10.0 -> 1.0
  (Isaac defaults)
- scheduler: hardcoded CosineDecayWithWarmup(10k decay, floor 10% peak) ->
  DiffuserSchedulerConfig HF cosine with ceil(max_steps * warmup_ratio) warmup,
  deriving num_training_steps from the outer --steps at runtime
- model_params_fp32 (default true): keep master weights in FP32 and compute
  under BF16 autocast like the native N1.7 recipe (fixes optimizer-update
  numerics vs pure-BF16 params)
- weight-decay grouping via transformers get_parameter_names: biases and norm
  parameters excluded from decay
- restore the TF4 lm_head/embedding weight tie so the unused Qwen LM head stays
  frozen and deduplicated in checkpoints
- action_mask kept in native dtype for the masked flow-matching loss
- drop_n_last_frames: exclude episode tails that cannot supply a complete
  action chunk (Isaac sampler behavior)

Verification: tests/policies/groot/test_groot_training_optim_contract.py
(7 passed) + remaining groot suite 11 passed/5 skipped on RTX PRO 6000 /
CUDA 13.3. Note: tests/policies/groot/test_groot_n1_7.py does not collect on
the base branch (pre-existing ImportError, fixed in PR #37).
2026-07-02 01:04:23 +02:00
johnnynunez 20c0f07858 feat(groot): activate checkpoint-configured N1.7 raw-state dropout during training
Isaac-GR00T applies dual state regularization during fine-tuning: raw-state
zeroing driven by the processor sidecar's state_dropout_prob (0.2 for the
inspected N1.7 checkpoint) plus encoded-feature dropout. Baseline LeRobot kept
the processor in deterministic mode, so the raw-state dropout never activated
(RCA Tier-2 contributor to the LeRobot-trained SO-101 failures).

- GrootN17PackInputsStep: runtime-only 'training' flag + state_dropout_prob;
  whole-sample state zeroing gated on torch.is_grad_enabled() so eval and
  no_grad validation paths are unaffected
- sidecar loader reads state_dropout_prob from processor_config.json
- state_dropout_prob serializes with the step; the training flag intentionally
  does not (reloaded pipelines default to eval, re-enabled only when processors
  are rebuilt with dataset_meta)
- _set_groot_preprocessor_training toggles any dataclass step exposing a
  'training' field on serialized-pipeline reloads

Verification: tests/policies/groot/test_groot_state_dropout.py (4 passed) on
RTX PRO 6000 / CUDA 13.3.
2026-07-02 00:54:20 +02:00
Andy Wrenn da9ce79678 fix(groot): make N1.7 letterbox opt-in 2026-06-30 15:46:56 -07:00
Steven Palma c74eb20abd fix(test): add guard 2026-06-30 15:46:56 -07:00
Steven Palma 22c1d0765a chore(policies): add explicit dataset dependecy to gr00t implementation 2026-06-30 15:46:56 -07:00
Steven Palma 86e60499d0 chore(groot): move cv2 to the top as its in the default install tag 2026-06-30 15:46:56 -07:00
Steven Palma 73c3a66d51 fix(ci): guard dependecy checks 2026-06-30 15:46:56 -07:00
Steven Palma b422269de4 fix(style): pre-commit 2026-06-30 15:46:56 -07:00
Steven Palma 44b6950f06 chore(policies): add guards, warnings and comments + recover tests n1.5 check 2026-06-30 15:46:56 -07:00
9 changed files with 833 additions and 156 deletions
+125 -80
View File
@@ -43,25 +43,6 @@ For a source checkout:
pip install -e ".[groot]"
```
### Optional: Flash Attention acceleration
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
2. Install lerobot with the groot extra.
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
## Usage
To use GR00T N1.7:
@@ -76,26 +57,49 @@ To use GR00T N1.7:
Here's a complete training command for finetuning the base GR00T model on your own dataset:
This command is using the `new_embodiment` flag, which is used for the SO-101 robot, [read more about how GR00T handles different embodiments.](https://github.com/NVIDIA/Isaac-GR00T/blob/main/getting_started/policy.md#--embodiment-tag).
```bash
# Using a multi-GPU setup
accelerate launch \
--multi_gpu \
--num_processes=$NUM_GPUS \
$(which lerobot-train) \
--output_dir=$OUTPUT_DIR \
--save_checkpoint=true \
--batch_size=$BATCH_SIZE \
--steps=$NUM_STEPS \
--save_freq=$SAVE_FREQ \
--log_freq=$LOG_FREQ \
--policy.push_to_hub=true \
# install extra deps for training
pip install "lerobot[training]"
hf auth login
wandb login
export DATASET_NAME=your_data_set
export HF_USER=your_hf_username
export DATASET=$HF_USER/$DATASET_NAME
export REPO_ID="${DATASET}_GR00T17" #this is the model that will be uploaded to huggingface
export OUTPUT_DIR=outputs/train/$REPO_ID
lerobot-train \
--dataset.repo_id=$DATASET \
--dataset.image_transforms.enable=true \
--policy.type=groot \
--policy.device=cuda \
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
--policy.embodiment_tag=new_embodiment \
--policy.chunk_size=16 \
--policy.n_action_steps=16 \
--policy.use_relative_actions=true \
--policy.relative_exclude_joints='["gripper"]' \
--policy.use_bf16=true \
--policy.push_to_hub=true \
--policy.repo_id=$REPO_ID \
--policy.tune_diffusion_model=false \
--dataset.repo_id=$DATASET_ID \
--seed=42 \
--batch_size=64 \
--steps=20000 \
--save_checkpoint=true \
--save_freq=5000 \
--use_policy_training_preset=true \
--env_eval_freq=0 \
--eval_steps=0 \
--log_freq=10 \
--output_dir=$OUTPUT_DIR \
--job_name=$DATASET \
--wandb.enable=true \
--wandb.disable_artifact=true \
--job_name=$JOB_NAME
--wandb.disable_artifact=true
```
## Performance Results
@@ -107,39 +111,66 @@ accelerate launch \
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
### GR00T N1.7 LIBERO Checkpoints
### Train on LIBERO
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
| Suite | Checkpoint subdirectory |
| -------------- | ----------------------- |
| LIBERO Spatial | `libero_spatial` |
| LIBERO Object | `libero_object` |
| LIBERO Goal | `libero_goal` |
| LIBERO 10 | `libero_10` |
Preliminary LeRobot integration results:
| Suite | Status | Success rate | n_episodes |
| -------------- | ------ | -----------: | ---------: |
| LIBERO Spatial | ✓ | ~95% | XX |
| LIBERO Object | ✓ | XX% | XX |
| LIBERO Goal | ✓ | XX% | XX |
| LIBERO 10 | ✓ | XX% | XX |
| **Average** | ✓ | **XX%** | **XX** |
Replace the `XX` placeholders with final eval artifacts before merge.
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
Example training command for a LIBERO suite (here `libero_spatial`):
```bash
hf download nvidia/GR00T-N1.7-LIBERO \
--include "libero_spatial/*" \
--local-dir ./GR00T-N1.7-LIBERO
IMAGE_TRANSFORMS='{
"brightness": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"brightness": [0.7, 1.3]}},
"contrast": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"contrast": [0.6, 1.4]}},
"saturation": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"saturation": [0.5, 1.5]}},
"hue": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"hue": [-0.08, 0.08]}}
}'
lerobot-train \
--dataset.repo_id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
--dataset.root=/datasets/libero_spatial \
--dataset.revision=main \
--dataset.video_backend=pyav \
--dataset.image_transforms.enable=true \
--dataset.image_transforms.max_num_transforms=4 \
--dataset.image_transforms.tfs="$IMAGE_TRANSFORMS" \
--policy.type=groot \
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
--policy.embodiment_tag=libero_sim \
--policy.push_to_hub=false \
--policy.max_steps=20000 \
--batch_size=320 \
--steps=20000 \
--save_freq=2000 \
--env_eval_freq=0 \
--eval_steps=0 \
--log_freq=10 \
--wandb.enable=true \
--wandb.project=lerobot \
--wandb.mode=online \
--wandb.disable_artifact=true \
--num_workers=4 \
--prefetch_factor=2 \
--persistent_workers=true \
--output_dir=$OUTPUT_DIR \
--job_name=$JOB_NAME
```
### GR00T N1.7 LIBERO Results
Preliminary LeRobot integration results (GR00T-LeRobot, `eval.n_episodes >= 50` per suite):
| Suite | Success rate |
| ---------------------- | -----------: |
| LIBERO Spatial | 94% |
| LIBERO Object | 98% |
| LIBERO Goal | 93% |
| LIBERO 10 (Long) | 90% |
| **Average** | **93.75%** |
```bash
export MODEL_ID=your_trained_model_on_huggingface
lerobot-eval \
--policy.type=groot \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.base_model_path=$MODEL_ID \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
--env.task=libero_spatial \
@@ -153,27 +184,41 @@ Use `eval.n_episodes >= 50` per suite when reporting success rates.
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
```bash
lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \
--robot.id=bimanual_follower \
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
}' \
# install extra deps for roullout and real hardware
pip install "lerobot[feetech,viz]"
export MODEL_ID=your_trained_model_on_huggingface
# make sure that camera index matches your setup!
# find index using `uv run lerobot-find-cameras opencv`
WRIST_CAM='wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: "MJPG"}'
FRONT_CAM='front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: "MJPG"}'
export ROBOT_CAMERAS="{ $WRIST_CAM, $FRONT_CAM }"
export ROBOT_ID=follower_robot
export ROBOT_PORT=/dev/ttyACM0
uv run lerobot-rollout \
--strategy.type=base \
--policy.path=$MODEL_ID \
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
--policy.n_action_steps=8 \
--robot.type=so101_follower \
--robot.port=$ROBOT_PORT \
--robot.id=$ROBOT_ID \
--robot.cameras="$ROBOT_CAMERAS" \
--task="place the vial in the rack" \
--duration=60 \
--device=cuda \
--display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.rgb_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model
--duration=600
--inference.type=rtc \
--inference.rtc.enabled=false \
--inference.rtc.execution_horizon=8 \
--inference.queue_threshold=0
```
> [!NOTE]
> Value of `inference.queue_threshold` should not exeed 0.5 to ensure stable inference.
## License
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
@@ -15,11 +15,12 @@
# limitations under the License.
import logging
import math
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.optim import AdamWConfig, DiffuserSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from .utils import read_json
@@ -336,11 +337,14 @@ class GrootConfig(PreTrainedConfig):
# Training parameters
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
# Isaac-GR00T N1.7 fine-tunes with AdamW betas (0.9, 0.999).
optimizer_betas: tuple[float, float] = (0.9, 0.999)
optimizer_eps: float = 1e-8
optimizer_weight_decay: float = 1e-5
warmup_ratio: float = 0.05
use_bf16: bool = True
# The native N1.7 fine-tuning recipe keeps model parameters in FP32 and computes under BF16 autocast.
model_params_fp32: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
@@ -480,15 +484,20 @@ class GrootConfig(PreTrainedConfig):
betas=self.optimizer_betas,
eps=self.optimizer_eps,
weight_decay=self.optimizer_weight_decay,
grad_clip_norm=1.0,
)
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
"""Return scheduler configuration."""
return CosineDecayWithWarmupSchedulerConfig(
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
num_decay_steps=10000, # Adjust based on training steps
peak_lr=self.optimizer_lr,
decay_lr=self.optimizer_lr * 0.1,
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
"""Return scheduler configuration.
Isaac-GR00T uses the HF Trainer cosine schedule with ~5% warmup over the
actual training update count; DiffuserSchedulerConfig wraps the same
diffusers/transformers `get_scheduler("cosine")` implementation and
derives num_training_steps from the outer --steps value at runtime.
"""
return DiffuserSchedulerConfig(
name="cosine",
num_warmup_steps=math.ceil(self.max_steps * self.warmup_ratio),
)
@property
@@ -504,6 +513,11 @@ class GrootConfig(PreTrainedConfig):
)
return list(range(min(self.chunk_size, model_action_horizon)))
@property
def drop_n_last_frames(self) -> int:
"""Exclude episode tails that cannot supply a complete N1.7 action chunk."""
return max(0, len(self.action_delta_indices) - 1)
@property
def reward_delta_indices(self) -> None:
"""Return indices for delta rewards (None for Groot)."""
+15 -1
View File
@@ -60,6 +60,19 @@ except ImportError:
logger = logging.getLogger(__name__)
def _tie_unused_qwen_lm_head(model: nn.Module) -> None:
"""Restore the TF4 weight tie so the unused LM head stays frozen and is omitted on save."""
lm_head = getattr(model, "lm_head", None)
get_input_embeddings = getattr(model, "get_input_embeddings", None)
if lm_head is None or not callable(get_input_embeddings):
return
input_embeddings = get_input_embeddings()
embedding_weight = getattr(input_embeddings, "weight", None)
if embedding_weight is None:
return
lm_head.weight = embedding_weight
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
@@ -288,6 +301,7 @@ class Qwen3Backbone(nn.Module):
config_kwargs=transformers_loading_kwargs,
).eval()
_tie_unused_qwen_lm_head(self.model)
while len(self.language_model.layers) > select_layer:
self.language_model.layers.pop(-1)
@@ -603,7 +617,7 @@ class GR00TN17ActionHead(nn.Module):
pred = self.action_decoder(model_output, embodiment_id)
pred_actions = pred[:, -actions.shape[1] :]
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
action_mask = action_input.action_mask
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
return BatchFeature(
+44 -4
View File
@@ -34,6 +34,7 @@ from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from torch import Tensor
from transformers.trainer_pt_utils import get_parameter_names
from lerobot.configs import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, OBS_IMAGES
@@ -50,7 +51,7 @@ from .configuration_groot import (
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
)
from .groot_n1_7 import GR00TN17
from .groot_n1_7 import GR00TN17, _tie_unused_qwen_lm_head
logger = logging.getLogger(__name__)
@@ -96,11 +97,49 @@ class GrootPolicy(PreTrainedPolicy):
if self.config.rtc_ramp_rate is not None:
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
return GR00TN17.from_pretrained(
model = GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=self.config.tune_vlln,
transformers_loading_kwargs={"trust_remote_code": True},
)
backbone = getattr(model, "backbone", None)
qwen_model = getattr(backbone, "model", None)
if qwen_model is not None:
_tie_unused_qwen_lm_head(qwen_model)
if self.config.model_params_fp32:
self._cast_model_parameters_to_fp32(model)
return model
@staticmethod
def _cast_model_parameters_to_fp32(model: torch.nn.Module) -> None:
for parameter in model.parameters():
if parameter.is_floating_point():
parameter.data = parameter.data.to(torch.float32)
@staticmethod
def _build_weight_decay_parameter_groups(model: torch.nn.Module) -> list[dict[str, object]]:
forbidden_name_patterns = [
r"bias",
r"layernorm",
r"rmsnorm",
r"(?:^|\.)norm(?:$|\.)",
r"_norm(?:$|\.)",
]
decay_names = set(get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns))
decay_params = [
parameter
for name, parameter in model.named_parameters()
if parameter.requires_grad and name in decay_names
]
no_decay_params = [
parameter
for name, parameter in model.named_parameters()
if parameter.requires_grad and name not in decay_names
]
return [
{"params": decay_params},
{"params": no_decay_params, "weight_decay": 0.0},
]
def reset(self):
"""Reset policy state when environment resets."""
@@ -238,8 +277,9 @@ class GrootPolicy(PreTrainedPolicy):
policy.eval()
return policy
def get_optim_params(self) -> dict:
return self.parameters()
def get_optim_params(self): # type: ignore[override]
"""Isaac-GR00T excludes biases and normalization parameters from weight decay."""
return self._build_weight_decay_parameter_groups(self)
def _resolve_action_queue_steps(self) -> int:
n_action_steps = int(self.config.n_action_steps)
+116 -26
View File
@@ -15,6 +15,7 @@
# limitations under the License.
import logging
import random
from copy import copy, deepcopy
from dataclasses import dataclass, field, fields, is_dataclass
from pathlib import Path
@@ -136,6 +137,7 @@ class _GrootN17CheckpointProcessorAssets:
video_horizon: int | None
use_percentiles: bool
use_relative_action: bool
state_dropout_prob: float
clip_outliers: bool
video_modality_keys: list[str] | None
image_crop_size: list[int] | None
@@ -143,6 +145,7 @@ class _GrootN17CheckpointProcessorAssets:
shortest_image_edge: int | None
crop_fraction: float | None
use_albumentations: bool
letter_box_transform: bool
@dataclass(frozen=True)
@@ -181,6 +184,9 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
modality_config = {}
use_relative_action = bool(processor_kwargs.get("use_relative_action", False))
state_dropout_prob = as_optional_float(processor_kwargs.get("state_dropout_prob"))
if state_dropout_prob is None:
state_dropout_prob = 0.0
stats = _load_n1_7_checkpoint_stats(
checkpoint_path,
processor_kwargs,
@@ -199,6 +205,9 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
use_albumentations = processor_kwargs.get("use_albumentations", False)
if not isinstance(use_albumentations, bool):
use_albumentations = False
letter_box_transform = processor_kwargs.get("letter_box_transform", False)
if not isinstance(letter_box_transform, bool):
letter_box_transform = False
valid_action_horizon = _load_n1_7_checkpoint_action_horizon(processor_kwargs, config.embodiment_tag)
video_horizon = _load_n1_7_checkpoint_video_horizon(processor_kwargs, config.embodiment_tag)
@@ -218,6 +227,7 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
video_horizon=video_horizon,
use_percentiles=bool(processor_kwargs.get("use_percentiles", False)),
use_relative_action=use_relative_action,
state_dropout_prob=state_dropout_prob,
clip_outliers=clip_outliers,
video_modality_keys=video_modality_keys,
image_crop_size=as_int_pair(processor_kwargs.get("image_crop_size")),
@@ -225,6 +235,7 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
shortest_image_edge=as_optional_int(processor_kwargs.get("shortest_image_edge")),
crop_fraction=as_optional_float(processor_kwargs.get("crop_fraction")),
use_albumentations=use_albumentations,
letter_box_transform=letter_box_transform,
)
@@ -440,6 +451,22 @@ def _apply_groot_step_overrides(
post_init()
def _set_groot_preprocessor_training(
preprocessor: PolicyProcessorPipeline,
*,
training: bool,
) -> None:
"""Set the runtime-only mode of GR00T stochastic processor steps.
Any dataclass step exposing a ``training`` field participates, so processor
steps can opt into train-time-only behavior (dropout, augmentation) without
this helper enumerating them.
"""
for step in preprocessor.steps:
if is_dataclass(step) and any(f.name == "training" for f in fields(step)):
setattr(step, "training", training)
def make_groot_pre_post_processors_from_pretrained(
config: GrootConfig,
pretrained_path: str,
@@ -488,6 +515,7 @@ def make_groot_pre_post_processors_from_pretrained(
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
_apply_groot_action_decode_transform(postprocessor, config.action_decode_transform)
_set_groot_preprocessor_training(preprocessor, training=dataset_meta is not None)
return preprocessor, postprocessor
@@ -1002,7 +1030,6 @@ def _build_n1_7_relative_action_processor_assets(
}
for group in groups
]
# 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B)
action_horizon = min(config.chunk_size, 40)
modality_config: dict[str, Any] = {
"state": {"modality_keys": [group.key for group in groups]},
@@ -1052,6 +1079,7 @@ def _build_n1_7_relative_action_processor_assets(
video_horizon=base_assets.video_horizon if base_assets is not None else None,
use_percentiles=use_percentiles,
use_relative_action=True,
state_dropout_prob=base_assets.state_dropout_prob if base_assets is not None else 0.0,
clip_outliers=base_assets.clip_outliers if base_assets is not None else True,
video_modality_keys=video_modality_keys,
image_crop_size=base_assets.image_crop_size if base_assets is not None else None,
@@ -1059,6 +1087,7 @@ def _build_n1_7_relative_action_processor_assets(
shortest_image_edge=base_assets.shortest_image_edge if base_assets is not None else None,
crop_fraction=base_assets.crop_fraction if base_assets is not None else None,
use_albumentations=base_assets.use_albumentations if base_assets is not None else False,
letter_box_transform=base_assets.letter_box_transform if base_assets is not None else False,
)
@@ -1155,6 +1184,8 @@ def make_groot_pre_post_processors(
embodiment_tag=config.embodiment_tag,
embodiment_mapping=embodiment_mapping,
normalize_min_max=True,
training=dataset_meta is not None,
state_dropout_prob=(checkpoint_assets.state_dropout_prob if checkpoint_assets is not None else 0.0),
stats=padded_stats,
clip_outliers=clip_outliers,
video_modality_keys=video_modality_keys,
@@ -1180,6 +1211,7 @@ def make_groot_pre_post_processors(
shortest_image_edge = None
crop_fraction = None
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
letter_box_transform = checkpoint_assets.letter_box_transform if checkpoint_assets is not None else False
input_steps: list[ProcessorStep] = [
RenameObservationsProcessorStep(rename_map={}),
@@ -1192,6 +1224,8 @@ def make_groot_pre_post_processors(
shortest_image_edge=shortest_image_edge,
crop_fraction=crop_fraction,
use_albumentations=use_albumentations,
letter_box_transform=letter_box_transform,
training=dataset_meta is not None,
device=config.device,
),
DeviceProcessorStep(device=config.device),
@@ -1316,6 +1350,8 @@ def _transform_n1_7_image_for_vlm_albumentations(
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
letter_box_transform: bool = False,
crop_position: tuple[float, float] | None = None,
) -> np.ndarray:
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
@@ -1325,6 +1361,12 @@ def _transform_n1_7_image_for_vlm_albumentations(
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
torch path and must stay bit-exact to the upstream reference. The hot path accepts
and returns numpy arrays to avoid per-frame PIL round-trips.
``crop_position`` selects where the ``crop_fraction`` window sits: ``None``
keeps the deterministic center crop (eval contract), while ``(y, x)``
fractions in [0, 1] place the window for Isaac's train-time random crop
(0.5, 0.5 == center). Training samples one position per sample and reuses
it across camera views.
"""
if image_target_size is None:
return image
@@ -1340,6 +1382,18 @@ def _transform_n1_7_image_for_vlm_albumentations(
if not image_np.flags.c_contiguous:
image_np = np.ascontiguousarray(image_np)
if letter_box_transform:
height, width = image_np.shape[:2]
if height != width:
square_edge = max(height, width)
pad_h = square_edge - height
pad_w = square_edge - width
top = pad_h // 2
bottom = pad_h - top
left = pad_w // 2
right = pad_w - left
image_np = cv2.copyMakeBorder(image_np, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
resize_edge = shortest_image_edge or target_h
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
@@ -1364,8 +1418,13 @@ def _transform_n1_7_image_for_vlm_albumentations(
height, width = image_np.shape[:2]
crop_h = max(1, int(height * crop_fraction))
crop_w = max(1, int(width * crop_fraction))
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
if crop_position is None:
top = max(0, (height - crop_h) // 2)
left = max(0, (width - crop_w) // 2)
else:
pos_y, pos_x = crop_position
top = int(round((height - crop_h) * min(max(pos_y, 0.0), 1.0)))
left = int(round((width - crop_w) * min(max(pos_x, 0.0), 1.0)))
image_np = image_np[top : top + crop_h, left : left + crop_w]
return resize_shortest_edge(image_np)
@@ -1378,9 +1437,12 @@ def _transform_n1_7_image_for_vlm_torch(
image_target_size: list[int] | None,
shortest_image_edge: int | None,
crop_fraction: float | None,
letter_box_transform: bool = False,
) -> torch.Tensor:
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
"""Default (non-albumentations) N1.7 image transform.
Optionally pads to square, then resizes to ``shortest_image_edge``, center-crops
by ``crop_fraction``, and resizes to ``image_target_size``.
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
@@ -1395,13 +1457,14 @@ def _transform_n1_7_image_for_vlm_torch(
target_h, target_w = image_target_size
_, height, width = image.shape
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
if letter_box_transform:
square_edge = max(height, width)
if height != width:
left = (square_edge - width) // 2
top = (square_edge - height) // 2
image = tv_functional.pad(
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
)
resize_edge = shortest_image_edge or target_h
image = tv_functional.resize(
@@ -1448,6 +1511,8 @@ class GrootN17PackInputsStep(ProcessorStep):
embodiment_tag: str = "new_embodiment"
embodiment_mapping: dict[str, int] = field(default_factory=lambda: dict(N1_7_EMBODIMENT_MAPPING))
normalize_min_max: bool = True
training: bool = False
state_dropout_prob: float = 0.0
stats: dict[str, dict[str, Any]] | None = None
clip_outliers: bool = True
use_percentiles: bool = False
@@ -1779,6 +1844,13 @@ class GrootN17PackInputsStep(ProcessorStep):
if dim < self.max_state_dim:
pad = torch.zeros(bsz, 1, self.max_state_dim - dim, dtype=state.dtype, device=state.device)
state = torch.cat([state, pad], dim=2)
if self.training and torch.is_grad_enabled() and self.state_dropout_prob > 0:
drop_state = torch.tensor(
[random.random() < self.state_dropout_prob for _ in range(bsz)],
dtype=torch.bool,
device=state.device,
).view(bsz, 1, 1)
state = state.masked_fill(drop_state, 0)
obs["state"] = state
action = transition.get(TransitionKey.ACTION)
@@ -1890,6 +1962,7 @@ class GrootN17PackInputsStep(ProcessorStep):
"embodiment_tag": self.embodiment_tag,
"embodiment_mapping": self.embodiment_mapping,
"normalize_min_max": self.normalize_min_max,
"state_dropout_prob": self.state_dropout_prob,
"clip_outliers": self.clip_outliers,
"use_percentiles": self.use_percentiles,
"video_modality_keys": self.video_modality_keys,
@@ -1946,6 +2019,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
shortest_image_edge: int | None = None
crop_fraction: float | None = None
use_albumentations: bool = False
letter_box_transform: bool = False
# Runtime-only train/eval mode: True enables Isaac's train-time random crop
# (one window per sample, replayed across views); False keeps the
# deterministic center crop. Never serialized - reloaded pipelines default
# to eval and are re-enabled only when processors are built with dataset_meta.
training: bool = False
device: str | None = None
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
@@ -1979,20 +2058,29 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"""
if self.use_albumentations:
video_np = np.asarray(video)
return [
[
_transform_n1_7_image_for_vlm_albumentations(
video_np[batch_idx, timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
for batch_idx in range(batch_size)
]
train_crop = self.training and torch.is_grad_enabled()
sample_images: list[list[Any]] = []
for batch_idx in range(batch_size):
# Isaac-GR00T samples ONE crop window per sample and replays it
# across every (timestep, view) frame of that sample, keeping
# cross-view geometry consistent. Eval keeps the center crop.
crop_position = (random.random(), random.random()) if train_crop else None
sample_images.append(
[
_transform_n1_7_image_for_vlm_albumentations(
video_np[batch_idx, timestep, view_idx],
image_crop_size=self.image_crop_size,
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
crop_position=crop_position,
)
for timestep in range(video_np.shape[1])
for view_idx in range(video_np.shape[2])
]
)
return sample_images
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
@@ -2011,6 +2099,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
image_target_size=self.image_target_size,
shortest_image_edge=self.shortest_image_edge,
crop_fraction=self.crop_fraction,
letter_box_transform=self.letter_box_transform,
)
for timestep in range(sample.shape[0])
for view_idx in range(sample.shape[1])
@@ -2084,6 +2173,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
"shortest_image_edge": self.shortest_image_edge,
"crop_fraction": self.crop_fraction,
"use_albumentations": self.use_albumentations,
"letter_box_transform": self.letter_box_transform,
"device": self.device,
}
+133 -36
View File
@@ -43,8 +43,10 @@ from lerobot.policies.groot.processor_groot import (
GrootN17ActionDecodeStep,
GrootN17PackInputsStep,
GrootN17VLMEncodeStep,
N1_7_NATIVE_ACTION_HORIZON,
_make_relative_action_training_stats,
_transform_n1_7_image_for_vlm_albumentations,
_transform_n1_7_image_for_vlm_torch,
make_groot_pre_post_processors,
)
from lerobot.processor import (
@@ -80,6 +82,14 @@ def _groot_config() -> GrootConfig:
)
def _native_action_chunk(rows: list[list[float]]) -> torch.Tensor:
chunk = torch.tensor(rows, dtype=torch.float32)
if chunk.shape[0] >= N1_7_NATIVE_ACTION_HORIZON:
return chunk[:N1_7_NATIVE_ACTION_HORIZON]
tail = chunk[-1:].repeat(N1_7_NATIVE_ACTION_HORIZON - chunk.shape[0], 1)
return torch.cat([chunk, tail], dim=0)
def _raw_n1_7_libero_config(model_path) -> GrootConfig:
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
return GrootConfig(
@@ -236,6 +246,7 @@ def _write_raw_n1_7_libero_checkpoint(path):
"shortest_image_edge": 256,
"crop_fraction": 0.95,
"use_albumentations": True,
"letter_box_transform": False,
"max_action_horizon": 40,
"max_state_dim": 132,
"max_action_dim": 132,
@@ -600,6 +611,7 @@ def test_raw_n1_7_libero_checkpoint_processors_use_checkpoint_assets(tmp_path):
assert vlm_encode.shortest_image_edge == 256
assert vlm_encode.crop_fraction == 0.95
assert vlm_encode.use_albumentations is True
assert vlm_encode.letter_box_transform is False
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
assert decode_actions.env_action_dim == 7
assert decode_actions.use_percentiles is True
@@ -673,6 +685,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
config_filename="policy_postprocessor.json",
)
pack_inputs = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
vlm_encode = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17VLMEncodeStep))
decode_actions = next(
step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)
)
@@ -681,6 +694,7 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
assert pack_inputs.action_horizon == 40
assert pack_inputs.video_modality_keys == ["image", "wrist_image"]
assert pack_inputs.clip_outliers is True
assert vlm_encode.letter_box_transform is False
torch.testing.assert_close(
pack_inputs.stats[OBS_STATE]["min"],
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
@@ -1849,6 +1863,58 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
np.testing.assert_array_equal(np.asarray(transformed), expected)
def test_groot_n1_7_albumentations_letterbox_is_opt_in():
pytest.importorskip("cv2", exc_type=ImportError)
image = np.full((3, 5, 3), 255, dtype=np.uint8)
default = _transform_n1_7_image_for_vlm_albumentations(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
)
letterboxed = _transform_n1_7_image_for_vlm_albumentations(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
letter_box_transform=True,
)
assert default.shape == (10, 17, 3)
assert default.min() == 255
assert letterboxed.shape == (10, 10, 3)
assert letterboxed.min() < 255
def test_groot_n1_7_torch_letterbox_is_opt_in():
image = torch.full((3, 3, 5), 255, dtype=torch.uint8)
default = _transform_n1_7_image_for_vlm_torch(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
)
letterboxed = _transform_n1_7_image_for_vlm_torch(
image,
image_crop_size=None,
image_target_size=[10, 10],
shortest_image_edge=10,
crop_fraction=None,
letter_box_transform=True,
)
assert tuple(default.shape) == (3, 10, 10)
assert int(default.min()) == 255
assert tuple(letterboxed.shape) == (3, 10, 10)
assert int(letterboxed.min()) < 255
def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core_albumentations():
cv2 = pytest.importorskip("cv2", exc_type=ImportError)
@@ -1919,6 +1985,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
letter_box_transform=True,
)
restored = GrootN17VLMEncodeStep(**step.get_config())
@@ -1929,6 +1996,7 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
assert restored.shortest_image_edge == 256
assert restored.crop_fraction == 0.95
assert restored.use_albumentations is True
assert restored.letter_box_transform is True
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
@@ -2086,7 +2154,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
samples = [
{
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
ACTION: torch.tensor(
ACTION: _native_action_chunk(
[
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
@@ -2095,7 +2163,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
},
{
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
ACTION: torch.tensor(
ACTION: _native_action_chunk(
[
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
@@ -2122,10 +2190,12 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
action_names=action_names,
preserve_action_horizon=True,
)
expected_relative_action_stats = {
"min": torch.tensor([-2.0, -3.0, -4.0, -5.0, -6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.0]),
"max": torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
}
expected_relative_action_min_prefix = torch.tensor(
[-2.0, -3.0, -4.0, -5.0, -6.0, 1.0, 2.0, 3.0, 4.0, 5.0]
)
expected_relative_action_max_prefix = torch.tensor(
[-1.0, -2.0, -3.0, -4.0, -5.0, 2.0, 3.0, 4.0, 5.0, 6.0]
)
preprocessor, postprocessor = make_groot_pre_post_processors(
config, dataset_stats=relative_dataset_stats, dataset_meta=_RelativeStatsDataset.meta
@@ -2148,17 +2218,26 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
]
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["min"] == [
pack_relative_min = pack_config["raw_stats"]["relative_action"]["single_arm"]["min"]
assert pack_relative_min[:2] == [
[-2.0, -3.0, -4.0, -5.0, -6.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
assert len(pack_relative_min) == N1_7_NATIVE_ACTION_HORIZON
assert (
pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2] * N1_7_NATIVE_ACTION_HORIZON
)
assert pack_config["raw_stats"]["action"]["gripper"]["min"] == [0.0]
assert pack_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
pack_state = load_file(tmp_path / pack_entry["state_file"])
torch.testing.assert_close(pack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
torch.testing.assert_close(pack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
expected_flat_dim = N1_7_NATIVE_ACTION_HORIZON * 5 + 1
assert pack_state[f"{ACTION}.min"].shape == (expected_flat_dim,)
assert pack_state[f"{ACTION}.max"].shape == (expected_flat_dim,)
torch.testing.assert_close(pack_state[f"{ACTION}.min"][:10], expected_relative_action_min_prefix)
torch.testing.assert_close(pack_state[f"{ACTION}.max"][:10], expected_relative_action_max_prefix)
assert pack_state[f"{ACTION}.min"][-1].item() == 0.0
assert pack_state[f"{ACTION}.max"][-1].item() == 100.0
postprocessor_config = json.loads((tmp_path / "policy_postprocessor.json").read_text())
assert not any(
@@ -2171,11 +2250,16 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
)
decode_config = decode_entry["config"]
assert decode_config["use_relative_action"] is True
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["max"] == [
decode_relative_max = decode_config["raw_stats"]["relative_action"]["single_arm"]["max"]
assert decode_relative_max[:2] == [
[-1.0, -2.0, -3.0, -4.0, -5.0],
[2.0, 3.0, 4.0, 5.0, 6.0],
]
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
assert len(decode_relative_max) == N1_7_NATIVE_ACTION_HORIZON
assert (
decode_config["raw_stats"]["relative_action"]["single_arm"]["count"]
== [2] * N1_7_NATIVE_ACTION_HORIZON
)
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
@@ -2215,7 +2299,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
samples = [
{
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
ACTION: torch.tensor(
ACTION: _native_action_chunk(
[
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
@@ -2224,7 +2308,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
},
{
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
ACTION: torch.tensor(
ACTION: _native_action_chunk(
[
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
@@ -2255,7 +2339,9 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
assert kwargs["root"] == runtime_meta.root
assert kwargs["revision"] == runtime_meta.revision
assert kwargs["download_videos"] is False
assert kwargs["delta_timestamps"][ACTION] == [0.0, 1 / runtime_meta.fps]
assert kwargs["delta_timestamps"][ACTION] == [
index / runtime_meta.fps for index in range(N1_7_NATIVE_ACTION_HORIZON)
]
return _RelativeStatsDataset()
monkeypatch.setattr("lerobot.policies.groot.processor_groot.LeRobotDataset", _fake_lerobot_dataset)
@@ -2266,11 +2352,15 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
assert not any(isinstance(step, RelativeActionsProcessorStep) for step in preprocessor.steps)
assert isinstance(postprocessor.steps[0], GrootN17ActionDecodeStep)
pack_step = next(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
assert pack_step.raw_stats["relative_action"]["single_arm"]["min"] == [
assert pack_step.action_horizon == N1_7_NATIVE_ACTION_HORIZON
assert pack_step.valid_action_horizon == 2
pack_relative_min = pack_step.raw_stats["relative_action"]["single_arm"]["min"]
assert pack_relative_min[:2] == [
[-2.0, -3.0, -4.0, -5.0, -6.0],
[1.0, 2.0, 3.0, 4.0, 5.0],
]
assert pack_step.raw_stats["relative_action"]["single_arm"]["count"] == [2, 2]
assert len(pack_relative_min) == N1_7_NATIVE_ACTION_HORIZON
assert pack_step.raw_stats["relative_action"]["single_arm"]["count"] == [2] * N1_7_NATIVE_ACTION_HORIZON
assert pack_step.raw_stats["action"]["gripper"]["max"] == [100.0]
@@ -2315,14 +2405,14 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
}
state_a = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 25.0])
state_b = torch.tensor([0.0, -10.0, 10.0, -20.0, 20.0, 75.0])
action_a = torch.tensor(
action_a = _native_action_chunk(
[
[11.0, 22.0, 33.0, 44.0, 55.0, 20.0],
[12.0, 24.0, 36.0, 48.0, 60.0, 80.0],
[13.0, 26.0, 39.0, 52.0, 65.0, 90.0],
]
)
action_b = torch.tensor(
action_b = _native_action_chunk(
[
[-1.0, -8.0, 13.0, -16.0, 25.0, 30.0],
[-2.0, -6.0, 16.0, -12.0, 30.0, 40.0],
@@ -2399,12 +2489,13 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
]
)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["min"][:, :5]), oss_arm_min)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["max"][:, :5]), oss_arm_max)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["mean"][:, :5]), oss_arm_mean)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["std"][:, :5]), oss_arm_std)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q01"][:, :5]), oss_arm_q01)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q99"][:, :5]), oss_arm_q99)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["min"][:3, :5]), oss_arm_min)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["max"][:3, :5]), oss_arm_max)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["mean"][:3, :5]), oss_arm_mean)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["std"][:3, :5]), oss_arm_std)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q01"][:3, :5]), oss_arm_q01)
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q99"][:3, :5]), oss_arm_q99)
assert torch.as_tensor(relative_dataset_stats[ACTION]["min"]).shape[0] == N1_7_NATIVE_ACTION_HORIZON
preprocessor, postprocessor = make_groot_pre_post_processors(
config,
@@ -2415,16 +2506,16 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
decode_step = next(step for step in postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep))
assert pack_step.use_percentiles is True
torch.testing.assert_close(
torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["min"]),
oss_arm_min,
)
torch.testing.assert_close(
torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["q99"]),
oss_arm_q99,
)
assert pack_step.stats[ACTION]["min"] == pytest.approx([*oss_arm_min.flatten().tolist(), 20.0])
assert pack_step.stats[ACTION]["max"] == pytest.approx([*oss_arm_max.flatten().tolist(), 90.0])
pack_relative_min = torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["min"])
pack_relative_q99 = torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["q99"])
assert pack_relative_min.shape == (N1_7_NATIVE_ACTION_HORIZON, 5)
assert pack_relative_q99.shape == (N1_7_NATIVE_ACTION_HORIZON, 5)
torch.testing.assert_close(pack_relative_min[:3], oss_arm_min)
torch.testing.assert_close(pack_relative_q99[:3], oss_arm_q99)
assert pack_step.stats[ACTION]["min"][:15] == pytest.approx(oss_arm_min.flatten().tolist())
assert pack_step.stats[ACTION]["max"][:15] == pytest.approx(oss_arm_max.flatten().tolist())
assert pack_step.stats[ACTION]["min"][-1] == pytest.approx(20.0)
assert pack_step.stats[ACTION]["max"][-1] == pytest.approx(90.0)
packed = pack_step(
{
@@ -2443,7 +2534,13 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
torch.testing.assert_close(packed[TransitionKey.ACTION][0, :3, :6], expected_normalized)
decoded = decode_step({TransitionKey.ACTION: packed[TransitionKey.ACTION]})
torch.testing.assert_close(decoded[TransitionKey.ACTION], action_a.unsqueeze(0), atol=1e-5, rtol=1e-5)
assert decoded[TransitionKey.ACTION].shape == (1, N1_7_NATIVE_ACTION_HORIZON, 6)
torch.testing.assert_close(
decoded[TransitionKey.ACTION][:, :3],
action_a.unsqueeze(0)[:, :3],
atol=1e-5,
rtol=1e-5,
)
def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
@@ -0,0 +1,100 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isaac-GR00T N1.7 raw-state dropout training contract.
Isaac-GR00T zeroes the entire proprioceptive state of a sample with probability
``state_dropout_prob`` (configured in the checkpoint's processor sidecar) during
training only. Baseline LeRobot kept the processor deterministic, so this
regularization never activated. These tests pin the train/eval split.
"""
import torch
from lerobot.policies.groot.processor_groot import GrootN17PackInputsStep
from lerobot.types import TransitionKey
from lerobot.utils.constants import OBS_STATE
def _make_transition():
return {
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[1.0, 2.0], [3.0, 4.0]])},
TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move", "Move"]},
}
def test_groot_n1_7_training_applies_raw_state_dropout_before_encoder():
step = GrootN17PackInputsStep(
max_state_dim=4,
max_action_dim=4,
normalize_min_max=False,
training=True,
state_dropout_prob=1.0,
)
output = step(_make_transition())
expected = torch.zeros(2, 1, 4)
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
def test_groot_n1_7_training_state_dropout_is_disabled_under_no_grad():
step = GrootN17PackInputsStep(
max_state_dim=4,
max_action_dim=4,
normalize_min_max=False,
training=True,
state_dropout_prob=1.0,
)
with torch.no_grad():
output = step(_make_transition())
expected = torch.tensor([[[1.0, 2.0, 0.0, 0.0]], [[3.0, 4.0, 0.0, 0.0]]])
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
def test_groot_n1_7_eval_mode_state_dropout_is_inactive():
step = GrootN17PackInputsStep(
max_state_dim=4,
max_action_dim=4,
normalize_min_max=False,
training=False,
state_dropout_prob=1.0,
)
output = step(_make_transition())
expected = torch.tensor([[[1.0, 2.0, 0.0, 0.0]], [[3.0, 4.0, 0.0, 0.0]]])
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
def test_groot_n1_7_pack_step_serializes_dropout_prob_but_not_training_mode():
step = GrootN17PackInputsStep(
max_state_dim=4,
max_action_dim=4,
normalize_min_max=False,
training=True,
state_dropout_prob=0.2,
)
serialized = step.get_config()
restored = GrootN17PackInputsStep(**serialized)
assert "training" not in serialized
assert serialized["state_dropout_prob"] == 0.2
assert restored.training is False
assert restored.state_dropout_prob == 0.2
@@ -0,0 +1,156 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isaac-GR00T N1.7 train-time random crop contract (crop geometry only).
Isaac-GR00T crops a random ``crop_fraction`` window during training and the
deterministic center window at eval, replaying the sampled window across all
camera views of a sample (gr00t/data/transform/video.py, n1.5-release onward:
"If mode is 'train', return a random crop transform. If mode is 'eval', return
a center crop transform."). This mirrors LeRobot's own Diffusion/VQBeT
``crop_is_random`` pattern. Color jitter is intentionally out of scope here.
"""
import random
import numpy as np
import torch
from lerobot.policies.groot.processor_groot import (
GrootN17VLMEncodeStep,
_transform_n1_7_image_for_vlm_albumentations,
)
def _structured_image(h=480, w=640):
yy, xx = np.mgrid[0:h, 0:w]
return np.stack(
[(xx * 255 / w), (yy * 255 / h), ((xx + yy) * 255 / (h + w))], axis=-1
).astype(np.uint8)
def test_crop_position_none_is_bitexact_center_crop():
"""crop_position=None must remain byte-identical to the pre-change eval path."""
img = _structured_image()
ref = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95,
)
out = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
)
np.testing.assert_array_equal(ref, out)
def test_crop_position_center_matches_center_crop():
img = _structured_image()
center = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
)
explicit = _transform_n1_7_image_for_vlm_albumentations(
img, image_crop_size=None, image_target_size=[256, 256],
shortest_image_edge=256, crop_fraction=0.95, crop_position=(0.5, 0.5),
)
# int-floor center vs rounded positional center may differ by <=1 px of grid
assert center.shape == explicit.shape
diff = np.abs(center.astype(np.int16) - explicit.astype(np.int16))
assert diff.mean() < 3.0
def test_crop_position_corners_differ_from_center():
img = _structured_image()
def crop_at(position):
return _transform_n1_7_image_for_vlm_albumentations(
img,
image_crop_size=None,
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
crop_position=position,
)
center = crop_at(None)
tl = crop_at((0.0, 0.0))
br = crop_at((1.0, 1.0))
assert not np.array_equal(center, tl)
assert not np.array_equal(tl, br)
def _video(img, views=2):
return np.stack([img] * views, axis=0).reshape(1, 1, views, *img.shape)
def _step(training):
return GrootN17VLMEncodeStep(
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
training=training,
)
def test_training_crop_replays_one_window_across_views():
video = _video(_structured_image())
frames = _step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0]
np.testing.assert_array_equal(np.asarray(frames[0]), np.asarray(frames[1]))
def test_training_crop_differs_from_eval_center_crop():
video = _video(_structured_image())
random.seed(3) # a draw that is not the exact center
train_frame = np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
eval_frame = np.asarray(
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
assert not np.array_equal(train_frame, eval_frame)
def test_training_crop_is_disabled_under_no_grad():
video = _video(_structured_image())
with torch.no_grad():
no_grad_frame = np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
eval_frame = np.asarray(
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
np.testing.assert_array_equal(no_grad_frame, eval_frame)
def test_training_mode_is_not_serialized():
step = _step(training=True)
serialized = step.get_config()
assert "training" not in serialized
restored = GrootN17VLMEncodeStep(**serialized)
assert restored.training is False
def test_training_crop_respects_global_seed():
video = _video(_structured_image())
def draw():
random.seed(11)
return np.asarray(
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
)
np.testing.assert_array_equal(draw(), draw())
@@ -0,0 +1,121 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Isaac-GR00T N1.7 optimizer/scheduler/precision training contract.
Pins the LeRobot GR00T fine-tuning recipe to the native Isaac-GR00T contract:
AdamW(lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5, grad clip 1.0),
HF cosine schedule with ~5% warmup over the actual update count, FP32 master
parameters under BF16 autocast, transformers-style weight-decay grouping, the
frozen LM-head weight tie, and episode-tail exclusion for incomplete chunks.
"""
import pytest
import torch
from lerobot.optim.schedulers import DiffuserSchedulerConfig
from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1_7 import _tie_unused_qwen_lm_head
from lerobot.policies.groot.modeling_groot import GrootPolicy
def test_groot_n1_7_optimizer_matches_isaac_training_contract():
optimizer = GrootConfig().get_optimizer_preset()
assert optimizer.lr == pytest.approx(1e-4)
assert optimizer.betas == pytest.approx((0.9, 0.999))
assert optimizer.eps == pytest.approx(1e-8)
assert optimizer.weight_decay == pytest.approx(1e-5)
assert optimizer.grad_clip_norm == pytest.approx(1.0)
def test_groot_n1_7_sampler_excludes_incomplete_action_tails():
config = GrootConfig(chunk_size=16, n_action_steps=16)
assert len(config.action_delta_indices) == 16
assert config.drop_n_last_frames == 15
def test_groot_n1_7_scheduler_matches_isaac_hf_cosine_contract():
config = GrootConfig(max_steps=20_000)
scheduler_config = config.get_scheduler_preset()
assert isinstance(scheduler_config, DiffuserSchedulerConfig)
assert scheduler_config.name == "cosine"
assert scheduler_config.num_warmup_steps == 1_000
parameter = torch.nn.Parameter(torch.ones(()))
optimizer = torch.optim.AdamW([parameter], lr=config.optimizer_lr)
scheduler = scheduler_config.build(optimizer, num_training_steps=20_000)
lr_factor = scheduler.lr_lambdas[0]
assert lr_factor(0) == pytest.approx(0.0)
assert lr_factor(1_000) == pytest.approx(1.0)
assert lr_factor(10_500) == pytest.approx(0.5)
assert lr_factor(20_000) == pytest.approx(0.0, abs=1e-12)
def test_groot_n1_7_scheduler_rounds_fractional_warmup_up_like_transformers():
scheduler_config = GrootConfig(max_steps=777).get_scheduler_preset()
assert scheduler_config.num_warmup_steps == 39
def test_groot_n1_7_model_parameters_use_fp32_checkpoint_and_optimizer_precision():
module = torch.nn.Module()
module.trainable = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16))
module.frozen = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16), requires_grad=False)
GrootPolicy._cast_model_parameters_to_fp32(module)
assert module.trainable.dtype == torch.float32
assert module.frozen.dtype == torch.float32
def test_groot_n1_7_ties_unused_qwen_lm_head_to_frozen_input_embeddings():
class DummyQwen(torch.nn.Module):
def __init__(self):
super().__init__()
self.embed_tokens = torch.nn.Embedding(7, 3)
self.lm_head = torch.nn.Linear(3, 7, bias=False)
def get_input_embeddings(self):
return self.embed_tokens
model = DummyQwen()
_tie_unused_qwen_lm_head(model)
assert model.lm_head.weight is model.embed_tokens.weight
assert len(list(model.parameters())) == 1
def test_groot_n1_7_optimizer_groups_match_transformers_weight_decay_rules():
module = torch.nn.Module()
module.linear = torch.nn.Linear(3, 2)
module.norm = torch.nn.LayerNorm(2)
module.frozen = torch.nn.Parameter(torch.ones(1), requires_grad=False)
groups = GrootPolicy._build_weight_decay_parameter_groups(module)
assert len(groups) == 2
assert "weight_decay" not in groups[0]
assert groups[1]["weight_decay"] == 0.0
assert groups[0]["params"] == [module.linear.weight]
assert {id(parameter) for parameter in groups[1]["params"]} == {
id(module.linear.bias),
id(module.norm.weight),
id(module.norm.bias),
}