mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4688b9c27f |
+20
-24
@@ -5,7 +5,7 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints and configs are rejected with a migration note. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
@@ -31,46 +31,43 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. Install LeRobot with the GR00T extra:
|
||||
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
### Optional: Flash Attention acceleration
|
||||
|
||||
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
|
||||
|
||||
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
|
||||
```bash
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
2. Install lerobot with the groot extra.
|
||||
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T N1.7:
|
||||
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
--policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -142,7 +139,6 @@ hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
|
||||
+1
-3
@@ -208,8 +208,6 @@ groot = [
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
@@ -280,7 +278,7 @@ all = [
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[groot]",
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
|
||||
@@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
@@ -24,6 +22,8 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .utils import read_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
@@ -127,12 +127,7 @@ def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
config = read_json(config_path)
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
@@ -141,11 +136,7 @@ def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
processor_config = read_json(processor_config_path)
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
@@ -164,11 +155,7 @@ def infer_groot_n1_7_action_horizon(
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
processor_config = read_json(processor_config_path)
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
@@ -210,83 +197,6 @@ def infer_groot_n1_7_action_execution_horizon(
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
@@ -296,16 +206,7 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
return _infer_groot_model_version_from_config(read_json(config_path))
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
@@ -368,15 +269,15 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
# Groot-specific model parameters
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
|
||||
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
|
||||
# model *source*, and is intentionally distinct from the inherited `pretrained_path`:
|
||||
# `pretrained_path` (`--policy.path`) points at a saved LeRobot checkpoint directory whose
|
||||
# `config.json` carries a `type` field, whereas a raw NVIDIA GR00T checkpoint has no such
|
||||
# field and so can only be loaded through `base_model_path` (`--policy.base_model_path`).
|
||||
# Defaults to GROOT_N1_7_BASE_MODEL when unset (resolved in __post_init__).
|
||||
base_model_path: str | None = None
|
||||
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
@@ -399,20 +300,31 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Whether to fine-tune the diffusion model
|
||||
tune_diffusion_model: bool = True
|
||||
|
||||
# LoRA parameters (from groot_finetune_script.py)
|
||||
# Rank for the LORA model. If 0, no LORA will be used.
|
||||
lora_rank: int = 0
|
||||
# Whether to fine-tune the VL LayerNorm + VL self-attention projector in the action head.
|
||||
tune_vlln: bool = True
|
||||
|
||||
# Alpha value for the LORA model
|
||||
lora_alpha: int = 16
|
||||
# Number of top LLM backbone layers to fine-tune (0 = none). Lets you adapt just the final
|
||||
# language layers without unfreezing the whole backbone; independent of `tune_llm`, which tunes
|
||||
# the entire LLM.
|
||||
tune_top_llm_layers: int = 0
|
||||
|
||||
# Dropout rate for the LORA model
|
||||
lora_dropout: float = 0.1
|
||||
# Inference-time knob: Number of flow-matching denoising steps used to decode an action chunk.
|
||||
# Trades inference latency for action quality.
|
||||
# None keeps the checkpoint value (GR00T N1.7 default: 4).
|
||||
num_inference_timesteps: int | None = None
|
||||
|
||||
# Whether to use the full model for LORA
|
||||
lora_full_model: bool = False
|
||||
# Inference-time knob: Real-Time Chunking (RTC) overlap-blend ramp rate, used when the RTC engine
|
||||
# supplies a previous-chunk prefix. Higher values blend the overlapping prefix more aggressively.
|
||||
# None keeps the checkpoint value (GR00T N1.7 default: 6.0).
|
||||
rtc_ramp_rate: float | None = None
|
||||
|
||||
# Training parameters (matching groot_finetune_script.py)
|
||||
# Inference-time knob: Whether to request the flash-attention-2 kernel for the Qwen3-VL backbone.
|
||||
# flash-attn is an optional, user-managed optimization; when it is absent (the default),
|
||||
# the backbone transparently falls back to SDPA, which is numerically equivalent.
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
@@ -421,12 +333,18 @@ class GrootConfig(PreTrainedConfig):
|
||||
use_bf16: bool = True
|
||||
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
|
||||
# unused by the LeRobot N1.7 implementation except the `tokenizer_assets_repo` N1.5 tripwire and
|
||||
# the `image_size` legacy remap in __post_init__. They are kept ONLY so a config.json saved by an
|
||||
# earlier lerobot release (notably a GR00T N1.5 checkpoint) still parses under draccus — which
|
||||
# rejects unknown fields — and is then rejected with a clear N1.5 removal message rather than an
|
||||
# opaque draccus decoding error.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
lora_rank: int = 0
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
lora_full_model: bool = False
|
||||
video_backend: str = "decord"
|
||||
balance_dataset_weights: bool = True
|
||||
balance_trajectory_weights: bool = True
|
||||
@@ -446,7 +364,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
@@ -491,9 +408,9 @@ class GrootConfig(PreTrainedConfig):
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
@@ -507,9 +424,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# groot_repo_path is now optional since we ported the components
|
||||
# No validation needed
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features for Groot."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
@@ -35,7 +34,14 @@ from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionT
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
Qwen3VLConfig,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
@@ -43,25 +49,17 @@ else:
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
@@ -74,7 +72,7 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"tune_visual": False,
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"use_flash_attention": False,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
@@ -152,29 +150,10 @@ class GR00TN17Config(PretrainedConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
setattr(self, key, deepcopy(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
@@ -275,13 +254,7 @@ class Qwen3Backbone(nn.Module):
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
require_package("transformers", extra="groot")
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
@@ -361,11 +334,6 @@ class Qwen3Backbone(nn.Module):
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
@@ -762,11 +730,6 @@ def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
@@ -844,6 +807,7 @@ class GR00TN17(PreTrainedModel):
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
_register_with_transformers()
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
@@ -949,6 +913,12 @@ class GR00TN17(PreTrainedModel):
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
"""Register GR00T N1.7 with transformers' Auto* factories.
|
||||
|
||||
Idempotent: ``register(..., exist_ok=True)`` makes repeat calls no-ops (with a fallback that
|
||||
suppresses the already-registered error on transformers builds whose ``register()`` predates
|
||||
``exist_ok``), so no run-once guard is needed.
|
||||
"""
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
@@ -961,6 +931,3 @@ def _register_with_transformers() -> None:
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
|
||||
@@ -30,6 +30,9 @@ from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
@@ -46,8 +49,8 @@ from .configuration_groot import (
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -74,27 +77,30 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
"""Create and initialize the GR00T N1.7 model using the ported components."""
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
# Forwarded as a GR00TN17Config override; read back by set_trainable_parameters.
|
||||
"tune_top_llm_layers": self.config.tune_top_llm_layers,
|
||||
"use_flash_attention": self.config.use_flash_attention,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
# Surface the inference-time knobs onto the model config only when the user set them; None
|
||||
# leaves the value baked into the checkpoint untouched.
|
||||
if self.config.num_inference_timesteps is not None:
|
||||
model_kwargs["num_inference_timesteps"] = self.config.num_inference_timesteps
|
||||
if self.config.rtc_ramp_rate is not None:
|
||||
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
return GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
tune_vlln=self.config.tune_vlln,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
@@ -137,15 +143,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
Returns:
|
||||
Initialized GrootPolicy instance with loaded model
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
requested_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
@@ -199,10 +197,8 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
|
||||
@@ -225,11 +221,10 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
@@ -432,13 +427,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
@@ -469,26 +462,3 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
@@ -14,7 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
from copy import copy
|
||||
from dataclasses import dataclass, field, fields, is_dataclass
|
||||
@@ -28,12 +27,22 @@ from einops import rearrange
|
||||
from PIL import Image
|
||||
from torchvision.transforms import InterpolationMode
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import ProcessorMixin
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
ProcessorMixin,
|
||||
Qwen2VLImageProcessor,
|
||||
Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor,
|
||||
)
|
||||
else:
|
||||
AutoTokenizer = None
|
||||
ProcessorMixin = object
|
||||
Qwen2VLImageProcessor = None
|
||||
Qwen3VLProcessor = None
|
||||
Qwen3VLVideoProcessor = None
|
||||
|
||||
from lerobot.processor import (
|
||||
AbsoluteActionsProcessorStep,
|
||||
@@ -70,6 +79,19 @@ from .configuration_groot import (
|
||||
GrootConfig,
|
||||
is_raw_groot_n1_7_checkpoint,
|
||||
)
|
||||
from .utils import (
|
||||
as_int_pair,
|
||||
as_optional_float,
|
||||
as_optional_int,
|
||||
config_value,
|
||||
flatten_n1_7_modality_stats,
|
||||
has_modality_stats,
|
||||
infer_n1_7_batch_size_and_device,
|
||||
prepare_n1_7_language_batch,
|
||||
read_json,
|
||||
relative_eef_to_absolute,
|
||||
stat_dim_from_entry,
|
||||
)
|
||||
|
||||
N1_7_EMBODIMENT_MAPPING = {
|
||||
"oxe_droid_relative_eef_relative_joint": 24,
|
||||
@@ -128,12 +150,12 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
return None
|
||||
|
||||
checkpoint_path = Path(config.base_model_path).expanduser()
|
||||
processor_config = _read_json(checkpoint_path / "processor_config.json")
|
||||
processor_config = read_json(checkpoint_path / "processor_config.json")
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
processor_kwargs = {}
|
||||
|
||||
all_stats = _read_json(checkpoint_path / "statistics.json")
|
||||
all_stats = read_json(checkpoint_path / "statistics.json")
|
||||
raw_stats = all_stats.get(config.embodiment_tag)
|
||||
if not isinstance(raw_stats, dict):
|
||||
raw_stats = {}
|
||||
@@ -185,25 +207,16 @@ def _load_n1_7_checkpoint_processor_assets(config: GrootConfig) -> _GrootN17Chec
|
||||
use_relative_action=use_relative_action,
|
||||
clip_outliers=clip_outliers,
|
||||
video_modality_keys=video_modality_keys,
|
||||
image_crop_size=_as_int_pair(processor_kwargs.get("image_crop_size")),
|
||||
image_target_size=_as_int_pair(processor_kwargs.get("image_target_size")),
|
||||
shortest_image_edge=_as_optional_int(processor_kwargs.get("shortest_image_edge")),
|
||||
crop_fraction=_as_optional_float(processor_kwargs.get("crop_fraction")),
|
||||
image_crop_size=as_int_pair(processor_kwargs.get("image_crop_size")),
|
||||
image_target_size=as_int_pair(processor_kwargs.get("image_target_size")),
|
||||
shortest_image_edge=as_optional_int(processor_kwargs.get("shortest_image_edge")),
|
||||
crop_fraction=as_optional_float(processor_kwargs.get("crop_fraction")),
|
||||
use_albumentations=use_albumentations,
|
||||
)
|
||||
|
||||
|
||||
def _read_json(path: Path) -> dict[str, Any]:
|
||||
try:
|
||||
with path.open() as f:
|
||||
data = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def _load_n1_7_embodiment_mapping(checkpoint_path: Path) -> dict[str, int] | None:
|
||||
mapping = _read_json(checkpoint_path / "embodiment_id.json")
|
||||
mapping = read_json(checkpoint_path / "embodiment_id.json")
|
||||
if not mapping:
|
||||
return None
|
||||
parsed: dict[str, int] = {}
|
||||
@@ -234,7 +247,7 @@ def _load_n1_7_checkpoint_stats(
|
||||
"""
|
||||
|
||||
if raw_stats is None:
|
||||
all_stats = _read_json(checkpoint_path / "statistics.json")
|
||||
all_stats = read_json(checkpoint_path / "statistics.json")
|
||||
raw_stats = all_stats.get(embodiment_tag)
|
||||
if not isinstance(raw_stats, dict):
|
||||
return {}
|
||||
@@ -249,14 +262,14 @@ def _load_n1_7_checkpoint_stats(
|
||||
|
||||
use_percentiles = processor_kwargs.get("use_percentiles", False)
|
||||
return {
|
||||
OBS_STATE: _flatten_n1_7_modality_stats(
|
||||
OBS_STATE: flatten_n1_7_modality_stats(
|
||||
embodiment_stats=raw_stats,
|
||||
embodiment_config=modality_config,
|
||||
modality="state",
|
||||
use_percentiles=bool(use_percentiles),
|
||||
use_relative_action=use_relative_action,
|
||||
),
|
||||
ACTION: _flatten_n1_7_modality_stats(
|
||||
ACTION: flatten_n1_7_modality_stats(
|
||||
embodiment_stats=raw_stats,
|
||||
embodiment_config=modality_config,
|
||||
modality="action",
|
||||
@@ -324,134 +337,6 @@ def _load_n1_7_checkpoint_video_modality_keys(
|
||||
return keys or None
|
||||
|
||||
|
||||
def _as_int_pair(value: Any) -> list[int] | None:
|
||||
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
||||
return None
|
||||
try:
|
||||
return [int(value[0]), int(value[1])]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _as_optional_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _as_optional_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def _flatten_n1_7_modality_stats(
|
||||
*,
|
||||
embodiment_stats: dict[str, Any],
|
||||
embodiment_config: dict[str, Any],
|
||||
modality: str,
|
||||
use_percentiles: bool,
|
||||
use_relative_action: bool,
|
||||
) -> dict[str, list[float]]:
|
||||
"""Flatten one N1.7 modality's grouped statistics in checkpoint order.
|
||||
|
||||
When checkpoints request percentile normalization, q01/q99 replace min/max
|
||||
for regular groups. Relative action groups read from ``relative_action``
|
||||
stats and keep min/max, matching Isaac-GR00T's processor override.
|
||||
"""
|
||||
|
||||
source_stats = embodiment_stats.get(modality, {})
|
||||
modality_config = embodiment_config.get(modality, {})
|
||||
if not isinstance(source_stats, dict) or not isinstance(modality_config, dict):
|
||||
return {}
|
||||
modality_keys = modality_config.get("modality_keys", [])
|
||||
if not isinstance(modality_keys, list):
|
||||
return {}
|
||||
|
||||
flattened: dict[str, list[float]] = {}
|
||||
action_configs = modality_config.get("action_configs", []) if modality == "action" else []
|
||||
if not isinstance(action_configs, list):
|
||||
action_configs = []
|
||||
relative_stats = embodiment_stats.get("relative_action", {})
|
||||
if not isinstance(relative_stats, dict):
|
||||
relative_stats = {}
|
||||
|
||||
for stat_name in ("min", "max", "mean", "std"):
|
||||
values: list[float] = []
|
||||
source_stat_name = stat_name
|
||||
if use_percentiles and stat_name == "min":
|
||||
source_stat_name = "q01"
|
||||
elif use_percentiles and stat_name == "max":
|
||||
source_stat_name = "q99"
|
||||
|
||||
for idx, modality_key in enumerate(modality_keys):
|
||||
if not isinstance(modality_key, str):
|
||||
continue
|
||||
key_source_stats = source_stats
|
||||
key_stat_name = source_stat_name
|
||||
if modality == "action" and use_relative_action and idx < len(action_configs):
|
||||
action_config = action_configs[idx]
|
||||
if isinstance(action_config, dict) and _config_value(action_config.get("rep")) == "relative":
|
||||
key_source_stats = relative_stats
|
||||
key_stat_name = stat_name
|
||||
key_stats = key_source_stats.get(modality_key, {})
|
||||
if not isinstance(key_stats, dict):
|
||||
raise KeyError(f"Missing statistics for {modality}.{modality_key}")
|
||||
raw_values = key_stats.get(key_stat_name)
|
||||
if raw_values is None:
|
||||
raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}")
|
||||
values.extend(_as_float_list(raw_values))
|
||||
if values:
|
||||
flattened[stat_name] = values
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
def _as_float_list(values: Any) -> list[float]:
|
||||
if values is None:
|
||||
return []
|
||||
if isinstance(values, torch.Tensor):
|
||||
return values.detach().cpu().reshape(-1).float().tolist()
|
||||
if isinstance(values, np.ndarray):
|
||||
return values.reshape(-1).astype(np.float32).tolist()
|
||||
if isinstance(values, (list, tuple)):
|
||||
flattened: list[float] = []
|
||||
for value in values:
|
||||
flattened.extend(_as_float_list(value))
|
||||
return flattened
|
||||
return [float(values)]
|
||||
|
||||
|
||||
def _config_value(value: Any) -> str:
|
||||
if hasattr(value, "value"):
|
||||
value = value.value
|
||||
text = str(value).lower()
|
||||
return {
|
||||
"relative": "relative",
|
||||
"absolute": "absolute",
|
||||
"delta": "delta",
|
||||
"eef": "eef",
|
||||
"non_eef": "non_eef",
|
||||
"default": "default",
|
||||
"xyz_rot6d": "xyz+rot6d",
|
||||
"xyz+rot6d": "xyz+rot6d",
|
||||
"xyz_rotvec": "xyz+rotvec",
|
||||
"xyz+rotvec": "xyz+rotvec",
|
||||
}.get(text, text)
|
||||
|
||||
|
||||
def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
if not stats:
|
||||
return False
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
||||
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
||||
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
||||
@@ -603,6 +488,9 @@ def _load_groot_processor_pipelines(
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
# Register the GR00T N1.5 rejection stubs before deserializing, so a saved N1.5 pipeline
|
||||
# referencing their registry names fails with the canonical removal guidance.
|
||||
_register_removed_n1_5_step_stubs()
|
||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||
pretrained_model_name_or_path=pretrained_path,
|
||||
config_filename=preprocessor_config_filename,
|
||||
@@ -701,7 +589,7 @@ def make_groot_pre_post_processors(
|
||||
else action_horizon
|
||||
)
|
||||
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
|
||||
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
|
||||
checkpoint_has_stats = has_modality_stats(checkpoint_stats)
|
||||
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
|
||||
embodiment_mapping = (
|
||||
checkpoint_assets.embodiment_mapping
|
||||
@@ -757,7 +645,7 @@ def make_groot_pre_post_processors(
|
||||
AddBatchDimensionProcessorStep(),
|
||||
pack_step,
|
||||
GrootN17VLMEncodeStep(
|
||||
model_name=config.n1_7_backbone_model,
|
||||
model_name=GROOT_N1_7_BACKBONE_MODEL,
|
||||
image_crop_size=image_crop_size,
|
||||
image_target_size=image_target_size,
|
||||
shortest_image_edge=shortest_image_edge,
|
||||
@@ -768,7 +656,7 @@ def make_groot_pre_post_processors(
|
||||
DeviceProcessorStep(device=config.device),
|
||||
]
|
||||
|
||||
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
|
||||
if checkpoint_assets is not None and not checkpoint_has_stats and not has_modality_stats(padded_stats):
|
||||
raise ValueError(
|
||||
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
|
||||
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
|
||||
@@ -846,66 +734,8 @@ def _align_video_horizon(video: np.ndarray, horizon: int | None) -> np.ndarray:
|
||||
return np.concatenate([pad, video], axis=1)
|
||||
|
||||
|
||||
def _infer_n1_7_batch_size_and_device(
|
||||
obs: dict[str, Any], action: torch.Tensor | None
|
||||
) -> tuple[int, torch.device]:
|
||||
for value in list(obs.values()) + [action]:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.shape[0], value.device
|
||||
video = obs.get("video")
|
||||
if isinstance(video, np.ndarray):
|
||||
return video.shape[0], torch.device("cpu")
|
||||
return 1, torch.device("cpu")
|
||||
|
||||
|
||||
def _prepare_n1_7_language_batch(
|
||||
language: Any,
|
||||
batch_size: int,
|
||||
*,
|
||||
formalize_language: bool,
|
||||
) -> list[str]:
|
||||
default_language = "Perform the task."
|
||||
if language is None or (isinstance(language, str) and language == ""):
|
||||
languages = [default_language] * batch_size
|
||||
elif isinstance(language, str):
|
||||
languages = [language] * batch_size
|
||||
elif isinstance(language, (list, tuple)):
|
||||
languages = list(language)
|
||||
if len(languages) == 0:
|
||||
languages = [default_language] * batch_size
|
||||
elif len(languages) == 1 and batch_size > 1:
|
||||
languages = languages * batch_size
|
||||
elif len(languages) != batch_size:
|
||||
raise ValueError(
|
||||
f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}."
|
||||
)
|
||||
else:
|
||||
languages = [str(language)] * batch_size
|
||||
|
||||
formatted = []
|
||||
for item in languages:
|
||||
text = str(item) if item else default_language
|
||||
if formalize_language:
|
||||
text = text.lower()
|
||||
text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_")
|
||||
formatted.append(text)
|
||||
return formatted
|
||||
|
||||
|
||||
def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> ProcessorMixin:
|
||||
try:
|
||||
from transformers import (
|
||||
AutoTokenizer,
|
||||
Qwen2VLImageProcessor,
|
||||
Qwen3VLProcessor,
|
||||
Qwen3VLVideoProcessor,
|
||||
)
|
||||
except ImportError as exc:
|
||||
raise ImportError(
|
||||
"GR00T N1.7 preprocessing requires a transformers version with Qwen3-VL processor support. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
) from exc
|
||||
|
||||
require_package("transformers", extra="groot")
|
||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||
@@ -1201,8 +1031,8 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
for k in image_keys_to_remove:
|
||||
obs.pop(k, None)
|
||||
|
||||
bsz, _device = _infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION))
|
||||
comp["language"] = _prepare_n1_7_language_batch(
|
||||
bsz, _device = infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION))
|
||||
comp["language"] = prepare_n1_7_language_batch(
|
||||
comp.get(self.language_key),
|
||||
bsz,
|
||||
formalize_language=self.formalize_language,
|
||||
@@ -1269,7 +1099,7 @@ class GrootN17PackInputsStep(ProcessorStep):
|
||||
comp["action_mask"] = action_mask
|
||||
|
||||
emb_id = self.embodiment_mapping.get(self.embodiment_tag, 0)
|
||||
bsz, device = _infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION))
|
||||
bsz, device = infer_n1_7_batch_size_and_device(obs, transition.get(TransitionKey.ACTION))
|
||||
if "action_mask" not in comp:
|
||||
action_mask = torch.zeros(bsz, self.action_horizon, dtype=torch.float32, device=device)
|
||||
valid_horizon = min(self.valid_action_horizon, self.action_horizon)
|
||||
@@ -1432,7 +1262,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
return transition
|
||||
|
||||
batch_size = int(video.shape[0])
|
||||
languages = _prepare_n1_7_language_batch(
|
||||
languages = prepare_n1_7_language_batch(
|
||||
comp.get("language"),
|
||||
batch_size,
|
||||
formalize_language=False,
|
||||
@@ -1494,14 +1324,6 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
||||
}
|
||||
|
||||
|
||||
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||
value = entry.get(stat_name)
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
return len(value)
|
||||
return 0
|
||||
|
||||
|
||||
def _n1_7_decode_stats_for_action(
|
||||
raw_stats: dict[str, Any],
|
||||
key: str,
|
||||
@@ -1512,7 +1334,7 @@ def _n1_7_decode_stats_for_action(
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""Select the min/max arrays needed to decode one checkpoint action group."""
|
||||
|
||||
is_relative = use_relative_action and _config_value(action_config.get("rep")) == "relative"
|
||||
is_relative = use_relative_action and config_value(action_config.get("rep")) == "relative"
|
||||
modality = "relative_action" if is_relative else "action"
|
||||
stats = raw_stats.get(modality, {}).get(key, {})
|
||||
if not isinstance(stats, dict):
|
||||
@@ -1537,38 +1359,6 @@ def _n1_7_decode_valid_horizon(action_config: dict[str, Any], action_np: np.ndar
|
||||
return max(1, min(action_np.shape[1], len(delta_indices)))
|
||||
|
||||
|
||||
def _rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
|
||||
rows = rot6d.reshape(2, 3).astype(np.float64)
|
||||
row1 = rows[0] / np.linalg.norm(rows[0])
|
||||
row2 = rows[1] - np.dot(row1, rows[1]) * row1
|
||||
row2 = row2 / np.linalg.norm(row2)
|
||||
row3 = np.cross(row1, row2)
|
||||
return np.vstack([row1, row2, row3])
|
||||
|
||||
|
||||
def _xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray:
|
||||
transform = np.eye(4, dtype=np.float64)
|
||||
transform[:3, :3] = _rot6d_to_matrix(xyz_rot6d[3:])
|
||||
transform[:3, 3] = xyz_rot6d[:3]
|
||||
return transform
|
||||
|
||||
|
||||
def _homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray:
|
||||
return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0)
|
||||
|
||||
|
||||
def _relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray:
|
||||
"""Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses."""
|
||||
|
||||
out = np.empty_like(action, dtype=np.float64)
|
||||
for batch_idx in range(action.shape[0]):
|
||||
reference = _xyz_rot6d_to_homogeneous(reference_state[batch_idx])
|
||||
for timestep in range(action.shape[1]):
|
||||
relative = _xyz_rot6d_to_homogeneous(action[batch_idx, timestep])
|
||||
out[batch_idx, timestep] = _homogeneous_to_xyz_rot6d(reference @ relative)
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def _n1_7_action_group_slice(
|
||||
action_keys: list[Any], decoded_groups: dict[str, np.ndarray], target_key: str
|
||||
) -> slice:
|
||||
@@ -1677,7 +1467,7 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
stats_entry = self.raw_stats.get("action", {}).get(key, {})
|
||||
if not isinstance(stats_entry, dict):
|
||||
continue
|
||||
dim = _stat_dim_from_entry(stats_entry)
|
||||
dim = stat_dim_from_entry(stats_entry)
|
||||
if dim <= 0:
|
||||
continue
|
||||
cfg = (
|
||||
@@ -1716,18 +1506,18 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
||||
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
|
||||
continue
|
||||
cfg = action_configs[idx]
|
||||
if not isinstance(cfg, dict) or _config_value(cfg.get("rep")) != "relative":
|
||||
if not isinstance(cfg, dict) or config_value(cfg.get("rep")) != "relative":
|
||||
continue
|
||||
state_key = cfg.get("state_key") or key
|
||||
if state_key not in raw_state:
|
||||
raise KeyError(f"Missing cached raw state '{state_key}' for relative N1.7 action '{key}'")
|
||||
reference = raw_state[state_key]
|
||||
action_type = _config_value(cfg.get("type"))
|
||||
action_format = _config_value(cfg.get("format"))
|
||||
action_type = config_value(cfg.get("type"))
|
||||
action_format = config_value(cfg.get("format"))
|
||||
if action_type == "non_eef":
|
||||
decoded_groups[key] = decoded_groups[key] + reference[:, None, :]
|
||||
elif action_type == "eef" and action_format == "xyz+rot6d":
|
||||
decoded_groups[key] = _relative_eef_to_absolute(decoded_groups[key], reference)
|
||||
decoded_groups[key] = relative_eef_to_absolute(decoded_groups[key], reference)
|
||||
else:
|
||||
raise ValueError(f"Unsupported relative N1.7 action config for '{key}': {cfg}")
|
||||
|
||||
@@ -1886,15 +1676,27 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||
self.stats = reconstructed
|
||||
|
||||
|
||||
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
|
||||
# Registry names that only GR00T N1.5 processor pipelines serialize. Saved N1.5 checkpoints
|
||||
# reference these in their processor JSON, so deserializing one must fail with the canonical N1.5
|
||||
# removal guidance instead of an opaque registry KeyError (or, for
|
||||
# ``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose action-chunk
|
||||
# semantics changed).
|
||||
_REMOVED_N1_5_STEP_NAMES = (
|
||||
"groot_pack_inputs_v3",
|
||||
"groot_eagle_encode_v3",
|
||||
"groot_eagle_collate_v3",
|
||||
"groot_action_unpack_unnormalize_v1",
|
||||
)
|
||||
|
||||
Saved N1.5 checkpoints reference these registry names in their processor JSON
|
||||
files. Deserializing them must fail with the canonical N1.5 removal guidance
|
||||
instead of an opaque registry KeyError (or, for
|
||||
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
|
||||
action-chunk semantics changed).
|
||||
|
||||
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||
"""Register a single rejecting stub for a removed GR00T N1.5 processor step name.
|
||||
|
||||
Idempotent: ``ProcessorStepRegistry.register`` raises on a duplicate name, so already-registered
|
||||
names are skipped. This lets the caller re-run on every processor load without a run-once guard.
|
||||
"""
|
||||
if registry_name in ProcessorStepRegistry.list():
|
||||
return
|
||||
|
||||
@ProcessorStepRegistry.register(name=registry_name)
|
||||
class _RemovedGrootN15ProcessorStep(ProcessorStep):
|
||||
@@ -1911,10 +1713,12 @@ def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
for _removed_n1_5_step_name in (
|
||||
"groot_pack_inputs_v3",
|
||||
"groot_eagle_encode_v3",
|
||||
"groot_eagle_collate_v3",
|
||||
"groot_action_unpack_unnormalize_v1",
|
||||
):
|
||||
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)
|
||||
def _register_removed_n1_5_step_stubs() -> None:
|
||||
"""Register the GR00T N1.5 removal stubs, lazily.
|
||||
|
||||
Deferred from import time so importing this module has no global side effects; invoked just
|
||||
before a GR00T processor pipeline is deserialized (the only point at which a saved N1.5 pipeline
|
||||
could reference these registry names). Idempotent via :func:`_register_removed_n1_5_step_stub`.
|
||||
"""
|
||||
for registry_name in _REMOVED_N1_5_STEP_NAMES:
|
||||
_register_removed_n1_5_step_stub(registry_name)
|
||||
|
||||
@@ -0,0 +1,256 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared, side-effect-free utilities for the GR00T N1.7 policy.
|
||||
|
||||
These helpers are consumed by both the config layer (checkpoint sidecar
|
||||
inspection) and the processor layer (stat flattening, action decoding, language
|
||||
and image packing). They are pure functions with no GR00T-specific state so they
|
||||
can be unit-tested in isolation and reused without importing the heavier
|
||||
config/processor modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def read_json(path: Path) -> dict[str, Any]:
|
||||
"""Read a JSON object from ``path``, returning ``{}`` on any read/parse error."""
|
||||
try:
|
||||
with path.open() as f:
|
||||
data = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def as_int_pair(value: Any) -> list[int] | None:
|
||||
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
||||
return None
|
||||
try:
|
||||
return [int(value[0]), int(value[1])]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_optional_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_optional_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_float_list(values: Any) -> list[float]:
|
||||
if values is None:
|
||||
return []
|
||||
if isinstance(values, torch.Tensor):
|
||||
return values.detach().cpu().reshape(-1).float().tolist()
|
||||
if isinstance(values, np.ndarray):
|
||||
return values.reshape(-1).astype(np.float32).tolist()
|
||||
if isinstance(values, (list, tuple)):
|
||||
flattened: list[float] = []
|
||||
for value in values:
|
||||
flattened.extend(as_float_list(value))
|
||||
return flattened
|
||||
return [float(values)]
|
||||
|
||||
|
||||
def config_value(value: Any) -> str:
|
||||
if hasattr(value, "value"):
|
||||
value = value.value
|
||||
text = str(value).lower()
|
||||
return {
|
||||
"relative": "relative",
|
||||
"absolute": "absolute",
|
||||
"delta": "delta",
|
||||
"eef": "eef",
|
||||
"non_eef": "non_eef",
|
||||
"default": "default",
|
||||
"xyz_rot6d": "xyz+rot6d",
|
||||
"xyz+rot6d": "xyz+rot6d",
|
||||
"xyz_rotvec": "xyz+rotvec",
|
||||
"xyz+rotvec": "xyz+rotvec",
|
||||
}.get(text, text)
|
||||
|
||||
|
||||
def has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
if not stats:
|
||||
return False
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
def stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||
value = entry.get(stat_name)
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
return len(value)
|
||||
return 0
|
||||
|
||||
|
||||
def flatten_n1_7_modality_stats(
|
||||
*,
|
||||
embodiment_stats: dict[str, Any],
|
||||
embodiment_config: dict[str, Any],
|
||||
modality: str,
|
||||
use_percentiles: bool,
|
||||
use_relative_action: bool,
|
||||
) -> dict[str, list[float]]:
|
||||
"""Flatten one N1.7 modality's grouped statistics in checkpoint order.
|
||||
|
||||
When checkpoints request percentile normalization, q01/q99 replace min/max
|
||||
for regular groups. Relative action groups read from ``relative_action``
|
||||
stats and keep min/max, matching Isaac-GR00T's processor override.
|
||||
"""
|
||||
|
||||
source_stats = embodiment_stats.get(modality, {})
|
||||
modality_config = embodiment_config.get(modality, {})
|
||||
if not isinstance(source_stats, dict) or not isinstance(modality_config, dict):
|
||||
return {}
|
||||
modality_keys = modality_config.get("modality_keys", [])
|
||||
if not isinstance(modality_keys, list):
|
||||
return {}
|
||||
|
||||
flattened: dict[str, list[float]] = {}
|
||||
action_configs = modality_config.get("action_configs", []) if modality == "action" else []
|
||||
if not isinstance(action_configs, list):
|
||||
action_configs = []
|
||||
relative_stats = embodiment_stats.get("relative_action", {})
|
||||
if not isinstance(relative_stats, dict):
|
||||
relative_stats = {}
|
||||
|
||||
for stat_name in ("min", "max", "mean", "std"):
|
||||
values: list[float] = []
|
||||
source_stat_name = stat_name
|
||||
if use_percentiles and stat_name == "min":
|
||||
source_stat_name = "q01"
|
||||
elif use_percentiles and stat_name == "max":
|
||||
source_stat_name = "q99"
|
||||
|
||||
for idx, modality_key in enumerate(modality_keys):
|
||||
if not isinstance(modality_key, str):
|
||||
continue
|
||||
key_source_stats = source_stats
|
||||
key_stat_name = source_stat_name
|
||||
if modality == "action" and use_relative_action and idx < len(action_configs):
|
||||
action_config = action_configs[idx]
|
||||
if isinstance(action_config, dict) and config_value(action_config.get("rep")) == "relative":
|
||||
key_source_stats = relative_stats
|
||||
key_stat_name = stat_name
|
||||
key_stats = key_source_stats.get(modality_key, {})
|
||||
if not isinstance(key_stats, dict):
|
||||
raise KeyError(f"Missing statistics for {modality}.{modality_key}")
|
||||
raw_values = key_stats.get(key_stat_name)
|
||||
if raw_values is None:
|
||||
raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}")
|
||||
values.extend(as_float_list(raw_values))
|
||||
if values:
|
||||
flattened[stat_name] = values
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
def rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
|
||||
rows = rot6d.reshape(2, 3).astype(np.float64)
|
||||
row1 = rows[0] / np.linalg.norm(rows[0])
|
||||
row2 = rows[1] - np.dot(row1, rows[1]) * row1
|
||||
row2 = row2 / np.linalg.norm(row2)
|
||||
row3 = np.cross(row1, row2)
|
||||
return np.vstack([row1, row2, row3])
|
||||
|
||||
|
||||
def xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray:
|
||||
transform = np.eye(4, dtype=np.float64)
|
||||
transform[:3, :3] = rot6d_to_matrix(xyz_rot6d[3:])
|
||||
transform[:3, 3] = xyz_rot6d[:3]
|
||||
return transform
|
||||
|
||||
|
||||
def homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray:
|
||||
return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0)
|
||||
|
||||
|
||||
def relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray:
|
||||
"""Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses."""
|
||||
|
||||
out = np.empty_like(action, dtype=np.float64)
|
||||
for batch_idx in range(action.shape[0]):
|
||||
reference = xyz_rot6d_to_homogeneous(reference_state[batch_idx])
|
||||
for timestep in range(action.shape[1]):
|
||||
relative = xyz_rot6d_to_homogeneous(action[batch_idx, timestep])
|
||||
out[batch_idx, timestep] = homogeneous_to_xyz_rot6d(reference @ relative)
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def infer_n1_7_batch_size_and_device(
|
||||
obs: dict[str, Any], action: torch.Tensor | None
|
||||
) -> tuple[int, torch.device]:
|
||||
for value in list(obs.values()) + [action]:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.shape[0], value.device
|
||||
video = obs.get("video")
|
||||
if isinstance(video, np.ndarray):
|
||||
return video.shape[0], torch.device("cpu")
|
||||
return 1, torch.device("cpu")
|
||||
|
||||
|
||||
def prepare_n1_7_language_batch(
|
||||
language: Any,
|
||||
batch_size: int,
|
||||
*,
|
||||
formalize_language: bool,
|
||||
) -> list[str]:
|
||||
default_language = "Perform the task."
|
||||
if language is None or (isinstance(language, str) and language == ""):
|
||||
languages = [default_language] * batch_size
|
||||
elif isinstance(language, str):
|
||||
languages = [language] * batch_size
|
||||
elif isinstance(language, (list, tuple)):
|
||||
languages = list(language)
|
||||
if len(languages) == 0:
|
||||
languages = [default_language] * batch_size
|
||||
elif len(languages) == 1 and batch_size > 1:
|
||||
languages = languages * batch_size
|
||||
elif len(languages) != batch_size:
|
||||
raise ValueError(
|
||||
f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}."
|
||||
)
|
||||
else:
|
||||
languages = [str(language)] * batch_size
|
||||
|
||||
formatted = []
|
||||
for item in languages:
|
||||
text = str(item) if item else default_language
|
||||
if formalize_language:
|
||||
text = text.lower()
|
||||
text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_")
|
||||
formatted.append(text)
|
||||
return formatted
|
||||
@@ -1464,17 +1464,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flash-attn"
|
||||
version = "2.8.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "einops", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" },
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" }
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "25.12.19"
|
||||
@@ -2690,8 +2679,10 @@ all = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "datasets" },
|
||||
{ name = "debugpy" },
|
||||
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
|
||||
{ name = "deepdiff" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dm-tree" },
|
||||
{ name = "dynamixel-sdk" },
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
@@ -2739,6 +2730,7 @@ all = [
|
||||
{ name = "scikit-image" },
|
||||
{ name = "scipy" },
|
||||
{ name = "teleop" },
|
||||
{ name = "timm" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "torchdiffeq" },
|
||||
{ name = "transformers" },
|
||||
@@ -2843,8 +2835,6 @@ groot = [
|
||||
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dm-tree" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin'" },
|
||||
{ name = "ninja" },
|
||||
{ name = "peft" },
|
||||
{ name = "timm" },
|
||||
{ name = "transformers" },
|
||||
@@ -3087,7 +3077,6 @@ requires-dist = [
|
||||
{ name = "faker", marker = "extra == 'sarm'", specifier = ">=33.0.0,<35.0.0" },
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
@@ -3132,6 +3121,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["groot"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
|
||||
@@ -3224,7 +3214,6 @@ requires-dist = [
|
||||
{ name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" },
|
||||
{ name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" },
|
||||
{ name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" },
|
||||
{ name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" },
|
||||
{ name = "numpy", specifier = ">=2.0.0,<2.3.0" },
|
||||
{ name = "onnx", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" },
|
||||
@@ -4012,32 +4001,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ninja"
|
||||
version = "1.13.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.10.0"
|
||||
|
||||
Reference in New Issue
Block a user