mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 78f778a1ff | |||
| 00f59a2cf4 | |||
| 49cb1ee7db | |||
| 07d6c5b8be | |||
| b23b6edcd9 | |||
| d7b09f77c5 | |||
| 34e70f43b8 | |||
| a35e6a4b46 |
+81
-126
@@ -43,6 +43,25 @@ For a source checkout:
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
### Optional: Flash Attention acceleration
|
||||
|
||||
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
|
||||
|
||||
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
2. Install lerobot with the groot extra.
|
||||
|
||||
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T N1.7:
|
||||
@@ -57,49 +76,26 @@ To use GR00T N1.7:
|
||||
|
||||
Here's a complete training command for finetuning the base GR00T model on your own dataset:
|
||||
|
||||
This command is using the `new_embodiment` flag, which is used for the SO-101 robot, [read more about how GR00T handles different embodiments.](https://github.com/NVIDIA/Isaac-GR00T/blob/main/getting_started/policy.md#--embodiment-tag).
|
||||
|
||||
```bash
|
||||
# install extra deps for training
|
||||
pip install "lerobot[training]"
|
||||
|
||||
hf auth login
|
||||
wandb login
|
||||
|
||||
export DATASET_NAME=your_data_set
|
||||
export HF_USER=your_hf_username
|
||||
export DATASET=$HF_USER/$DATASET_NAME
|
||||
export REPO_ID="${DATASET}_GR00T17" #this is the model that will be uploaded to huggingface
|
||||
export OUTPUT_DIR=outputs/train/$REPO_ID
|
||||
|
||||
lerobot-train \
|
||||
--dataset.repo_id=$DATASET \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--policy.type=groot \
|
||||
--policy.device=cuda \
|
||||
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
|
||||
--policy.embodiment_tag=new_embodiment \
|
||||
--policy.chunk_size=16 \
|
||||
--policy.n_action_steps=16 \
|
||||
--policy.use_relative_actions=true \
|
||||
--policy.relative_exclude_joints='["gripper"]' \
|
||||
--policy.use_bf16=true \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--seed=42 \
|
||||
--batch_size=64 \
|
||||
--steps=20000 \
|
||||
--save_checkpoint=true \
|
||||
--save_freq=5000 \
|
||||
--use_policy_training_preset=true \
|
||||
--env_eval_freq=0 \
|
||||
--eval_steps=0 \
|
||||
--log_freq=10 \
|
||||
# Using a multi-GPU setup
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=$NUM_GPUS \
|
||||
$(which lerobot-train) \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--job_name=$DATASET \
|
||||
--save_checkpoint=true \
|
||||
--batch_size=$BATCH_SIZE \
|
||||
--steps=$NUM_STEPS \
|
||||
--save_freq=$SAVE_FREQ \
|
||||
--log_freq=$LOG_FREQ \
|
||||
--policy.push_to_hub=true \
|
||||
--policy.type=groot \
|
||||
--policy.repo_id=$REPO_ID \
|
||||
--policy.tune_diffusion_model=false \
|
||||
--dataset.repo_id=$DATASET_ID \
|
||||
--wandb.enable=true \
|
||||
--wandb.disable_artifact=true
|
||||
|
||||
--wandb.disable_artifact=true \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
## Performance Results
|
||||
@@ -111,66 +107,39 @@ lerobot-train \
|
||||
|
||||
GR00T N1.7 has demonstrated strong performance on the LIBERO benchmark suite. To reproduce LeRobot results, follow the instructions in the [LIBERO](./libero) section.
|
||||
|
||||
### Train on LIBERO
|
||||
### GR00T N1.7 LIBERO Checkpoints
|
||||
|
||||
Example training command for a LIBERO suite (here `libero_spatial`):
|
||||
NVIDIA publishes GR00T N1.7 LIBERO checkpoints at [`nvidia/GR00T-N1.7-LIBERO`](https://huggingface.co/nvidia/GR00T-N1.7-LIBERO), with one subdirectory per LIBERO suite:
|
||||
|
||||
| Suite | Checkpoint subdirectory |
|
||||
| -------------- | ----------------------- |
|
||||
| LIBERO Spatial | `libero_spatial` |
|
||||
| LIBERO Object | `libero_object` |
|
||||
| LIBERO Goal | `libero_goal` |
|
||||
| LIBERO 10 | `libero_10` |
|
||||
|
||||
Preliminary LeRobot integration results:
|
||||
|
||||
| Suite | Status | Success rate | n_episodes |
|
||||
| -------------- | ------ | -----------: | ---------: |
|
||||
| LIBERO Spatial | ✓ | ~95% | XX |
|
||||
| LIBERO Object | ✓ | XX% | XX |
|
||||
| LIBERO Goal | ✓ | XX% | XX |
|
||||
| LIBERO 10 | ✓ | XX% | XX |
|
||||
| **Average** | ✓ | **XX%** | **XX** |
|
||||
|
||||
Replace the `XX` placeholders with final eval artifacts before merge.
|
||||
|
||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||
|
||||
```bash
|
||||
IMAGE_TRANSFORMS='{
|
||||
"brightness": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"brightness": [0.7, 1.3]}},
|
||||
"contrast": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"contrast": [0.6, 1.4]}},
|
||||
"saturation": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"saturation": [0.5, 1.5]}},
|
||||
"hue": {"weight": 1.0, "type": "ColorJitter", "kwargs": {"hue": [-0.08, 0.08]}}
|
||||
}'
|
||||
|
||||
lerobot-train \
|
||||
--dataset.repo_id=IPEC-COMMUNITY/libero_spatial_no_noops_1.0.0_lerobot \
|
||||
--dataset.root=/datasets/libero_spatial \
|
||||
--dataset.revision=main \
|
||||
--dataset.video_backend=pyav \
|
||||
--dataset.image_transforms.enable=true \
|
||||
--dataset.image_transforms.max_num_transforms=4 \
|
||||
--dataset.image_transforms.tfs="$IMAGE_TRANSFORMS" \
|
||||
--policy.type=groot \
|
||||
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--policy.push_to_hub=false \
|
||||
--policy.max_steps=20000 \
|
||||
--batch_size=320 \
|
||||
--steps=20000 \
|
||||
--save_freq=2000 \
|
||||
--env_eval_freq=0 \
|
||||
--eval_steps=0 \
|
||||
--log_freq=10 \
|
||||
--wandb.enable=true \
|
||||
--wandb.project=lerobot \
|
||||
--wandb.mode=online \
|
||||
--wandb.disable_artifact=true \
|
||||
--num_workers=4 \
|
||||
--prefetch_factor=2 \
|
||||
--persistent_workers=true \
|
||||
--output_dir=$OUTPUT_DIR \
|
||||
--job_name=$JOB_NAME
|
||||
```
|
||||
|
||||
### GR00T N1.7 LIBERO Results
|
||||
|
||||
Preliminary LeRobot integration results (GR00T-LeRobot, `eval.n_episodes >= 50` per suite):
|
||||
|
||||
| Suite | Success rate |
|
||||
| ---------------------- | -----------: |
|
||||
| LIBERO Spatial | 94% |
|
||||
| LIBERO Object | 98% |
|
||||
| LIBERO Goal | 93% |
|
||||
| LIBERO 10 (Long) | 90% |
|
||||
| **Average** | **93.75%** |
|
||||
|
||||
```bash
|
||||
export MODEL_ID=your_trained_model_on_huggingface
|
||||
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
--include "libero_spatial/*" \
|
||||
--local-dir ./GR00T-N1.7-LIBERO
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.base_model_path=$MODEL_ID \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial \
|
||||
@@ -184,41 +153,27 @@ Use `eval.n_episodes >= 50` per suite when reporting success rates.
|
||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||
|
||||
```bash
|
||||
# install extra deps for roullout and real hardware
|
||||
pip install "lerobot[feetech,viz]"
|
||||
|
||||
export MODEL_ID=your_trained_model_on_huggingface
|
||||
|
||||
# make sure that camera index matches your setup!
|
||||
# find index using `uv run lerobot-find-cameras opencv`
|
||||
WRIST_CAM='wrist: {type: opencv, index_or_path: 2, width: 640, height: 480, fps: 30, fourcc: "MJPG"}'
|
||||
FRONT_CAM='front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30, fourcc: "MJPG"}'
|
||||
export ROBOT_CAMERAS="{ $WRIST_CAM, $FRONT_CAM }"
|
||||
export ROBOT_ID=follower_robot
|
||||
export ROBOT_PORT=/dev/ttyACM0
|
||||
|
||||
uv run lerobot-rollout \
|
||||
--strategy.type=base \
|
||||
--policy.path=$MODEL_ID \
|
||||
--policy.base_model_path=nvidia/GR00T-N1.7-3B \
|
||||
--policy.n_action_steps=8 \
|
||||
--robot.type=so101_follower \
|
||||
--robot.port=$ROBOT_PORT \
|
||||
--robot.id=$ROBOT_ID \
|
||||
--robot.cameras="$ROBOT_CAMERAS" \
|
||||
--task="place the vial in the rack" \
|
||||
--duration=60 \
|
||||
--device=cuda \
|
||||
lerobot-rollout\
|
||||
--strategy.type=sentry \
|
||||
--strategy.upload_every_n_episodes=5 \
|
||||
--robot.type=bi_so_follower \
|
||||
--robot.left_arm_port=/dev/ttyACM1 \
|
||||
--robot.right_arm_port=/dev/ttyACM0 \
|
||||
--robot.id=bimanual_follower \
|
||||
--robot.cameras='{ right: {"type": "opencv", "index_or_path": 0, "width": 640, "height": 480, "fps": 30},
|
||||
left: {"type": "opencv", "index_or_path": 2, "width": 640, "height": 480, "fps": 30},
|
||||
top: {"type": "opencv", "index_or_path": 4, "width": 640, "height": 480, "fps": 30},
|
||||
}' \
|
||||
--display_data=true \
|
||||
--inference.type=rtc \
|
||||
--inference.rtc.enabled=false \
|
||||
--inference.rtc.execution_horizon=8 \
|
||||
--inference.queue_threshold=0
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.rgb_encoder.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--duration=600
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> Value of `inference.queue_threshold` should not exeed 0.5 to ensure stable inference.
|
||||
|
||||
## License
|
||||
|
||||
GR00T N1.7 is released under the [NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license/).
|
||||
|
||||
@@ -15,12 +15,11 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTrainedConfig
|
||||
from lerobot.optim import AdamWConfig, DiffuserSchedulerConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .utils import read_json
|
||||
@@ -337,14 +336,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
# Isaac-GR00T N1.7 fine-tunes with AdamW betas (0.9, 0.999).
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.999)
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
# The native N1.7 fine-tuning recipe keeps model parameters in FP32 and computes under BF16 autocast.
|
||||
model_params_fp32: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
|
||||
@@ -484,20 +480,15 @@ class GrootConfig(PreTrainedConfig):
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> DiffuserSchedulerConfig:
|
||||
"""Return scheduler configuration.
|
||||
|
||||
Isaac-GR00T uses the HF Trainer cosine schedule with ~5% warmup over the
|
||||
actual training update count; DiffuserSchedulerConfig wraps the same
|
||||
diffusers/transformers `get_scheduler("cosine")` implementation and
|
||||
derives num_training_steps from the outer --steps value at runtime.
|
||||
"""
|
||||
return DiffuserSchedulerConfig(
|
||||
name="cosine",
|
||||
num_warmup_steps=math.ceil(self.max_steps * self.warmup_ratio),
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Return scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=int(10000 * self.warmup_ratio), # 5% warmup by default
|
||||
num_decay_steps=10000, # Adjust based on training steps
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.optimizer_lr * 0.1,
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -513,11 +504,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
)
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def drop_n_last_frames(self) -> int:
|
||||
"""Exclude episode tails that cannot supply a complete N1.7 action chunk."""
|
||||
return max(0, len(self.action_delta_indices) - 1)
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""Return indices for delta rewards (None for Groot)."""
|
||||
|
||||
@@ -60,19 +60,6 @@ except ImportError:
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _tie_unused_qwen_lm_head(model: nn.Module) -> None:
|
||||
"""Restore the TF4 weight tie so the unused LM head stays frozen and is omitted on save."""
|
||||
lm_head = getattr(model, "lm_head", None)
|
||||
get_input_embeddings = getattr(model, "get_input_embeddings", None)
|
||||
if lm_head is None or not callable(get_input_embeddings):
|
||||
return
|
||||
input_embeddings = get_input_embeddings()
|
||||
embedding_weight = getattr(input_embeddings, "weight", None)
|
||||
if embedding_weight is None:
|
||||
return
|
||||
lm_head.weight = embedding_weight
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
@@ -301,7 +288,6 @@ class Qwen3Backbone(nn.Module):
|
||||
config_kwargs=transformers_loading_kwargs,
|
||||
).eval()
|
||||
|
||||
_tie_unused_qwen_lm_head(self.model)
|
||||
while len(self.language_model.layers) > select_layer:
|
||||
self.language_model.layers.pop(-1)
|
||||
|
||||
@@ -617,7 +603,7 @@ class GR00TN17ActionHead(nn.Module):
|
||||
|
||||
pred = self.action_decoder(model_output, embodiment_id)
|
||||
pred_actions = pred[:, -actions.shape[1] :]
|
||||
action_mask = action_input.action_mask
|
||||
action_mask = action_input.action_mask.to(dtype=pred_actions.dtype)
|
||||
action_loss = F.mse_loss(pred_actions, velocity, reduction="none") * action_mask
|
||||
loss = action_loss.sum() / (action_mask.sum() + 1e-6)
|
||||
return BatchFeature(
|
||||
|
||||
@@ -34,7 +34,6 @@ from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||
@@ -51,7 +50,7 @@ from .configuration_groot import (
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
)
|
||||
from .groot_n1_7 import GR00TN17, _tie_unused_qwen_lm_head
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -97,49 +96,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if self.config.rtc_ramp_rate is not None:
|
||||
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
return GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=self.config.tune_vlln,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
backbone = getattr(model, "backbone", None)
|
||||
qwen_model = getattr(backbone, "model", None)
|
||||
if qwen_model is not None:
|
||||
_tie_unused_qwen_lm_head(qwen_model)
|
||||
if self.config.model_params_fp32:
|
||||
self._cast_model_parameters_to_fp32(model)
|
||||
return model
|
||||
|
||||
@staticmethod
|
||||
def _cast_model_parameters_to_fp32(model: torch.nn.Module) -> None:
|
||||
for parameter in model.parameters():
|
||||
if parameter.is_floating_point():
|
||||
parameter.data = parameter.data.to(torch.float32)
|
||||
|
||||
@staticmethod
|
||||
def _build_weight_decay_parameter_groups(model: torch.nn.Module) -> list[dict[str, object]]:
|
||||
forbidden_name_patterns = [
|
||||
r"bias",
|
||||
r"layernorm",
|
||||
r"rmsnorm",
|
||||
r"(?:^|\.)norm(?:$|\.)",
|
||||
r"_norm(?:$|\.)",
|
||||
]
|
||||
decay_names = set(get_parameter_names(model, [torch.nn.LayerNorm], forbidden_name_patterns))
|
||||
decay_params = [
|
||||
parameter
|
||||
for name, parameter in model.named_parameters()
|
||||
if parameter.requires_grad and name in decay_names
|
||||
]
|
||||
no_decay_params = [
|
||||
parameter
|
||||
for name, parameter in model.named_parameters()
|
||||
if parameter.requires_grad and name not in decay_names
|
||||
]
|
||||
return [
|
||||
{"params": decay_params},
|
||||
{"params": no_decay_params, "weight_decay": 0.0},
|
||||
]
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
@@ -277,9 +238,8 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
policy.eval()
|
||||
return policy
|
||||
|
||||
def get_optim_params(self): # type: ignore[override]
|
||||
"""Isaac-GR00T excludes biases and normalization parameters from weight decay."""
|
||||
return self._build_weight_decay_parameter_groups(self)
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _resolve_action_queue_steps(self) -> int:
|
||||
n_action_steps = int(self.config.n_action_steps)
|
||||
|
||||
@@ -15,7 +15,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import random
|
||||
from copy import copy, deepcopy
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
from pathlib import Path
|
||||
@@ -137,7 +136,6 @@ class _GrootN17CheckpointProcessorAssets:
|
||||
video_horizon: int | None
|
||||
use_percentiles: bool
|
||||
use_relative_action: bool
|
||||
state_dropout_prob: float
|
||||
clip_outliers: bool
|
||||
video_modality_keys: list[str] | None
|
||||
image_crop_size: list[int] | None
|
||||
@@ -145,7 +143,6 @@ class _GrootN17CheckpointProcessorAssets:
|
||||
shortest_image_edge: int | None
|
||||
crop_fraction: float | None
|
||||
use_albumentations: bool
|
||||
letter_box_transform: bool
|
||||
|
||||
|
||||
@dataclass(frozen=True)
|
||||
@@ -184,9 +181,6 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
modality_config = {}
|
||||
|
||||
use_relative_action = bool(processor_kwargs.get("use_relative_action", False))
|
||||
state_dropout_prob = as_optional_float(processor_kwargs.get("state_dropout_prob"))
|
||||
if state_dropout_prob is None:
|
||||
state_dropout_prob = 0.0
|
||||
stats = _load_n1_7_checkpoint_stats(
|
||||
checkpoint_path,
|
||||
processor_kwargs,
|
||||
@@ -205,9 +199,6 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
use_albumentations = processor_kwargs.get("use_albumentations", False)
|
||||
if not isinstance(use_albumentations, bool):
|
||||
use_albumentations = False
|
||||
letter_box_transform = processor_kwargs.get("letter_box_transform", False)
|
||||
if not isinstance(letter_box_transform, bool):
|
||||
letter_box_transform = False
|
||||
|
||||
valid_action_horizon = _load_n1_7_checkpoint_action_horizon(processor_kwargs, config.embodiment_tag)
|
||||
video_horizon = _load_n1_7_checkpoint_video_horizon(processor_kwargs, config.embodiment_tag)
|
||||
@@ -227,7 +218,6 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
video_horizon=video_horizon,
|
||||
use_percentiles=bool(processor_kwargs.get("use_percentiles", False)),
|
||||
use_relative_action=use_relative_action,
|
||||
state_dropout_prob=state_dropout_prob,
|
||||
clip_outliers=clip_outliers,
|
||||
video_modality_keys=video_modality_keys,
|
||||
image_crop_size=as_int_pair(processor_kwargs.get("image_crop_size")),
|
||||
@@ -235,7 +225,6 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
shortest_image_edge=as_optional_int(processor_kwargs.get("shortest_image_edge")),
|
||||
crop_fraction=as_optional_float(processor_kwargs.get("crop_fraction")),
|
||||
use_albumentations=use_albumentations,
|
||||
letter_box_transform=letter_box_transform,
|
||||
)
|
||||
|
||||
|
||||
@@ -451,22 +440,6 @@ def _apply_groot_step_overrides(
|
||||
post_init()
|
||||
|
||||
|
||||
def _set_groot_preprocessor_training(
|
||||
preprocessor: PolicyProcessorPipeline,
|
||||
*,
|
||||
training: bool,
|
||||
) -> None:
|
||||
"""Set the runtime-only mode of GR00T stochastic processor steps.
|
||||
|
||||
Any dataclass step exposing a ``training`` field participates, so processor
|
||||
steps can opt into train-time-only behavior (dropout, augmentation) without
|
||||
this helper enumerating them.
|
||||
"""
|
||||
for step in preprocessor.steps:
|
||||
if is_dataclass(step) and any(f.name == "training" for f in fields(step)):
|
||||
setattr(step, "training", training)
|
||||
|
||||
|
||||
def make_groot_pre_post_processors_from_pretrained(
|
||||
config: GrootConfig,
|
||||
pretrained_path: str,
|
||||
@@ -515,7 +488,6 @@ def make_groot_pre_post_processors_from_pretrained(
|
||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||
_apply_groot_action_decode_transform(postprocessor, config.action_decode_transform)
|
||||
_set_groot_preprocessor_training(preprocessor, training=dataset_meta is not None)
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
@@ -1030,6 +1002,7 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
}
|
||||
for group in groups
|
||||
]
|
||||
# 40 matches the action horizon of the only N1.7 base model (nvidia/GR00T-N1.7-3B)
|
||||
action_horizon = min(config.chunk_size, 40)
|
||||
modality_config: dict[str, Any] = {
|
||||
"state": {"modality_keys": [group.key for group in groups]},
|
||||
@@ -1079,7 +1052,6 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
video_horizon=base_assets.video_horizon if base_assets is not None else None,
|
||||
use_percentiles=use_percentiles,
|
||||
use_relative_action=True,
|
||||
state_dropout_prob=base_assets.state_dropout_prob if base_assets is not None else 0.0,
|
||||
clip_outliers=base_assets.clip_outliers if base_assets is not None else True,
|
||||
video_modality_keys=video_modality_keys,
|
||||
image_crop_size=base_assets.image_crop_size if base_assets is not None else None,
|
||||
@@ -1087,7 +1059,6 @@ def _build_n1_7_relative_action_processor_assets(
|
||||
shortest_image_edge=base_assets.shortest_image_edge if base_assets is not None else None,
|
||||
crop_fraction=base_assets.crop_fraction if base_assets is not None else None,
|
||||
use_albumentations=base_assets.use_albumentations if base_assets is not None else False,
|
||||
letter_box_transform=base_assets.letter_box_transform if base_assets is not None else False,
|
||||
)
|
||||
|
||||
|
||||
@@ -1184,8 +1155,6 @@ def make_groot_pre_post_processors(
|
||||
embodiment_tag=config.embodiment_tag,
|
||||
embodiment_mapping=embodiment_mapping,
|
||||
normalize_min_max=True,
|
||||
training=dataset_meta is not None,
|
||||
state_dropout_prob=(checkpoint_assets.state_dropout_prob if checkpoint_assets is not None else 0.0),
|
||||
stats=padded_stats,
|
||||
clip_outliers=clip_outliers,
|
||||
video_modality_keys=video_modality_keys,
|
||||
@@ -1211,7 +1180,6 @@ def make_groot_pre_post_processors(
|
||||
shortest_image_edge = None
|
||||
crop_fraction = None
|
||||
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
|
||||
letter_box_transform = checkpoint_assets.letter_box_transform if checkpoint_assets is not None else False
|
||||
|
||||
input_steps: list[ProcessorStep] = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
@@ -1224,8 +1192,6 @@ def make_groot_pre_post_processors(
|
||||
shortest_image_edge=shortest_image_edge,
|
||||
crop_fraction=crop_fraction,
|
||||
use_albumentations=use_albumentations,
|
||||
letter_box_transform=letter_box_transform,
|
||||
training=dataset_meta is not None,
|
||||
device=config.device,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
@@ -1350,8 +1316,6 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
letter_box_transform: bool = False,
|
||||
crop_position: tuple[float, float] | None = None,
|
||||
) -> np.ndarray:
|
||||
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
||||
|
||||
@@ -1361,12 +1325,6 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
|
||||
torch path and must stay bit-exact to the upstream reference. The hot path accepts
|
||||
and returns numpy arrays to avoid per-frame PIL round-trips.
|
||||
|
||||
``crop_position`` selects where the ``crop_fraction`` window sits: ``None``
|
||||
keeps the deterministic center crop (eval contract), while ``(y, x)``
|
||||
fractions in [0, 1] place the window for Isaac's train-time random crop
|
||||
(0.5, 0.5 == center). Training samples one position per sample and reuses
|
||||
it across camera views.
|
||||
"""
|
||||
if image_target_size is None:
|
||||
return image
|
||||
@@ -1382,18 +1340,6 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
if not image_np.flags.c_contiguous:
|
||||
image_np = np.ascontiguousarray(image_np)
|
||||
|
||||
if letter_box_transform:
|
||||
height, width = image_np.shape[:2]
|
||||
if height != width:
|
||||
square_edge = max(height, width)
|
||||
pad_h = square_edge - height
|
||||
pad_w = square_edge - width
|
||||
top = pad_h // 2
|
||||
bottom = pad_h - top
|
||||
left = pad_w // 2
|
||||
right = pad_w - left
|
||||
image_np = cv2.copyMakeBorder(image_np, top, bottom, left, right, cv2.BORDER_CONSTANT, value=0)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
|
||||
def resize_shortest_edge(frame: np.ndarray) -> np.ndarray:
|
||||
@@ -1418,13 +1364,8 @@ def _transform_n1_7_image_for_vlm_albumentations(
|
||||
height, width = image_np.shape[:2]
|
||||
crop_h = max(1, int(height * crop_fraction))
|
||||
crop_w = max(1, int(width * crop_fraction))
|
||||
if crop_position is None:
|
||||
top = max(0, (height - crop_h) // 2)
|
||||
left = max(0, (width - crop_w) // 2)
|
||||
else:
|
||||
pos_y, pos_x = crop_position
|
||||
top = int(round((height - crop_h) * min(max(pos_y, 0.0), 1.0)))
|
||||
left = int(round((width - crop_w) * min(max(pos_x, 0.0), 1.0)))
|
||||
top = max(0, (height - crop_h) // 2)
|
||||
left = max(0, (width - crop_w) // 2)
|
||||
image_np = image_np[top : top + crop_h, left : left + crop_w]
|
||||
|
||||
return resize_shortest_edge(image_np)
|
||||
@@ -1437,12 +1378,9 @@ def _transform_n1_7_image_for_vlm_torch(
|
||||
image_target_size: list[int] | None,
|
||||
shortest_image_edge: int | None,
|
||||
crop_fraction: float | None,
|
||||
letter_box_transform: bool = False,
|
||||
) -> torch.Tensor:
|
||||
"""Default (non-albumentations) N1.7 image transform.
|
||||
|
||||
Optionally pads to square, then resizes to ``shortest_image_edge``, center-crops
|
||||
by ``crop_fraction``, and resizes to ``image_target_size``.
|
||||
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
|
||||
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
|
||||
|
||||
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||
@@ -1457,14 +1395,13 @@ def _transform_n1_7_image_for_vlm_torch(
|
||||
target_h, target_w = image_target_size
|
||||
_, height, width = image.shape
|
||||
|
||||
if letter_box_transform:
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
square_edge = max(height, width)
|
||||
if height != width:
|
||||
left = (square_edge - width) // 2
|
||||
top = (square_edge - height) // 2
|
||||
image = tv_functional.pad(
|
||||
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||
)
|
||||
|
||||
resize_edge = shortest_image_edge or target_h
|
||||
image = tv_functional.resize(
|
||||
@@ -1511,8 +1448,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
embodiment_mapping: dict[str, int] = field(default_factory=lambda: dict(N1_7_EMBODIMENT_MAPPING))
|
||||
normalize_min_max: bool = True
|
||||
training: bool = False
|
||||
state_dropout_prob: float = 0.0
|
||||
stats: dict[str, dict[str, Any]] | None = None
|
||||
clip_outliers: bool = True
|
||||
use_percentiles: bool = False
|
||||
@@ -1844,13 +1779,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
if dim < self.max_state_dim:
|
||||
pad = torch.zeros(bsz, 1, self.max_state_dim - dim, dtype=state.dtype, device=state.device)
|
||||
state = torch.cat([state, pad], dim=2)
|
||||
if self.training and torch.is_grad_enabled() and self.state_dropout_prob > 0:
|
||||
drop_state = torch.tensor(
|
||||
[random.random() < self.state_dropout_prob for _ in range(bsz)],
|
||||
dtype=torch.bool,
|
||||
device=state.device,
|
||||
).view(bsz, 1, 1)
|
||||
state = state.masked_fill(drop_state, 0)
|
||||
obs["state"] = state
|
||||
|
||||
action = transition.get(TransitionKey.ACTION)
|
||||
@@ -1962,7 +1890,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
"embodiment_tag": self.embodiment_tag,
|
||||
"embodiment_mapping": self.embodiment_mapping,
|
||||
"normalize_min_max": self.normalize_min_max,
|
||||
"state_dropout_prob": self.state_dropout_prob,
|
||||
"clip_outliers": self.clip_outliers,
|
||||
"use_percentiles": self.use_percentiles,
|
||||
"video_modality_keys": self.video_modality_keys,
|
||||
@@ -2019,12 +1946,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
shortest_image_edge: int | None = None
|
||||
crop_fraction: float | None = None
|
||||
use_albumentations: bool = False
|
||||
letter_box_transform: bool = False
|
||||
# Runtime-only train/eval mode: True enables Isaac's train-time random crop
|
||||
# (one window per sample, replayed across views); False keeps the
|
||||
# deterministic center crop. Never serialized - reloaded pipelines default
|
||||
# to eval and are re-enabled only when processors are built with dataset_meta.
|
||||
training: bool = False
|
||||
device: str | None = None
|
||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||
|
||||
@@ -2058,29 +1979,20 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"""
|
||||
if self.use_albumentations:
|
||||
video_np = np.asarray(video)
|
||||
train_crop = self.training and torch.is_grad_enabled()
|
||||
sample_images: list[list[Any]] = []
|
||||
for batch_idx in range(batch_size):
|
||||
# Isaac-GR00T samples ONE crop window per sample and replays it
|
||||
# across every (timestep, view) frame of that sample, keeping
|
||||
# cross-view geometry consistent. Eval keeps the center crop.
|
||||
crop_position = (random.random(), random.random()) if train_crop else None
|
||||
sample_images.append(
|
||||
[
|
||||
_transform_n1_7_image_for_vlm_albumentations(
|
||||
video_np[batch_idx, timestep, view_idx],
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
letter_box_transform=self.letter_box_transform,
|
||||
crop_position=crop_position,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
]
|
||||
)
|
||||
return sample_images
|
||||
return [
|
||||
[
|
||||
_transform_n1_7_image_for_vlm_albumentations(
|
||||
video_np[batch_idx, timestep, view_idx],
|
||||
image_crop_size=self.image_crop_size,
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
)
|
||||
for timestep in range(video_np.shape[1])
|
||||
for view_idx in range(video_np.shape[2])
|
||||
]
|
||||
for batch_idx in range(batch_size)
|
||||
]
|
||||
|
||||
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
||||
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
||||
@@ -2099,7 +2011,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
image_target_size=self.image_target_size,
|
||||
shortest_image_edge=self.shortest_image_edge,
|
||||
crop_fraction=self.crop_fraction,
|
||||
letter_box_transform=self.letter_box_transform,
|
||||
)
|
||||
for timestep in range(sample.shape[0])
|
||||
for view_idx in range(sample.shape[1])
|
||||
@@ -2173,7 +2084,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
"shortest_image_edge": self.shortest_image_edge,
|
||||
"crop_fraction": self.crop_fraction,
|
||||
"use_albumentations": self.use_albumentations,
|
||||
"letter_box_transform": self.letter_box_transform,
|
||||
"device": self.device,
|
||||
}
|
||||
|
||||
|
||||
@@ -43,10 +43,8 @@ from lerobot.policies.groot.processor_groot import (
|
||||
GrootN17ActionDecodeStep,
|
||||
GrootN17PackInputsStep,
|
||||
GrootN17VLMEncodeStep,
|
||||
N1_7_NATIVE_ACTION_HORIZON,
|
||||
_make_relative_action_training_stats,
|
||||
_transform_n1_7_image_for_vlm_albumentations,
|
||||
_transform_n1_7_image_for_vlm_torch,
|
||||
make_groot_pre_post_processors,
|
||||
)
|
||||
from lerobot.processor import (
|
||||
@@ -82,14 +80,6 @@ def _groot_config() -> GrootConfig:
|
||||
)
|
||||
|
||||
|
||||
def _native_action_chunk(rows: list[list[float]]) -> torch.Tensor:
|
||||
chunk = torch.tensor(rows, dtype=torch.float32)
|
||||
if chunk.shape[0] >= N1_7_NATIVE_ACTION_HORIZON:
|
||||
return chunk[:N1_7_NATIVE_ACTION_HORIZON]
|
||||
tail = chunk[-1:].repeat(N1_7_NATIVE_ACTION_HORIZON - chunk.shape[0], 1)
|
||||
return torch.cat([chunk, tail], dim=0)
|
||||
|
||||
|
||||
def _raw_n1_7_libero_config(model_path) -> GrootConfig:
|
||||
input_features, output_features = _groot_features(state_dim=8, action_dim=7)
|
||||
return GrootConfig(
|
||||
@@ -246,7 +236,6 @@ def _write_raw_n1_7_libero_checkpoint(path):
|
||||
"shortest_image_edge": 256,
|
||||
"crop_fraction": 0.95,
|
||||
"use_albumentations": True,
|
||||
"letter_box_transform": False,
|
||||
"max_action_horizon": 40,
|
||||
"max_state_dim": 132,
|
||||
"max_action_dim": 132,
|
||||
@@ -611,7 +600,6 @@ def test_raw_n1_7_libero_checkpoint_processors_use_checkpoint_assets(tmp_path):
|
||||
assert vlm_encode.shortest_image_edge == 256
|
||||
assert vlm_encode.crop_fraction == 0.95
|
||||
assert vlm_encode.use_albumentations is True
|
||||
assert vlm_encode.letter_box_transform is False
|
||||
assert decode_actions.raw_stats["action"]["gripper"]["q99"] == [115.0]
|
||||
assert decode_actions.env_action_dim == 7
|
||||
assert decode_actions.use_percentiles is True
|
||||
@@ -685,7 +673,6 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
|
||||
config_filename="policy_postprocessor.json",
|
||||
)
|
||||
pack_inputs = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
|
||||
vlm_encode = next(step for step in loaded_preprocessor.steps if isinstance(step, GrootN17VLMEncodeStep))
|
||||
decode_actions = next(
|
||||
step for step in loaded_postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep)
|
||||
)
|
||||
@@ -694,7 +681,6 @@ def test_groot_n1_7_saved_processors_round_trip_checkpoint_specific_fields(tmp_p
|
||||
assert pack_inputs.action_horizon == 40
|
||||
assert pack_inputs.video_modality_keys == ["image", "wrist_image"]
|
||||
assert pack_inputs.clip_outliers is True
|
||||
assert vlm_encode.letter_box_transform is False
|
||||
torch.testing.assert_close(
|
||||
pack_inputs.stats[OBS_STATE]["min"],
|
||||
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0]),
|
||||
@@ -1863,58 +1849,6 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
|
||||
np.testing.assert_array_equal(np.asarray(transformed), expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_albumentations_letterbox_is_opt_in():
|
||||
pytest.importorskip("cv2", exc_type=ImportError)
|
||||
|
||||
image = np.full((3, 5, 3), 255, dtype=np.uint8)
|
||||
|
||||
default = _transform_n1_7_image_for_vlm_albumentations(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
)
|
||||
letterboxed = _transform_n1_7_image_for_vlm_albumentations(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
assert default.shape == (10, 17, 3)
|
||||
assert default.min() == 255
|
||||
assert letterboxed.shape == (10, 10, 3)
|
||||
assert letterboxed.min() < 255
|
||||
|
||||
|
||||
def test_groot_n1_7_torch_letterbox_is_opt_in():
|
||||
image = torch.full((3, 3, 5), 255, dtype=torch.uint8)
|
||||
|
||||
default = _transform_n1_7_image_for_vlm_torch(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
)
|
||||
letterboxed = _transform_n1_7_image_for_vlm_torch(
|
||||
image,
|
||||
image_crop_size=None,
|
||||
image_target_size=[10, 10],
|
||||
shortest_image_edge=10,
|
||||
crop_fraction=None,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
assert tuple(default.shape) == (3, 10, 10)
|
||||
assert int(default.min()) == 255
|
||||
assert tuple(letterboxed.shape) == (3, 10, 10)
|
||||
assert int(letterboxed.min()) < 255
|
||||
|
||||
|
||||
def test_groot_n1_7_vlm_encode_transforms_non_square_two_camera_sample_like_core_albumentations():
|
||||
cv2 = pytest.importorskip("cv2", exc_type=ImportError)
|
||||
|
||||
@@ -1985,7 +1919,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
use_albumentations=True,
|
||||
letter_box_transform=True,
|
||||
)
|
||||
|
||||
restored = GrootN17VLMEncodeStep(**step.get_config())
|
||||
@@ -1996,7 +1929,6 @@ def test_groot_n1_7_vlm_encode_config_round_trips_model_name():
|
||||
assert restored.shortest_image_edge == 256
|
||||
assert restored.crop_fraction == 0.95
|
||||
assert restored.use_albumentations is True
|
||||
assert restored.letter_box_transform is True
|
||||
|
||||
|
||||
def test_groot_n1_7_processor_uses_qwen_component_assets(monkeypatch):
|
||||
@@ -2154,7 +2086,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
samples = [
|
||||
{
|
||||
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
|
||||
ACTION: _native_action_chunk(
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
|
||||
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
|
||||
@@ -2163,7 +2095,7 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
},
|
||||
{
|
||||
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
|
||||
ACTION: _native_action_chunk(
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
|
||||
@@ -2190,12 +2122,10 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
action_names=action_names,
|
||||
preserve_action_horizon=True,
|
||||
)
|
||||
expected_relative_action_min_prefix = torch.tensor(
|
||||
[-2.0, -3.0, -4.0, -5.0, -6.0, 1.0, 2.0, 3.0, 4.0, 5.0]
|
||||
)
|
||||
expected_relative_action_max_prefix = torch.tensor(
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0, 2.0, 3.0, 4.0, 5.0, 6.0]
|
||||
)
|
||||
expected_relative_action_stats = {
|
||||
"min": torch.tensor([-2.0, -3.0, -4.0, -5.0, -6.0, 1.0, 2.0, 3.0, 4.0, 5.0, 0.0]),
|
||||
"max": torch.tensor([-1.0, -2.0, -3.0, -4.0, -5.0, 2.0, 3.0, 4.0, 5.0, 6.0, 100.0]),
|
||||
}
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config, dataset_stats=relative_dataset_stats, dataset_meta=_RelativeStatsDataset.meta
|
||||
@@ -2218,26 +2148,17 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
{"rep": "RELATIVE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
{"rep": "ABSOLUTE", "type": "NON_EEF", "format": "DEFAULT", "state_key": None},
|
||||
]
|
||||
pack_relative_min = pack_config["raw_stats"]["relative_action"]["single_arm"]["min"]
|
||||
assert pack_relative_min[:2] == [
|
||||
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["min"] == [
|
||||
[-2.0, -3.0, -4.0, -5.0, -6.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
]
|
||||
assert len(pack_relative_min) == N1_7_NATIVE_ACTION_HORIZON
|
||||
assert (
|
||||
pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2] * N1_7_NATIVE_ACTION_HORIZON
|
||||
)
|
||||
assert pack_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
|
||||
assert pack_config["raw_stats"]["action"]["gripper"]["min"] == [0.0]
|
||||
assert pack_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
pack_state = load_file(tmp_path / pack_entry["state_file"])
|
||||
expected_flat_dim = N1_7_NATIVE_ACTION_HORIZON * 5 + 1
|
||||
assert pack_state[f"{ACTION}.min"].shape == (expected_flat_dim,)
|
||||
assert pack_state[f"{ACTION}.max"].shape == (expected_flat_dim,)
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.min"][:10], expected_relative_action_min_prefix)
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.max"][:10], expected_relative_action_max_prefix)
|
||||
assert pack_state[f"{ACTION}.min"][-1].item() == 0.0
|
||||
assert pack_state[f"{ACTION}.max"][-1].item() == 100.0
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.min"], expected_relative_action_stats["min"])
|
||||
torch.testing.assert_close(pack_state[f"{ACTION}.max"], expected_relative_action_stats["max"])
|
||||
|
||||
postprocessor_config = json.loads((tmp_path / "policy_postprocessor.json").read_text())
|
||||
assert not any(
|
||||
@@ -2250,16 +2171,11 @@ def test_groot_n1_7_relative_action_training_processors_save_native_grouped_stat
|
||||
)
|
||||
decode_config = decode_entry["config"]
|
||||
assert decode_config["use_relative_action"] is True
|
||||
decode_relative_max = decode_config["raw_stats"]["relative_action"]["single_arm"]["max"]
|
||||
assert decode_relative_max[:2] == [
|
||||
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["max"] == [
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0],
|
||||
[2.0, 3.0, 4.0, 5.0, 6.0],
|
||||
]
|
||||
assert len(decode_relative_max) == N1_7_NATIVE_ACTION_HORIZON
|
||||
assert (
|
||||
decode_config["raw_stats"]["relative_action"]["single_arm"]["count"]
|
||||
== [2] * N1_7_NATIVE_ACTION_HORIZON
|
||||
)
|
||||
assert decode_config["raw_stats"]["relative_action"]["single_arm"]["count"] == [2, 2]
|
||||
assert decode_config["raw_stats"]["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
|
||||
@@ -2299,7 +2215,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
|
||||
samples = [
|
||||
{
|
||||
OBS_STATE: torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 0.0]),
|
||||
ACTION: _native_action_chunk(
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[8.0, 17.0, 26.0, 35.0, 44.0, 0.0],
|
||||
[12.0, 23.0, 34.0, 45.0, 56.0, 100.0],
|
||||
@@ -2308,7 +2224,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
|
||||
},
|
||||
{
|
||||
OBS_STATE: torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0, 50.0]),
|
||||
ACTION: _native_action_chunk(
|
||||
ACTION: torch.tensor(
|
||||
[
|
||||
[-1.0, -2.0, -3.0, -4.0, -5.0, 25.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0, 75.0],
|
||||
@@ -2339,9 +2255,7 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
|
||||
assert kwargs["root"] == runtime_meta.root
|
||||
assert kwargs["revision"] == runtime_meta.revision
|
||||
assert kwargs["download_videos"] is False
|
||||
assert kwargs["delta_timestamps"][ACTION] == [
|
||||
index / runtime_meta.fps for index in range(N1_7_NATIVE_ACTION_HORIZON)
|
||||
]
|
||||
assert kwargs["delta_timestamps"][ACTION] == [0.0, 1 / runtime_meta.fps]
|
||||
return _RelativeStatsDataset()
|
||||
|
||||
monkeypatch.setattr("lerobot.policies.groot.processor_groot.LeRobotDataset", _fake_lerobot_dataset)
|
||||
@@ -2352,15 +2266,11 @@ def test_groot_n1_7_relative_action_processors_compute_stats_from_runtime_datase
|
||||
assert not any(isinstance(step, RelativeActionsProcessorStep) for step in preprocessor.steps)
|
||||
assert isinstance(postprocessor.steps[0], GrootN17ActionDecodeStep)
|
||||
pack_step = next(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep))
|
||||
assert pack_step.action_horizon == N1_7_NATIVE_ACTION_HORIZON
|
||||
assert pack_step.valid_action_horizon == 2
|
||||
pack_relative_min = pack_step.raw_stats["relative_action"]["single_arm"]["min"]
|
||||
assert pack_relative_min[:2] == [
|
||||
assert pack_step.raw_stats["relative_action"]["single_arm"]["min"] == [
|
||||
[-2.0, -3.0, -4.0, -5.0, -6.0],
|
||||
[1.0, 2.0, 3.0, 4.0, 5.0],
|
||||
]
|
||||
assert len(pack_relative_min) == N1_7_NATIVE_ACTION_HORIZON
|
||||
assert pack_step.raw_stats["relative_action"]["single_arm"]["count"] == [2] * N1_7_NATIVE_ACTION_HORIZON
|
||||
assert pack_step.raw_stats["relative_action"]["single_arm"]["count"] == [2, 2]
|
||||
assert pack_step.raw_stats["action"]["gripper"]["max"] == [100.0]
|
||||
|
||||
|
||||
@@ -2405,14 +2315,14 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
|
||||
}
|
||||
state_a = torch.tensor([10.0, 20.0, 30.0, 40.0, 50.0, 25.0])
|
||||
state_b = torch.tensor([0.0, -10.0, 10.0, -20.0, 20.0, 75.0])
|
||||
action_a = _native_action_chunk(
|
||||
action_a = torch.tensor(
|
||||
[
|
||||
[11.0, 22.0, 33.0, 44.0, 55.0, 20.0],
|
||||
[12.0, 24.0, 36.0, 48.0, 60.0, 80.0],
|
||||
[13.0, 26.0, 39.0, 52.0, 65.0, 90.0],
|
||||
]
|
||||
)
|
||||
action_b = _native_action_chunk(
|
||||
action_b = torch.tensor(
|
||||
[
|
||||
[-1.0, -8.0, 13.0, -16.0, 25.0, 30.0],
|
||||
[-2.0, -6.0, 16.0, -12.0, 30.0, 40.0],
|
||||
@@ -2489,13 +2399,12 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
|
||||
]
|
||||
)
|
||||
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["min"][:3, :5]), oss_arm_min)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["max"][:3, :5]), oss_arm_max)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["mean"][:3, :5]), oss_arm_mean)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["std"][:3, :5]), oss_arm_std)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q01"][:3, :5]), oss_arm_q01)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q99"][:3, :5]), oss_arm_q99)
|
||||
assert torch.as_tensor(relative_dataset_stats[ACTION]["min"]).shape[0] == N1_7_NATIVE_ACTION_HORIZON
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["min"][:, :5]), oss_arm_min)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["max"][:, :5]), oss_arm_max)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["mean"][:, :5]), oss_arm_mean)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["std"][:, :5]), oss_arm_std)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q01"][:, :5]), oss_arm_q01)
|
||||
torch.testing.assert_close(torch.as_tensor(relative_dataset_stats[ACTION]["q99"][:, :5]), oss_arm_q99)
|
||||
|
||||
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||
config,
|
||||
@@ -2506,16 +2415,16 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
|
||||
decode_step = next(step for step in postprocessor.steps if isinstance(step, GrootN17ActionDecodeStep))
|
||||
|
||||
assert pack_step.use_percentiles is True
|
||||
pack_relative_min = torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["min"])
|
||||
pack_relative_q99 = torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["q99"])
|
||||
assert pack_relative_min.shape == (N1_7_NATIVE_ACTION_HORIZON, 5)
|
||||
assert pack_relative_q99.shape == (N1_7_NATIVE_ACTION_HORIZON, 5)
|
||||
torch.testing.assert_close(pack_relative_min[:3], oss_arm_min)
|
||||
torch.testing.assert_close(pack_relative_q99[:3], oss_arm_q99)
|
||||
assert pack_step.stats[ACTION]["min"][:15] == pytest.approx(oss_arm_min.flatten().tolist())
|
||||
assert pack_step.stats[ACTION]["max"][:15] == pytest.approx(oss_arm_max.flatten().tolist())
|
||||
assert pack_step.stats[ACTION]["min"][-1] == pytest.approx(20.0)
|
||||
assert pack_step.stats[ACTION]["max"][-1] == pytest.approx(90.0)
|
||||
torch.testing.assert_close(
|
||||
torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["min"]),
|
||||
oss_arm_min,
|
||||
)
|
||||
torch.testing.assert_close(
|
||||
torch.as_tensor(pack_step.raw_stats["relative_action"]["single_arm"]["q99"]),
|
||||
oss_arm_q99,
|
||||
)
|
||||
assert pack_step.stats[ACTION]["min"] == pytest.approx([*oss_arm_min.flatten().tolist(), 20.0])
|
||||
assert pack_step.stats[ACTION]["max"] == pytest.approx([*oss_arm_max.flatten().tolist(), 90.0])
|
||||
|
||||
packed = pack_step(
|
||||
{
|
||||
@@ -2534,13 +2443,7 @@ def test_groot_n1_7_generated_relative_stats_match_oss_gr00t_reference_numbers()
|
||||
torch.testing.assert_close(packed[TransitionKey.ACTION][0, :3, :6], expected_normalized)
|
||||
|
||||
decoded = decode_step({TransitionKey.ACTION: packed[TransitionKey.ACTION]})
|
||||
assert decoded[TransitionKey.ACTION].shape == (1, N1_7_NATIVE_ACTION_HORIZON, 6)
|
||||
torch.testing.assert_close(
|
||||
decoded[TransitionKey.ACTION][:, :3],
|
||||
action_a.unsqueeze(0)[:, :3],
|
||||
atol=1e-5,
|
||||
rtol=1e-5,
|
||||
)
|
||||
torch.testing.assert_close(decoded[TransitionKey.ACTION], action_a.unsqueeze(0), atol=1e-5, rtol=1e-5)
|
||||
|
||||
|
||||
def test_groot_n1_7_relative_action_stats_skip_padded_tail_chunks():
|
||||
|
||||
@@ -1,100 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isaac-GR00T N1.7 raw-state dropout training contract.
|
||||
|
||||
Isaac-GR00T zeroes the entire proprioceptive state of a sample with probability
|
||||
``state_dropout_prob`` (configured in the checkpoint's processor sidecar) during
|
||||
training only. Baseline LeRobot kept the processor deterministic, so this
|
||||
regularization never activated. These tests pin the train/eval split.
|
||||
"""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.processor_groot import GrootN17PackInputsStep
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
def _make_transition():
|
||||
return {
|
||||
TransitionKey.OBSERVATION: {OBS_STATE: torch.tensor([[1.0, 2.0], [3.0, 4.0]])},
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["Move", "Move"]},
|
||||
}
|
||||
|
||||
|
||||
def test_groot_n1_7_training_applies_raw_state_dropout_before_encoder():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=4,
|
||||
max_action_dim=4,
|
||||
normalize_min_max=False,
|
||||
training=True,
|
||||
state_dropout_prob=1.0,
|
||||
)
|
||||
|
||||
output = step(_make_transition())
|
||||
|
||||
expected = torch.zeros(2, 1, 4)
|
||||
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_training_state_dropout_is_disabled_under_no_grad():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=4,
|
||||
max_action_dim=4,
|
||||
normalize_min_max=False,
|
||||
training=True,
|
||||
state_dropout_prob=1.0,
|
||||
)
|
||||
|
||||
with torch.no_grad():
|
||||
output = step(_make_transition())
|
||||
|
||||
expected = torch.tensor([[[1.0, 2.0, 0.0, 0.0]], [[3.0, 4.0, 0.0, 0.0]]])
|
||||
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_eval_mode_state_dropout_is_inactive():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=4,
|
||||
max_action_dim=4,
|
||||
normalize_min_max=False,
|
||||
training=False,
|
||||
state_dropout_prob=1.0,
|
||||
)
|
||||
|
||||
output = step(_make_transition())
|
||||
|
||||
expected = torch.tensor([[[1.0, 2.0, 0.0, 0.0]], [[3.0, 4.0, 0.0, 0.0]]])
|
||||
torch.testing.assert_close(output[TransitionKey.OBSERVATION]["state"], expected)
|
||||
|
||||
|
||||
def test_groot_n1_7_pack_step_serializes_dropout_prob_but_not_training_mode():
|
||||
step = GrootN17PackInputsStep(
|
||||
max_state_dim=4,
|
||||
max_action_dim=4,
|
||||
normalize_min_max=False,
|
||||
training=True,
|
||||
state_dropout_prob=0.2,
|
||||
)
|
||||
|
||||
serialized = step.get_config()
|
||||
restored = GrootN17PackInputsStep(**serialized)
|
||||
|
||||
assert "training" not in serialized
|
||||
assert serialized["state_dropout_prob"] == 0.2
|
||||
assert restored.training is False
|
||||
assert restored.state_dropout_prob == 0.2
|
||||
@@ -1,156 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isaac-GR00T N1.7 train-time random crop contract (crop geometry only).
|
||||
|
||||
Isaac-GR00T crops a random ``crop_fraction`` window during training and the
|
||||
deterministic center window at eval, replaying the sampled window across all
|
||||
camera views of a sample (gr00t/data/transform/video.py, n1.5-release onward:
|
||||
"If mode is 'train', return a random crop transform. If mode is 'eval', return
|
||||
a center crop transform."). This mirrors LeRobot's own Diffusion/VQBeT
|
||||
``crop_is_random`` pattern. Color jitter is intentionally out of scope here.
|
||||
"""
|
||||
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.groot.processor_groot import (
|
||||
GrootN17VLMEncodeStep,
|
||||
_transform_n1_7_image_for_vlm_albumentations,
|
||||
)
|
||||
|
||||
|
||||
def _structured_image(h=480, w=640):
|
||||
yy, xx = np.mgrid[0:h, 0:w]
|
||||
return np.stack(
|
||||
[(xx * 255 / w), (yy * 255 / h), ((xx + yy) * 255 / (h + w))], axis=-1
|
||||
).astype(np.uint8)
|
||||
|
||||
|
||||
def test_crop_position_none_is_bitexact_center_crop():
|
||||
"""crop_position=None must remain byte-identical to the pre-change eval path."""
|
||||
img = _structured_image()
|
||||
ref = _transform_n1_7_image_for_vlm_albumentations(
|
||||
img, image_crop_size=None, image_target_size=[256, 256],
|
||||
shortest_image_edge=256, crop_fraction=0.95,
|
||||
)
|
||||
out = _transform_n1_7_image_for_vlm_albumentations(
|
||||
img, image_crop_size=None, image_target_size=[256, 256],
|
||||
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
|
||||
)
|
||||
np.testing.assert_array_equal(ref, out)
|
||||
|
||||
|
||||
def test_crop_position_center_matches_center_crop():
|
||||
img = _structured_image()
|
||||
center = _transform_n1_7_image_for_vlm_albumentations(
|
||||
img, image_crop_size=None, image_target_size=[256, 256],
|
||||
shortest_image_edge=256, crop_fraction=0.95, crop_position=None,
|
||||
)
|
||||
explicit = _transform_n1_7_image_for_vlm_albumentations(
|
||||
img, image_crop_size=None, image_target_size=[256, 256],
|
||||
shortest_image_edge=256, crop_fraction=0.95, crop_position=(0.5, 0.5),
|
||||
)
|
||||
# int-floor center vs rounded positional center may differ by <=1 px of grid
|
||||
assert center.shape == explicit.shape
|
||||
diff = np.abs(center.astype(np.int16) - explicit.astype(np.int16))
|
||||
assert diff.mean() < 3.0
|
||||
|
||||
|
||||
def test_crop_position_corners_differ_from_center():
|
||||
img = _structured_image()
|
||||
|
||||
def crop_at(position):
|
||||
return _transform_n1_7_image_for_vlm_albumentations(
|
||||
img,
|
||||
image_crop_size=None,
|
||||
image_target_size=[256, 256],
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
crop_position=position,
|
||||
)
|
||||
|
||||
center = crop_at(None)
|
||||
tl = crop_at((0.0, 0.0))
|
||||
br = crop_at((1.0, 1.0))
|
||||
assert not np.array_equal(center, tl)
|
||||
assert not np.array_equal(tl, br)
|
||||
|
||||
|
||||
def _video(img, views=2):
|
||||
return np.stack([img] * views, axis=0).reshape(1, 1, views, *img.shape)
|
||||
|
||||
|
||||
def _step(training):
|
||||
return GrootN17VLMEncodeStep(
|
||||
image_target_size=[256, 256],
|
||||
shortest_image_edge=256,
|
||||
crop_fraction=0.95,
|
||||
use_albumentations=True,
|
||||
training=training,
|
||||
)
|
||||
|
||||
|
||||
def test_training_crop_replays_one_window_across_views():
|
||||
video = _video(_structured_image())
|
||||
frames = _step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0]
|
||||
np.testing.assert_array_equal(np.asarray(frames[0]), np.asarray(frames[1]))
|
||||
|
||||
|
||||
def test_training_crop_differs_from_eval_center_crop():
|
||||
video = _video(_structured_image())
|
||||
random.seed(3) # a draw that is not the exact center
|
||||
train_frame = np.asarray(
|
||||
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
|
||||
)
|
||||
eval_frame = np.asarray(
|
||||
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
|
||||
)
|
||||
assert not np.array_equal(train_frame, eval_frame)
|
||||
|
||||
|
||||
def test_training_crop_is_disabled_under_no_grad():
|
||||
video = _video(_structured_image())
|
||||
with torch.no_grad():
|
||||
no_grad_frame = np.asarray(
|
||||
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
|
||||
)
|
||||
eval_frame = np.asarray(
|
||||
_step(training=False)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
|
||||
)
|
||||
np.testing.assert_array_equal(no_grad_frame, eval_frame)
|
||||
|
||||
|
||||
def test_training_mode_is_not_serialized():
|
||||
step = _step(training=True)
|
||||
serialized = step.get_config()
|
||||
assert "training" not in serialized
|
||||
restored = GrootN17VLMEncodeStep(**serialized)
|
||||
assert restored.training is False
|
||||
|
||||
|
||||
def test_training_crop_respects_global_seed():
|
||||
video = _video(_structured_image())
|
||||
|
||||
def draw():
|
||||
random.seed(11)
|
||||
return np.asarray(
|
||||
_step(training=True)._build_sample_images(video, batch_size=1, target_device=None)[0][0]
|
||||
)
|
||||
|
||||
np.testing.assert_array_equal(draw(), draw())
|
||||
@@ -1,121 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Isaac-GR00T N1.7 optimizer/scheduler/precision training contract.
|
||||
|
||||
Pins the LeRobot GR00T fine-tuning recipe to the native Isaac-GR00T contract:
|
||||
AdamW(lr=1e-4, betas=(0.9, 0.999), eps=1e-8, weight_decay=1e-5, grad clip 1.0),
|
||||
HF cosine schedule with ~5% warmup over the actual update count, FP32 master
|
||||
parameters under BF16 autocast, transformers-style weight-decay grouping, the
|
||||
frozen LM-head weight tie, and episode-tail exclusion for incomplete chunks.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.optim.schedulers import DiffuserSchedulerConfig
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.groot_n1_7 import _tie_unused_qwen_lm_head
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
|
||||
def test_groot_n1_7_optimizer_matches_isaac_training_contract():
|
||||
optimizer = GrootConfig().get_optimizer_preset()
|
||||
|
||||
assert optimizer.lr == pytest.approx(1e-4)
|
||||
assert optimizer.betas == pytest.approx((0.9, 0.999))
|
||||
assert optimizer.eps == pytest.approx(1e-8)
|
||||
assert optimizer.weight_decay == pytest.approx(1e-5)
|
||||
assert optimizer.grad_clip_norm == pytest.approx(1.0)
|
||||
|
||||
|
||||
def test_groot_n1_7_sampler_excludes_incomplete_action_tails():
|
||||
config = GrootConfig(chunk_size=16, n_action_steps=16)
|
||||
|
||||
assert len(config.action_delta_indices) == 16
|
||||
assert config.drop_n_last_frames == 15
|
||||
|
||||
|
||||
def test_groot_n1_7_scheduler_matches_isaac_hf_cosine_contract():
|
||||
config = GrootConfig(max_steps=20_000)
|
||||
scheduler_config = config.get_scheduler_preset()
|
||||
|
||||
assert isinstance(scheduler_config, DiffuserSchedulerConfig)
|
||||
assert scheduler_config.name == "cosine"
|
||||
assert scheduler_config.num_warmup_steps == 1_000
|
||||
|
||||
parameter = torch.nn.Parameter(torch.ones(()))
|
||||
optimizer = torch.optim.AdamW([parameter], lr=config.optimizer_lr)
|
||||
scheduler = scheduler_config.build(optimizer, num_training_steps=20_000)
|
||||
lr_factor = scheduler.lr_lambdas[0]
|
||||
|
||||
assert lr_factor(0) == pytest.approx(0.0)
|
||||
assert lr_factor(1_000) == pytest.approx(1.0)
|
||||
assert lr_factor(10_500) == pytest.approx(0.5)
|
||||
assert lr_factor(20_000) == pytest.approx(0.0, abs=1e-12)
|
||||
|
||||
|
||||
def test_groot_n1_7_scheduler_rounds_fractional_warmup_up_like_transformers():
|
||||
scheduler_config = GrootConfig(max_steps=777).get_scheduler_preset()
|
||||
|
||||
assert scheduler_config.num_warmup_steps == 39
|
||||
|
||||
|
||||
def test_groot_n1_7_model_parameters_use_fp32_checkpoint_and_optimizer_precision():
|
||||
module = torch.nn.Module()
|
||||
module.trainable = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16))
|
||||
module.frozen = torch.nn.Parameter(torch.ones(3, dtype=torch.bfloat16), requires_grad=False)
|
||||
|
||||
GrootPolicy._cast_model_parameters_to_fp32(module)
|
||||
|
||||
assert module.trainable.dtype == torch.float32
|
||||
assert module.frozen.dtype == torch.float32
|
||||
|
||||
|
||||
def test_groot_n1_7_ties_unused_qwen_lm_head_to_frozen_input_embeddings():
|
||||
class DummyQwen(torch.nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.embed_tokens = torch.nn.Embedding(7, 3)
|
||||
self.lm_head = torch.nn.Linear(3, 7, bias=False)
|
||||
|
||||
def get_input_embeddings(self):
|
||||
return self.embed_tokens
|
||||
|
||||
model = DummyQwen()
|
||||
_tie_unused_qwen_lm_head(model)
|
||||
|
||||
assert model.lm_head.weight is model.embed_tokens.weight
|
||||
assert len(list(model.parameters())) == 1
|
||||
|
||||
|
||||
def test_groot_n1_7_optimizer_groups_match_transformers_weight_decay_rules():
|
||||
module = torch.nn.Module()
|
||||
module.linear = torch.nn.Linear(3, 2)
|
||||
module.norm = torch.nn.LayerNorm(2)
|
||||
module.frozen = torch.nn.Parameter(torch.ones(1), requires_grad=False)
|
||||
|
||||
groups = GrootPolicy._build_weight_decay_parameter_groups(module)
|
||||
|
||||
assert len(groups) == 2
|
||||
assert "weight_decay" not in groups[0]
|
||||
assert groups[1]["weight_decay"] == 0.0
|
||||
assert groups[0]["params"] == [module.linear.weight]
|
||||
assert {id(parameter) for parameter in groups[1]["params"]} == {
|
||||
id(module.linear.bias),
|
||||
id(module.norm.weight),
|
||||
id(module.norm.bias),
|
||||
}
|
||||
Reference in New Issue
Block a user