Compare commits

..

5 Commits

Author SHA1 Message Date
Steven Palma 4688b9c27f refactor(groot): N1.7 style cleanup (utils, imports, flash-attn, config)
Mechanical refactor of the GR00T N1.7 policy to match the repo's architecture and
style standards. No change to policy algorithm/numerics; only UX/CLI and packaging
changes. Tests are intentionally left untouched (out of scope) and need updating
for the removed `model_version` field.

Cleanup & consolidation:
- Add `groot/utils.py` holding the pure, side-effect-free helpers (JSON I/O, value
  coercion, stat flattening, rot6d/SE3 math, language/batch prep) shared by the
  config and processor layers.
- Remove dead code: the unused `resolve_groot_n1_7_backbone_model` cache-resolver
  cluster, `GR00TN17Config.to_filtered_dict/json`, and the `_copy_default` wrapper.

Imports & execution guards:
- Hoist nested imports to module top; relative imports within the package, absolute
  for external modules. The version-gated Qwen3-VL classes import under the single
  `_transformers_available` guard (transformers is pinned >=5.4, which ships them).
- No import-time side effects: `_register_with_transformers()` now runs in
  `GR00TN17.__init__` (idempotent via `register(exist_ok=True)`), and the N1.5 step
  stubs register lazily before pipeline deserialization (idempotent via the
  registry, no run-once globals).
- Gate optional deps at the point of use with `require_package(..., extra="groot")`.

Dependencies & docs:
- Drop `flash-attn` (and its build-only dep `ninja`) from the `groot` extra; default
  to SDPA (numerically equivalent) with opt-in via `--policy.use_flash_attention`.
  Un-comment `lerobot[groot]` in the `all` extra and regenerate `uv.lock`.
- Rewrite the `groot.mdx` install section: flash-attn is a purely optional,
  user-managed optimization that LeRobot neither installs nor requires.

Config & CLI:
- Surface previously-frozen knobs on `GrootConfig` (plumbed into `GR00TN17Config`;
  no-ops at their defaults): inference — `num_inference_timesteps`, `rtc_ramp_rate`,
  `use_flash_attention`; fine-tuning — `tune_top_llm_layers` (partial-LLM tuning)
  and `tune_vlln` (previously hardwired to True).
- Convert the single-valued `model_version` and `n1_7_backbone_model` fields to
  internal constants.
- Keep `base_model_path`: it is NOT equivalent to `pretrained_path` (raw NVIDIA
  checkpoints have no LeRobot `type` field and load only via `base_model_path`) and
  is genuinely user-tunable.
- Keep the deprecated Isaac-GR00T/N1.5 fields (and the dead LoRA fields) as a
  back-compat block so a v0.5.1 N1.5 `config.json` still parses under draccus and is
  rejected with the friendly N1.5 removal message instead of an opaque decode error.
2026-06-16 14:45:37 +02:00
Steven Palma 5753f8c18b fix(groot): GPU/tensor N1.7 image preprocessing + resize to trained resolution
GR00T training was dataloader-bound (0->100->0 GPU-utilization sawtooth).
GrootN17VLMEncodeStep ran the Qwen3-VL image processor per frame on PIL images
on the single CPU main-loop thread, and that cost is timed inside dataloading_s
(preprocessor(batch) runs in the main process, not the dataloader workers), so
adding workers cannot hide it.

- Feed the torchvision-backed Qwen3-VL processor (C,H,W) uint8 tensors instead
  of a per-frame Image.fromarray PIL roundtrip, and run resize/normalize/patchify
  on config.device (GPU) when available. Bit-identical on CPU when no resize is
  configured; with a resize only the PIL->torchvision bicubic backend differs
  (<2/255 per pixel). The use_albumentations path stays PIL/cv2; reload on a box
  without the saved device falls back to CPU.

- Default image_target_size/crop to the N1.7 backbone's training geometry
  (256x256 / 230x230) when a checkpoint ships no image sizing (checkpoint_assets
  is None, e.g. finetuning nvidia/GR00T-N1.7-3B via repo-id with a new
  embodiment). Previously image_target_size=None disabled the resize, so
  full-resolution frames were patchified into ~4.7x more vision tokens than the
  model was trained on -- inflating dataloading_s (patchify) and update_s (VLM
  sequence) and skewing the input distribution. Checkpoints that pin their own
  sizing are honored; the default constants are shared with GR00T_N1_7_DEFAULTS.

Net: preprocessing leaves the CPU critical path and the VLM sees the resolution
it was trained on -- faster training/inference and a correct train/serve
distribution. Affects inference too (shared preprocessor); existing checkpoints
still load (backward compatible) but must be retrained to gain the benefits.
2026-06-15 18:20:49 +02:00
Kartik 97bd373d15 Merge pull request #15 from huggingface/fix/groot_n17_core
fix(groot): N1.7 config defaults, N1.5 rejection, and processor/model runtime fixes
2026-06-13 23:05:51 +02:00
Kartik 10a73e3c95 Merge pull request #14 from huggingface/fix/groot_n17_backbone
fix(groot): N1.7 backbone loading and DiT parameter-count logging
2026-06-13 21:47:35 +02:00
Kartik 27c9288b24 Merge pull request #13 from huggingface/fix/groot_n17_docs
docs(groot): document the N1.5 removal and the N1.7 parity test
2026-06-13 21:47:05 +02:00
9 changed files with 510 additions and 1039 deletions
+20 -24
View File
@@ -5,7 +5,7 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
LeRobot integrates GR00T N1.7 through the `groot` policy type.
> [!WARNING]
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints and configs are rejected with a migration note. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
## Model Overview
@@ -31,46 +31,43 @@ This approach allows the model to be highly adaptable through post-training for
## Installation Requirements
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
GR00T is intended for NVIDIA GPU-accelerated systems. Install LeRobot with the GR00T extra:
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
```bash
pip install "lerobot[groot]"
```
For a source checkout:
```bash
pip install -e ".[groot]"
```
### Optional: Flash Attention acceleration
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
```bash
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
--index-url https://download.pytorch.org/whl/cu128
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
```
3. Install and verify Flash Attention:
```bash
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
```
4. Install LeRobot with the GR00T extra:
2. Install lerobot with the groot extra.
```bash
pip install "lerobot[groot]"
```
For a source checkout, use the same order, then install the local package with:
```bash
pip install -e ".[groot]"
```
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
## Usage
To use GR00T N1.7:
```bash
--policy.type=groot \
--policy.model_version=n1.7
--policy.type=groot
```
## Training
@@ -142,7 +139,6 @@ hf download nvidia/GR00T-N1.7-LIBERO \
lerobot-eval \
--policy.type=groot \
--policy.model_version=n1.7 \
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
--policy.embodiment_tag=libero_sim \
--env.type=libero \
+1 -3
View File
@@ -208,8 +208,6 @@ groot = [
"dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0",
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
]
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
@@ -280,7 +278,7 @@ all = [
"lerobot[pi]",
"lerobot[molmoact2]",
"lerobot[smolvla]",
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
"lerobot[groot]",
"lerobot[xvla]",
"lerobot[hilserl]",
"lerobot[vla_jepa]",
+48 -172
View File
@@ -14,9 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import json
import logging
import os
from dataclasses import dataclass, field
from pathlib import Path
@@ -24,6 +22,8 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
from lerobot.utils.constants import ACTION, OBS_STATE
from .utils import read_json
logger = logging.getLogger(__name__)
GROOT_N1_7 = "n1.7"
@@ -42,12 +42,8 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
)
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
# falls back to these when a checkpoint ships no image sizing in its processor_config
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
# are resized to the expected resolution instead of being patchified at full camera
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
@@ -131,12 +127,7 @@ def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
else:
return False
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return False
config = read_json(config_path)
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
@@ -145,11 +136,7 @@ def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
processor_config = read_json(processor_config_path)
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
if not isinstance(modality_configs, dict):
@@ -168,11 +155,7 @@ def infer_groot_n1_7_action_horizon(
return None
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
try:
with processor_config_path.open() as f:
processor_config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
processor_config = read_json(processor_config_path)
processor_kwargs = processor_config.get("processor_kwargs", {})
if not isinstance(processor_kwargs, dict):
@@ -214,83 +197,6 @@ def infer_groot_n1_7_action_execution_horizon(
return action_horizon
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
model_path = Path(model_name).expanduser()
if model_path.exists():
return str(model_path)
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
return str(cached_snapshot) if cached_snapshot is not None else model_name
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
required_files = (
"config.json",
"tokenizer_config.json",
"preprocessor_config.json",
"video_preprocessor_config.json",
)
for hub_cache in _candidate_hf_hub_caches(cache_dir):
repo_cache = hub_cache / repo_cache_name
snapshots_dir = repo_cache / "snapshots"
if not snapshots_dir.is_dir():
continue
candidates: list[Path] = []
ref_path = repo_cache / "refs" / "main"
try:
ref = ref_path.read_text().strip()
except OSError:
ref = ""
if ref:
candidates.append(snapshots_dir / ref)
candidates.extend(
sorted(
(path for path in snapshots_dir.iterdir() if path.is_dir()),
key=lambda path: path.stat().st_mtime,
reverse=True,
)
)
seen: set[Path] = set()
for snapshot in candidates:
if snapshot in seen:
continue
seen.add(snapshot)
if all((snapshot / filename).exists() for filename in required_files):
return snapshot
return None
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
candidates: list[Path] = []
if cache_dir is not None:
cache_path = Path(cache_dir).expanduser()
candidates.append(cache_path)
candidates.append(cache_path / "hub")
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
if hub_cache:
candidates.append(Path(hub_cache).expanduser())
hf_home = os.environ.get("HF_HOME")
if hf_home:
candidates.append(Path(hf_home).expanduser() / "hub")
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
deduped: list[Path] = []
seen: set[Path] = set()
for candidate in candidates:
resolved = candidate.resolve() if candidate.exists() else candidate
if resolved not in seen:
seen.add(resolved)
deduped.append(candidate)
return deduped
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
path = Path(model_path).expanduser()
if path.is_dir():
@@ -300,16 +206,7 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
else:
return None
if not config_path.exists():
return None
try:
with config_path.open() as f:
config = json.load(f)
except (OSError, json.JSONDecodeError):
return None
return _infer_groot_model_version_from_config(config)
return _infer_groot_model_version_from_config(read_json(config_path))
def _infer_groot_model_version_from_config(config: dict) -> str | None:
@@ -372,15 +269,15 @@ class GrootConfig(PreTrainedConfig):
# Groot-specific model parameters
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
model_version: str = GROOT_N1_7
# Path or HuggingFace model ID for the base Groot model
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
# model *source*, and is intentionally distinct from the inherited `pretrained_path`:
# `pretrained_path` (`--policy.path`) points at a saved LeRobot checkpoint directory whose
# `config.json` carries a `type` field, whereas a raw NVIDIA GR00T checkpoint has no such
# field and so can only be loaded through `base_model_path` (`--policy.base_model_path`).
# Defaults to GROOT_N1_7_BASE_MODEL when unset (resolved in __post_init__).
base_model_path: str | None = None
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
@@ -389,40 +286,6 @@ class GrootConfig(PreTrainedConfig):
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
embodiment_tag: str = "new_embodiment"
# Inference-only override for the number of flow-matching denoising steps used to decode an
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
# inference speed for action quality; applied at base-model load via _create_groot_model.
num_inference_timesteps: int | None = None
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
execution_horizon: int | None = None
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
# base-model build (so it applies during lerobot-train, which uses __init__ not
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
warm_start_embodiment_slot: int | str | None = None
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
# When True, GR00T converts absolute->relative inside its own pack step (training) and
# reconstructs absolute inside its own flat decode step (inference), using a cached
# reference state. The dataset stays absolute; compute relative ACTION stats with
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
# "['gripper']"` (this only rewrites stats, not actions).
use_relative_actions: bool = False
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
# Case-insensitive token match against action_feature_names.
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
# identified and kept absolute. When None, the gripper cannot be identified.
action_feature_names: list[str] | None = None
# Fine-tuning control arguments
# Whether to fine-tune the llm backbone
@@ -437,20 +300,31 @@ class GrootConfig(PreTrainedConfig):
# Whether to fine-tune the diffusion model
tune_diffusion_model: bool = True
# LoRA parameters (from groot_finetune_script.py)
# Rank for the LORA model. If 0, no LORA will be used.
lora_rank: int = 0
# Whether to fine-tune the VL LayerNorm + VL self-attention projector in the action head.
tune_vlln: bool = True
# Alpha value for the LORA model
lora_alpha: int = 16
# Number of top LLM backbone layers to fine-tune (0 = none). Lets you adapt just the final
# language layers without unfreezing the whole backbone; independent of `tune_llm`, which tunes
# the entire LLM.
tune_top_llm_layers: int = 0
# Dropout rate for the LORA model
lora_dropout: float = 0.1
# Inference-time knob: Number of flow-matching denoising steps used to decode an action chunk.
# Trades inference latency for action quality.
# None keeps the checkpoint value (GR00T N1.7 default: 4).
num_inference_timesteps: int | None = None
# Whether to use the full model for LORA
lora_full_model: bool = False
# Inference-time knob: Real-Time Chunking (RTC) overlap-blend ramp rate, used when the RTC engine
# supplies a previous-chunk prefix. Higher values blend the overlapping prefix more aggressively.
# None keeps the checkpoint value (GR00T N1.7 default: 6.0).
rtc_ramp_rate: float | None = None
# Training parameters (matching groot_finetune_script.py)
# Inference-time knob: Whether to request the flash-attention-2 kernel for the Qwen3-VL backbone.
# flash-attn is an optional, user-managed optimization; when it is absent (the default),
# the backbone transparently falls back to SDPA, which is numerically equivalent.
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
use_flash_attention: bool = False
# Training parameters
optimizer_lr: float = 1e-4
optimizer_betas: tuple[float, float] = (0.95, 0.999)
optimizer_eps: float = 1e-8
@@ -459,12 +333,18 @@ class GrootConfig(PreTrainedConfig):
use_bf16: bool = True
# TODO(Steven): Remove these deprecated fields in a future release.
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
# would break every previously saved groot checkpoint at config-load time.
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
# unused by the LeRobot N1.7 implementation except the `tokenizer_assets_repo` N1.5 tripwire and
# the `image_size` legacy remap in __post_init__. They are kept ONLY so a config.json saved by an
# earlier lerobot release (notably a GR00T N1.5 checkpoint) still parses under draccus — which
# rejects unknown fields — and is then rejected with a clear N1.5 removal message rather than an
# opaque draccus decoding error.
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
tokenizer_assets_repo: str | None = None
lora_rank: int = 0
lora_alpha: int = 16
lora_dropout: float = 0.1
lora_full_model: bool = False
video_backend: str = "decord"
balance_dataset_weights: bool = True
balance_trajectory_weights: bool = True
@@ -484,7 +364,6 @@ class GrootConfig(PreTrainedConfig):
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
)
self.model_version = normalize_groot_model_version(self.model_version)
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
if self.base_model_path is None:
self.base_model_path = GROOT_N1_7_BASE_MODEL
@@ -529,9 +408,9 @@ class GrootConfig(PreTrainedConfig):
setattr(self, field_name, n1_7_value)
inferred_version = infer_groot_model_version(self.base_model_path)
if inferred_version is not None and inferred_version != self.model_version:
if inferred_version is not None and inferred_version != GROOT_N1_7:
message = (
f"GR00T model_version '{self.model_version}' does not match base_model_path "
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"'{self.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
@@ -545,9 +424,6 @@ class GrootConfig(PreTrainedConfig):
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
)
# groot_repo_path is now optional since we ported the components
# No validation needed
def validate_features(self) -> None:
"""Validate and set up input/output features for Groot."""
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
+20 -53
View File
@@ -16,7 +16,6 @@
from __future__ import annotations
import importlib
import json
import logging
from contextlib import suppress
from copy import deepcopy
@@ -35,7 +34,14 @@ from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionT
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
if TYPE_CHECKING or _transformers_available:
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
from transformers import (
AutoConfig,
AutoModel,
PretrainedConfig,
PreTrainedModel,
Qwen3VLConfig,
Qwen3VLForConditionalGeneration,
)
from transformers.feature_extraction_utils import BatchFeature
else:
AutoConfig = None
@@ -43,25 +49,17 @@ else:
PretrainedConfig = object
PreTrainedModel = object
BatchFeature = None
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
try:
import tree
except ImportError:
tree = None
try:
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
except ImportError:
Qwen3VLConfig = None
Qwen3VLForConditionalGeneration = None
logger = logging.getLogger(__name__)
def _copy_default(value: Any) -> Any:
return deepcopy(value)
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"model_dtype": "bfloat16",
"dtype": "bfloat16",
@@ -74,7 +72,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
"tune_visual": False,
"select_layer": 16,
"reproject_vision": False,
"use_flash_attention": True,
"use_flash_attention": False,
"load_bf16": False,
"backbone_trainable_params_fp32": True,
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
@@ -152,29 +150,10 @@ class GR00TN17Config(PretrainedConfig):
def __init__(self, **kwargs):
super().__init__(**kwargs)
for key, value in GR00T_N1_7_DEFAULTS.items():
setattr(self, key, _copy_default(kwargs.pop(key, value)))
setattr(self, key, deepcopy(kwargs.pop(key, value)))
for key, value in kwargs.items():
setattr(self, key, value)
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
cfg = self.to_dict()
if not exclude_augment:
return cfg
exclude_keys = {
"random_rotation_angle",
"color_jitter_params",
"use_albumentations_transforms",
"formalize_language",
"image_crop_size",
"image_target_size",
"shortest_image_edge",
"crop_fraction",
}
return {k: v for k, v in cfg.items() if k not in exclude_keys}
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
class CategorySpecificLinear(nn.Module):
"""Linear layer with category-specific weights for multi-embodiment support."""
@@ -275,13 +254,7 @@ class Qwen3Backbone(nn.Module):
transformers_loading_kwargs: dict[str, Any] | None = None,
load_pretrained_weights: bool = True,
):
if Qwen3VLForConditionalGeneration is None:
raise ImportError(
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
"or use a transformers version that provides Qwen3-VL support."
)
require_package("transformers", extra="groot")
super().__init__()
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
@@ -361,11 +334,6 @@ class Qwen3Backbone(nn.Module):
if _is_cosmos_reason2_backbone(model_name):
backbone_config = _cosmos_reason2_qwen3_vl_config()
else:
if AutoConfig is None:
raise ImportError(
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
@@ -762,11 +730,6 @@ def _is_cosmos_reason2_backbone(model_name: str) -> bool:
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
if Qwen3VLConfig is None:
raise ImportError(
"Qwen3VLConfig is required for GR00T N1.7. "
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
)
return Qwen3VLConfig(
image_token_id=151655,
video_token_id=151656,
@@ -844,6 +807,7 @@ class GR00TN17(PreTrainedModel):
transformers_loading_kwargs: dict[str, Any] | None = None,
load_backbone_weights: bool = True,
):
_register_with_transformers()
super().__init__(config)
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
self.config = config
@@ -949,6 +913,12 @@ class GR00TN17(PreTrainedModel):
def _register_with_transformers() -> None:
"""Register GR00T N1.7 with transformers' Auto* factories.
Idempotent: ``register(..., exist_ok=True)`` makes repeat calls no-ops (with a fallback that
suppresses the already-registered error on transformers builds whose ``register()`` predates
``exist_ok``), so no run-once guard is needed.
"""
if AutoConfig is None or AutoModel is None:
return
try:
@@ -961,6 +931,3 @@ def _register_with_transformers() -> None:
except TypeError:
with suppress(ValueError):
AutoModel.register(GR00TN17Config, GR00TN17)
_register_with_transformers()
+25 -180
View File
@@ -30,6 +30,9 @@ from pathlib import Path
from typing import TypeVar
import torch
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
from torch import Tensor
from lerobot.configs import FeatureType, PolicyFeature
@@ -46,106 +49,14 @@ from .configuration_groot import (
infer_groot_model_version,
infer_groot_n1_7_action_execution_horizon,
infer_groot_n1_7_action_horizon,
normalize_groot_model_version,
)
from .groot_n1_7 import GR00TN17
logger = logging.getLogger(__name__)
T = TypeVar("T", bound="GrootPolicy")
def _resolve_embodiment_id(value: int | str) -> int:
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
Raises ValueError listing the known keys if the name is unknown.
"""
from .processor_groot import N1_7_EMBODIMENT_MAPPING
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
if isinstance(value, int):
return value
if value in N1_7_EMBODIMENT_MAPPING:
return N1_7_EMBODIMENT_MAPPING[value]
raise ValueError(
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
)
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
"""Copy category-specific action-head weights from one embodiment slot to another.
Used at base-model load (training only) to warm-start a cold target embodiment slot
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
parameters across every CategorySpecificLinear in the action head's state encoder,
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
of range or identical.
"""
if source_id == target_id:
logger.warning(
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
"skipping (nothing to copy).",
source_id,
)
return
action_head = getattr(model, "action_head", None)
if action_head is None:
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
return
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
linear_groups = [
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
]
copied: list[str] = []
with torch.no_grad():
for submodule, attr_names in linear_groups:
if submodule is None:
continue
submodule_name = type(submodule).__name__
for attr_name in attr_names:
lin = getattr(submodule, attr_name, None)
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
continue
num_categories = lin.W.shape[0]
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
logger.warning(
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
"for %s.%s (num_categories=%d); skipping this layer.",
source_id,
target_id,
submodule_name,
attr_name,
num_categories,
)
continue
lin.W.data[target_id] = lin.W.data[source_id].clone()
lin.b.data[target_id] = lin.b.data[source_id].clone()
copied.append(f"{submodule_name}.{attr_name}")
if copied:
logger.info(
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
"to slot %d for: %s.",
source_id,
target_id,
", ".join(copied),
)
else:
logger.warning(
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
"(source_id=%d, target_id=%d).",
source_id,
target_id,
)
class GrootPolicy(PreTrainedPolicy):
"""Wrapper around external Groot model for LeRobot integration."""
@@ -166,46 +77,30 @@ class GrootPolicy(PreTrainedPolicy):
self.reset()
def _create_groot_model(self):
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
# Handle Flash Attention compatibility issues
self._handle_flash_attention_compatibility()
"""Create and initialize the GR00T N1.7 model using the ported components."""
model_kwargs = {
"pretrained_model_name_or_path": self.config.base_model_path,
"tune_llm": self.config.tune_llm,
"tune_visual": self.config.tune_visual,
"tune_projector": self.config.tune_projector,
"tune_diffusion_model": self.config.tune_diffusion_model,
# Forwarded as a GR00TN17Config override; read back by set_trainable_parameters.
"tune_top_llm_layers": self.config.tune_top_llm_layers,
"use_flash_attention": self.config.use_flash_attention,
}
from .groot_n1_7 import GR00TN17
# Surface the inference-time knobs onto the model config only when the user set them; None
# leaves the value baked into the checkpoint untouched.
if self.config.num_inference_timesteps is not None:
model_kwargs["num_inference_timesteps"] = self.config.num_inference_timesteps
if self.config.rtc_ramp_rate is not None:
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
model = GR00TN17.from_pretrained(
return GR00TN17.from_pretrained(
**model_kwargs,
tune_vlln=True,
tune_vlln=self.config.tune_vlln,
transformers_loading_kwargs={"trust_remote_code": True},
)
# Inference-only override for the number of flow-matching denoising steps. The action
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
# t schedule adapt automatically.
if self.config.num_inference_timesteps is not None:
n = int(self.config.num_inference_timesteps)
model.config.num_inference_timesteps = n
model.action_head.num_inference_timesteps = n
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
if self.config.warm_start_embodiment_slot is not None:
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
_warm_start_embodiment_slot(model, source_id, target_id)
return model
def reset(self):
"""Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self._action_queue_steps)
@@ -248,15 +143,7 @@ class GrootPolicy(PreTrainedPolicy):
Returns:
Initialized GrootPolicy instance with loaded model
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
requested_version = (
normalize_groot_model_version(config.model_version)
if config is not None
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
)
requested_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
logger.info(
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
requested_version,
@@ -310,10 +197,8 @@ class GrootPolicy(PreTrainedPolicy):
logger.info("Detected base GR00T model, loading from HuggingFace...")
if config is None:
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
# Create default config with the pretrained path
config = GrootConfig(
model_version=model_version,
base_model_path=str(pretrained_name_or_path),
)
@@ -336,11 +221,10 @@ class GrootPolicy(PreTrainedPolicy):
if hasattr(config, key):
setattr(config, key, value)
config.model_version = normalize_groot_model_version(config.model_version)
inferred_version = infer_groot_model_version(config.base_model_path)
if inferred_version is not None and inferred_version != config.model_version:
if inferred_version is not None and inferred_version != GROOT_N1_7:
message = (
f"GR00T model_version '{config.model_version}' does not match base_model_path "
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
f"'{config.base_model_path}', which looks like '{inferred_version}'."
)
if inferred_version == GROOT_N1_5:
@@ -371,11 +255,7 @@ class GrootPolicy(PreTrainedPolicy):
horizons.append(checkpoint_action_horizon)
if execution_horizon is not None:
horizons.append(execution_horizon)
# An explicit config override caps the open-loop horizon (inference cadence), overriding
# the value inferred from the checkpoint/embodiment.
if self.config.execution_horizon is not None:
horizons.append(max(1, int(self.config.execution_horizon)))
return max(1, min(horizons))
return min(horizons)
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
"""Return the policy-facing action horizon for a native GR00T prediction."""
@@ -543,27 +423,15 @@ class GrootPolicy(PreTrainedPolicy):
"""
self.eval()
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
from .processor_groot import _GROOT_REF_HOLDER_KEY
holder = batch.get(_GROOT_REF_HOLDER_KEY)
if holder is not None:
holder.freeze()
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
# During inference, we do not pass action because it is predicted.
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
groot_options = None
if self.config.model_version == GROOT_N1_7:
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
groot_inputs,
inference_delay=kwargs.get("inference_delay"),
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
)
# Get device from model parameters
device = get_device_from_parameters(self)
@@ -594,26 +462,3 @@ class GrootPolicy(PreTrainedPolicy):
actions = self.predict_action_chunk(batch)
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
return self._action_queue.popleft()
# -------------------------
# Internal helpers
# -------------------------
def _handle_flash_attention_compatibility(self) -> None:
"""Log Flash Attention availability (diagnostic only).
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
change behaviour or mutate global state.
"""
try:
import flash_attn
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
except ImportError:
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
except Exception as e: # noqa: BLE001
logger.warning(
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
e,
)
File diff suppressed because it is too large Load Diff
+256
View File
@@ -0,0 +1,256 @@
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Shared, side-effect-free utilities for the GR00T N1.7 policy.
These helpers are consumed by both the config layer (checkpoint sidecar
inspection) and the processor layer (stat flattening, action decoding, language
and image packing). They are pure functions with no GR00T-specific state so they
can be unit-tested in isolation and reused without importing the heavier
config/processor modules.
"""
from __future__ import annotations
import json
from pathlib import Path
from typing import Any
import numpy as np
import torch
def read_json(path: Path) -> dict[str, Any]:
"""Read a JSON object from ``path``, returning ``{}`` on any read/parse error."""
try:
with path.open() as f:
data = json.load(f)
except (OSError, json.JSONDecodeError):
return {}
return data if isinstance(data, dict) else {}
def as_int_pair(value: Any) -> list[int] | None:
if not isinstance(value, (list, tuple)) or len(value) != 2:
return None
try:
return [int(value[0]), int(value[1])]
except (TypeError, ValueError):
return None
def as_optional_int(value: Any) -> int | None:
if value is None:
return None
try:
return int(value)
except (TypeError, ValueError):
return None
def as_optional_float(value: Any) -> float | None:
if value is None:
return None
try:
return float(value)
except (TypeError, ValueError):
return None
def as_float_list(values: Any) -> list[float]:
if values is None:
return []
if isinstance(values, torch.Tensor):
return values.detach().cpu().reshape(-1).float().tolist()
if isinstance(values, np.ndarray):
return values.reshape(-1).astype(np.float32).tolist()
if isinstance(values, (list, tuple)):
flattened: list[float] = []
for value in values:
flattened.extend(as_float_list(value))
return flattened
return [float(values)]
def config_value(value: Any) -> str:
if hasattr(value, "value"):
value = value.value
text = str(value).lower()
return {
"relative": "relative",
"absolute": "absolute",
"delta": "delta",
"eef": "eef",
"non_eef": "non_eef",
"default": "default",
"xyz_rot6d": "xyz+rot6d",
"xyz+rot6d": "xyz+rot6d",
"xyz_rotvec": "xyz+rotvec",
"xyz+rotvec": "xyz+rotvec",
}.get(text, text)
def has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
if not stats:
return False
return any(bool(modality_stats) for modality_stats in stats.values())
def stat_dim_from_entry(entry: dict[str, Any]) -> int:
for stat_name in ("mean", "q01", "min", "max", "std"):
value = entry.get(stat_name)
if isinstance(value, list) and len(value) > 0:
return len(value)
return 0
def flatten_n1_7_modality_stats(
*,
embodiment_stats: dict[str, Any],
embodiment_config: dict[str, Any],
modality: str,
use_percentiles: bool,
use_relative_action: bool,
) -> dict[str, list[float]]:
"""Flatten one N1.7 modality's grouped statistics in checkpoint order.
When checkpoints request percentile normalization, q01/q99 replace min/max
for regular groups. Relative action groups read from ``relative_action``
stats and keep min/max, matching Isaac-GR00T's processor override.
"""
source_stats = embodiment_stats.get(modality, {})
modality_config = embodiment_config.get(modality, {})
if not isinstance(source_stats, dict) or not isinstance(modality_config, dict):
return {}
modality_keys = modality_config.get("modality_keys", [])
if not isinstance(modality_keys, list):
return {}
flattened: dict[str, list[float]] = {}
action_configs = modality_config.get("action_configs", []) if modality == "action" else []
if not isinstance(action_configs, list):
action_configs = []
relative_stats = embodiment_stats.get("relative_action", {})
if not isinstance(relative_stats, dict):
relative_stats = {}
for stat_name in ("min", "max", "mean", "std"):
values: list[float] = []
source_stat_name = stat_name
if use_percentiles and stat_name == "min":
source_stat_name = "q01"
elif use_percentiles and stat_name == "max":
source_stat_name = "q99"
for idx, modality_key in enumerate(modality_keys):
if not isinstance(modality_key, str):
continue
key_source_stats = source_stats
key_stat_name = source_stat_name
if modality == "action" and use_relative_action and idx < len(action_configs):
action_config = action_configs[idx]
if isinstance(action_config, dict) and config_value(action_config.get("rep")) == "relative":
key_source_stats = relative_stats
key_stat_name = stat_name
key_stats = key_source_stats.get(modality_key, {})
if not isinstance(key_stats, dict):
raise KeyError(f"Missing statistics for {modality}.{modality_key}")
raw_values = key_stats.get(key_stat_name)
if raw_values is None:
raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}")
values.extend(as_float_list(raw_values))
if values:
flattened[stat_name] = values
return flattened
def rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
rows = rot6d.reshape(2, 3).astype(np.float64)
row1 = rows[0] / np.linalg.norm(rows[0])
row2 = rows[1] - np.dot(row1, rows[1]) * row1
row2 = row2 / np.linalg.norm(row2)
row3 = np.cross(row1, row2)
return np.vstack([row1, row2, row3])
def xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray:
transform = np.eye(4, dtype=np.float64)
transform[:3, :3] = rot6d_to_matrix(xyz_rot6d[3:])
transform[:3, 3] = xyz_rot6d[:3]
return transform
def homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray:
return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0)
def relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray:
"""Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses."""
out = np.empty_like(action, dtype=np.float64)
for batch_idx in range(action.shape[0]):
reference = xyz_rot6d_to_homogeneous(reference_state[batch_idx])
for timestep in range(action.shape[1]):
relative = xyz_rot6d_to_homogeneous(action[batch_idx, timestep])
out[batch_idx, timestep] = homogeneous_to_xyz_rot6d(reference @ relative)
return out.astype(np.float32)
def infer_n1_7_batch_size_and_device(
obs: dict[str, Any], action: torch.Tensor | None
) -> tuple[int, torch.device]:
for value in list(obs.values()) + [action]:
if isinstance(value, torch.Tensor):
return value.shape[0], value.device
video = obs.get("video")
if isinstance(video, np.ndarray):
return video.shape[0], torch.device("cpu")
return 1, torch.device("cpu")
def prepare_n1_7_language_batch(
language: Any,
batch_size: int,
*,
formalize_language: bool,
) -> list[str]:
default_language = "Perform the task."
if language is None or (isinstance(language, str) and language == ""):
languages = [default_language] * batch_size
elif isinstance(language, str):
languages = [language] * batch_size
elif isinstance(language, (list, tuple)):
languages = list(language)
if len(languages) == 0:
languages = [default_language] * batch_size
elif len(languages) == 1 and batch_size > 1:
languages = languages * batch_size
elif len(languages) != batch_size:
raise ValueError(
f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}."
)
else:
languages = [str(language)] * batch_size
formatted = []
for item in languages:
text = str(item) if item else default_language
if formalize_language:
text = text.lower()
text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_")
formatted.append(text)
return formatted
+2 -3
View File
@@ -41,7 +41,7 @@ from lerobot.policies.groot.processor_groot import (
GrootN17ActionDecodeStep,
GrootN17PackInputsStep,
GrootN17VLMEncodeStep,
_transform_n1_7_image_for_vlm,
_transform_n1_7_image_for_vlm_albumentations,
make_groot_pre_post_processors,
)
from lerobot.processor import (
@@ -1529,13 +1529,12 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3)
transformed = _transform_n1_7_image_for_vlm(
transformed = _transform_n1_7_image_for_vlm_albumentations(
Image.fromarray(image_np),
image_crop_size=[230, 230],
image_target_size=[256, 256],
shortest_image_edge=256,
crop_fraction=0.95,
use_albumentations=True,
)
expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)
Generated
+4 -41
View File
@@ -1464,17 +1464,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" },
]
[[package]]
name = "flash-attn"
version = "2.8.3"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "einops", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" },
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" }
[[package]]
name = "flatbuffers"
version = "25.12.19"
@@ -2690,8 +2679,10 @@ all = [
{ name = "contourpy" },
{ name = "datasets" },
{ name = "debugpy" },
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
{ name = "deepdiff" },
{ name = "diffusers" },
{ name = "dm-tree" },
{ name = "dynamixel-sdk" },
{ name = "faker" },
{ name = "fastapi" },
@@ -2739,6 +2730,7 @@ all = [
{ name = "scikit-image" },
{ name = "scipy" },
{ name = "teleop" },
{ name = "timm" },
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
{ name = "torchdiffeq" },
{ name = "transformers" },
@@ -2843,8 +2835,6 @@ groot = [
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
{ name = "diffusers" },
{ name = "dm-tree" },
{ name = "flash-attn", marker = "sys_platform != 'darwin'" },
{ name = "ninja" },
{ name = "peft" },
{ name = "timm" },
{ name = "transformers" },
@@ -3087,7 +3077,6 @@ requires-dist = [
{ name = "faker", marker = "extra == 'sarm'", specifier = ">=33.0.0,<35.0.0" },
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
@@ -3132,6 +3121,7 @@ requires-dist = [
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["groot"], marker = "extra == 'all'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
@@ -3224,7 +3214,6 @@ requires-dist = [
{ name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" },
{ name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" },
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" },
{ name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" },
{ name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" },
{ name = "numpy", specifier = ">=2.0.0,<2.3.0" },
{ name = "onnx", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" },
@@ -4012,32 +4001,6 @@ wheels = [
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
]
[[package]]
name = "ninja"
version = "1.13.0"
source = { registry = "https://pypi.org/simple" }
sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" }
wheels = [
{ url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" },
{ url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" },
{ url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" },
{ url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" },
{ url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" },
{ url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" },
{ url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" },
{ url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" },
{ url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" },
{ url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" },
{ url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" },
{ url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" },
{ url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" },
{ url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" },
{ url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" },
{ url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" },
{ url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" },
{ url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" },
]
[[package]]
name = "nodeenv"
version = "1.10.0"