mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 15:57:03 +00:00
Compare commits
9 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 4688b9c27f | |||
| 5753f8c18b | |||
| 97bd373d15 | |||
| 10a73e3c95 | |||
| 27c9288b24 | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec |
+20
-24
@@ -5,7 +5,7 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||
|
||||
> [!WARNING]
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints and configs are rejected with a migration note. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||
|
||||
## Model Overview
|
||||
|
||||
@@ -31,46 +31,43 @@ This approach allows the model to be highly adaptable through post-training for
|
||||
|
||||
## Installation Requirements
|
||||
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. The `groot` extra still includes Flash Attention on non-macOS platforms, and Flash Attention needs a compatible PyTorch/CUDA environment before it is installed. Install the dependencies in this order:
|
||||
GR00T is intended for NVIDIA GPU-accelerated systems. Install LeRobot with the GR00T extra:
|
||||
|
||||
1. Follow the Environment Setup in the [Installation Guide](./installation). Do not install `lerobot` yet.
|
||||
2. Install PyTorch, TorchVision, and the build dependencies used by Flash Attention:
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
### Optional: Flash Attention acceleration
|
||||
|
||||
Flash Attention is a purely optional performance optimization. **LeRobot neither installs nor requires it**, and setting it up is up to the user as it has environment-specific build requirements (a matching PyTorch/CUDA toolchain). To enable it:
|
||||
|
||||
1. Install a `flash-attn` build matching your PyTorch/CUDA environment (see the [Flash Attention project](https://github.com/Dao-AILab/flash-attention)):
|
||||
|
||||
```bash
|
||||
# Check https://pytorch.org/get-started/locally/ for the right CUDA wheel index for your system.
|
||||
pip install "torch>=2.7,<2.12.0" "torchvision>=0.22.0,<0.27.0" \
|
||||
--index-url https://download.pytorch.org/whl/cu128
|
||||
pip install "ninja>=1.11.1,<2.0.0" "packaging>=24.2,<26.0"
|
||||
```
|
||||
|
||||
3. Install and verify Flash Attention:
|
||||
|
||||
```bash
|
||||
pip install "flash-attn>=2.5.9,<3.0.0" --no-build-isolation
|
||||
python -c "import flash_attn; print(f'Flash Attention {flash_attn.__version__} imported successfully')"
|
||||
```
|
||||
|
||||
4. Install LeRobot with the GR00T extra:
|
||||
2. Install lerobot with the groot extra.
|
||||
|
||||
```bash
|
||||
pip install "lerobot[groot]"
|
||||
```
|
||||
|
||||
For a source checkout, use the same order, then install the local package with:
|
||||
|
||||
```bash
|
||||
pip install -e ".[groot]"
|
||||
```
|
||||
|
||||
If your CUDA/PyTorch build needs a different Flash Attention wheel or source build, follow the [Flash Attention project](https://github.com/Dao-AILab/flash-attention) instructions, but keep the same ordering: PyTorch first, Flash Attention next, then `lerobot[groot]`.
|
||||
3. Opt in by passing `--policy.use_flash_attention=true` when training/evaluating GR00T. If the kernel is missing or fails to import, the backbone transparently falls back to SDPA.
|
||||
|
||||
## Usage
|
||||
|
||||
To use GR00T N1.7:
|
||||
|
||||
```bash
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7
|
||||
--policy.type=groot
|
||||
```
|
||||
|
||||
## Training
|
||||
@@ -142,7 +139,6 @@ hf download nvidia/GR00T-N1.7-LIBERO \
|
||||
|
||||
lerobot-eval \
|
||||
--policy.type=groot \
|
||||
--policy.model_version=n1.7 \
|
||||
--policy.base_model_path=./GR00T-N1.7-LIBERO/libero_spatial \
|
||||
--policy.embodiment_tag=libero_sim \
|
||||
--env.type=libero \
|
||||
|
||||
+1
-3
@@ -208,8 +208,6 @@ groot = [
|
||||
"dm-tree>=0.1.8,<1.0.0",
|
||||
"timm>=1.0.0,<1.1.0",
|
||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
@@ -280,7 +278,7 @@ all = [
|
||||
"lerobot[pi]",
|
||||
"lerobot[molmoact2]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[groot]",
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[vla_jepa]",
|
||||
|
||||
@@ -14,9 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
@@ -24,6 +22,8 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
from .utils import read_json
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_N1_7 = "n1.7"
|
||||
@@ -42,6 +42,10 @@ GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||
)
|
||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||
@@ -123,12 +127,7 @@ def is_raw_groot_n1_7_checkpoint(model_path: str | Path | None) -> bool:
|
||||
else:
|
||||
return False
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return False
|
||||
|
||||
config = read_json(config_path)
|
||||
return "type" not in config and _infer_groot_model_version_from_config(config) == GROOT_N1_7
|
||||
|
||||
|
||||
@@ -137,11 +136,7 @@ def infer_groot_n1_7_embodiment_tag(model_path: str | Path | None) -> str | None
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
processor_config = read_json(processor_config_path)
|
||||
|
||||
modality_configs = processor_config.get("processor_kwargs", {}).get("modality_configs", {})
|
||||
if not isinstance(modality_configs, dict):
|
||||
@@ -160,11 +155,7 @@ def infer_groot_n1_7_action_horizon(
|
||||
return None
|
||||
|
||||
processor_config_path = Path(model_path).expanduser() / "processor_config.json"
|
||||
try:
|
||||
with processor_config_path.open() as f:
|
||||
processor_config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
processor_config = read_json(processor_config_path)
|
||||
|
||||
processor_kwargs = processor_config.get("processor_kwargs", {})
|
||||
if not isinstance(processor_kwargs, dict):
|
||||
@@ -206,83 +197,6 @@ def infer_groot_n1_7_action_execution_horizon(
|
||||
return action_horizon
|
||||
|
||||
|
||||
def resolve_groot_n1_7_backbone_model(model_name: str, cache_dir: str | Path | None = None) -> str:
|
||||
model_path = Path(model_name).expanduser()
|
||||
if model_path.exists():
|
||||
return str(model_path)
|
||||
|
||||
cached_snapshot = _find_cached_hf_snapshot(model_name, cache_dir=cache_dir)
|
||||
return str(cached_snapshot) if cached_snapshot is not None else model_name
|
||||
|
||||
|
||||
def _find_cached_hf_snapshot(repo_id: str, cache_dir: str | Path | None = None) -> Path | None:
|
||||
repo_cache_name = f"models--{repo_id.replace('/', '--')}"
|
||||
required_files = (
|
||||
"config.json",
|
||||
"tokenizer_config.json",
|
||||
"preprocessor_config.json",
|
||||
"video_preprocessor_config.json",
|
||||
)
|
||||
|
||||
for hub_cache in _candidate_hf_hub_caches(cache_dir):
|
||||
repo_cache = hub_cache / repo_cache_name
|
||||
snapshots_dir = repo_cache / "snapshots"
|
||||
if not snapshots_dir.is_dir():
|
||||
continue
|
||||
|
||||
candidates: list[Path] = []
|
||||
ref_path = repo_cache / "refs" / "main"
|
||||
try:
|
||||
ref = ref_path.read_text().strip()
|
||||
except OSError:
|
||||
ref = ""
|
||||
if ref:
|
||||
candidates.append(snapshots_dir / ref)
|
||||
candidates.extend(
|
||||
sorted(
|
||||
(path for path in snapshots_dir.iterdir() if path.is_dir()),
|
||||
key=lambda path: path.stat().st_mtime,
|
||||
reverse=True,
|
||||
)
|
||||
)
|
||||
|
||||
seen: set[Path] = set()
|
||||
for snapshot in candidates:
|
||||
if snapshot in seen:
|
||||
continue
|
||||
seen.add(snapshot)
|
||||
if all((snapshot / filename).exists() for filename in required_files):
|
||||
return snapshot
|
||||
return None
|
||||
|
||||
|
||||
def _candidate_hf_hub_caches(cache_dir: str | Path | None) -> list[Path]:
|
||||
candidates: list[Path] = []
|
||||
if cache_dir is not None:
|
||||
cache_path = Path(cache_dir).expanduser()
|
||||
candidates.append(cache_path)
|
||||
candidates.append(cache_path / "hub")
|
||||
|
||||
hub_cache = os.environ.get("HUGGINGFACE_HUB_CACHE")
|
||||
if hub_cache:
|
||||
candidates.append(Path(hub_cache).expanduser())
|
||||
|
||||
hf_home = os.environ.get("HF_HOME")
|
||||
if hf_home:
|
||||
candidates.append(Path(hf_home).expanduser() / "hub")
|
||||
|
||||
candidates.append(Path.home() / ".cache" / "huggingface" / "hub")
|
||||
|
||||
deduped: list[Path] = []
|
||||
seen: set[Path] = set()
|
||||
for candidate in candidates:
|
||||
resolved = candidate.resolve() if candidate.exists() else candidate
|
||||
if resolved not in seen:
|
||||
seen.add(resolved)
|
||||
deduped.append(candidate)
|
||||
return deduped
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
path = Path(model_path).expanduser()
|
||||
if path.is_dir():
|
||||
@@ -292,16 +206,7 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
||||
else:
|
||||
return None
|
||||
|
||||
if not config_path.exists():
|
||||
return None
|
||||
|
||||
try:
|
||||
with config_path.open() as f:
|
||||
config = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return None
|
||||
|
||||
return _infer_groot_model_version_from_config(config)
|
||||
return _infer_groot_model_version_from_config(read_json(config_path))
|
||||
|
||||
|
||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
@@ -321,9 +226,6 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||
normalized = candidate.lower().replace("-", "_")
|
||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||
return GROOT_N1_7
|
||||
# nvidia/GR00T-N1.5-3B ships model_type 'gr00t_n1_5' and architectures ['GR00T_N1_5'].
|
||||
# Recognise them so N1.5 checkpoints at generic local paths are rejected loudly
|
||||
# instead of being silently treated as N1.7 (see infer_groot_model_version).
|
||||
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||
return GROOT_N1_5
|
||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||
@@ -365,31 +267,22 @@ class GrootConfig(PreTrainedConfig):
|
||||
}
|
||||
)
|
||||
|
||||
# Deprecated and unused: image sizing is handled by the backbone's image processor.
|
||||
# Kept only so config.json files saved with earlier versions still parse.
|
||||
image_size: tuple[int, int] = (256, 256)
|
||||
# Groot-specific model parameters
|
||||
|
||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
||||
|
||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||
model_version: str = GROOT_N1_7
|
||||
|
||||
# Path or HuggingFace model ID for the base Groot model
|
||||
# Path or HuggingFace model ID for the base GR00T N1.7 model whose backbone weights and
|
||||
# checkpoint sidecars (statistics.json, processor_config.json, ...) are loaded. This is the
|
||||
# model *source*, and is intentionally distinct from the inherited `pretrained_path`:
|
||||
# `pretrained_path` (`--policy.path`) points at a saved LeRobot checkpoint directory whose
|
||||
# `config.json` carries a `type` field, whereas a raw NVIDIA GR00T checkpoint has no such
|
||||
# field and so can only be loaded through `base_model_path` (`--policy.base_model_path`).
|
||||
# Defaults to GROOT_N1_7_BASE_MODEL when unset (resolved in __post_init__).
|
||||
base_model_path: str | None = None
|
||||
|
||||
# HF repo ID (or local path) for the GR00T N1.7 Cosmos/Qwen3-VL backbone processor.
|
||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||
|
||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||
|
||||
# Deprecated, GR00T N1.5 only — do not set. Kept so config.json files saved by lerobot<=0.5.1
|
||||
# still parse (draccus rejects unknown fields) and can be rejected in __post_init__ with a
|
||||
# clear error pointing at GROOT_N1_5_REMOVAL_GUIDANCE instead of a cryptic DecodingError.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
|
||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||
embodiment_tag: str = "new_embodiment"
|
||||
|
||||
@@ -407,20 +300,31 @@ class GrootConfig(PreTrainedConfig):
|
||||
# Whether to fine-tune the diffusion model
|
||||
tune_diffusion_model: bool = True
|
||||
|
||||
# LoRA parameters (from groot_finetune_script.py)
|
||||
# Rank for the LORA model. If 0, no LORA will be used.
|
||||
lora_rank: int = 0
|
||||
# Whether to fine-tune the VL LayerNorm + VL self-attention projector in the action head.
|
||||
tune_vlln: bool = True
|
||||
|
||||
# Alpha value for the LORA model
|
||||
lora_alpha: int = 16
|
||||
# Number of top LLM backbone layers to fine-tune (0 = none). Lets you adapt just the final
|
||||
# language layers without unfreezing the whole backbone; independent of `tune_llm`, which tunes
|
||||
# the entire LLM.
|
||||
tune_top_llm_layers: int = 0
|
||||
|
||||
# Dropout rate for the LORA model
|
||||
lora_dropout: float = 0.1
|
||||
# Inference-time knob: Number of flow-matching denoising steps used to decode an action chunk.
|
||||
# Trades inference latency for action quality.
|
||||
# None keeps the checkpoint value (GR00T N1.7 default: 4).
|
||||
num_inference_timesteps: int | None = None
|
||||
|
||||
# Whether to use the full model for LORA
|
||||
lora_full_model: bool = False
|
||||
# Inference-time knob: Real-Time Chunking (RTC) overlap-blend ramp rate, used when the RTC engine
|
||||
# supplies a previous-chunk prefix. Higher values blend the overlapping prefix more aggressively.
|
||||
# None keeps the checkpoint value (GR00T N1.7 default: 6.0).
|
||||
rtc_ramp_rate: float | None = None
|
||||
|
||||
# Training parameters (matching groot_finetune_script.py)
|
||||
# Inference-time knob: Whether to request the flash-attention-2 kernel for the Qwen3-VL backbone.
|
||||
# flash-attn is an optional, user-managed optimization; when it is absent (the default),
|
||||
# the backbone transparently falls back to SDPA, which is numerically equivalent.
|
||||
# Set to True only after installing a flash-attn build matching your torch/CUDA env.
|
||||
use_flash_attention: bool = False
|
||||
|
||||
# Training parameters
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
@@ -428,10 +332,19 @@ class GrootConfig(PreTrainedConfig):
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
# Deprecated Isaac-GR00T runner fields below — unused by the LeRobot N1.7 implementation
|
||||
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||
# would break every previously saved groot checkpoint at config-load time.
|
||||
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||
# Deprecated Isaac-GR00T runner / GR00T N1.5 fields, plus the (never-wired) LoRA fields — all
|
||||
# unused by the LeRobot N1.7 implementation except the `tokenizer_assets_repo` N1.5 tripwire and
|
||||
# the `image_size` legacy remap in __post_init__. They are kept ONLY so a config.json saved by an
|
||||
# earlier lerobot release (notably a GR00T N1.5 checkpoint) still parses under draccus — which
|
||||
# rejects unknown fields — and is then rejected with a clear N1.5 removal message rather than an
|
||||
# opaque draccus decoding error.
|
||||
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||
tokenizer_assets_repo: str | None = None
|
||||
lora_rank: int = 0
|
||||
lora_alpha: int = 16
|
||||
lora_dropout: float = 0.1
|
||||
lora_full_model: bool = False
|
||||
video_backend: str = "decord"
|
||||
balance_dataset_weights: bool = True
|
||||
balance_trajectory_weights: bool = True
|
||||
@@ -445,16 +358,12 @@ class GrootConfig(PreTrainedConfig):
|
||||
resume: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# 'tokenizer_assets_repo' only ever existed for GR00T N1.5 (lerobot<=0.5.1) and was
|
||||
# serialized into every groot checkpoint config.json, so a value here means a legacy
|
||||
# N1.5 checkpoint or config is being loaded.
|
||||
if self.tokenizer_assets_repo is not None:
|
||||
raise ValueError(
|
||||
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||
)
|
||||
|
||||
self.model_version = normalize_groot_model_version(self.model_version)
|
||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||
if self.base_model_path is None:
|
||||
self.base_model_path = GROOT_N1_7_BASE_MODEL
|
||||
@@ -499,9 +408,9 @@ class GrootConfig(PreTrainedConfig):
|
||||
setattr(self, field_name, n1_7_value)
|
||||
|
||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||
if inferred_version is not None and inferred_version != self.model_version:
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
message = (
|
||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
@@ -515,9 +424,6 @@ class GrootConfig(PreTrainedConfig):
|
||||
f"n_action_steps ({self.n_action_steps}) cannot exceed chunk_size ({self.chunk_size})"
|
||||
)
|
||||
|
||||
# groot_repo_path is now optional since we ported the components
|
||||
# No validation needed
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate and set up input/output features for Groot."""
|
||||
image_features = [key for key, feat in self.input_features.items() if feat.type == FeatureType.VISUAL]
|
||||
@@ -582,22 +488,11 @@ class GrootConfig(PreTrainedConfig):
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
"""Return indices for delta actions.
|
||||
|
||||
The model action horizon is read from the checkpoint's processor_config.json
|
||||
when available; the result is cached (keyed on the inputs that determine it) so
|
||||
repeated access during dataset/training setup does not re-read from disk.
|
||||
"""
|
||||
cache_key = (self.base_model_path, self.embodiment_tag, self.chunk_size)
|
||||
cached = getattr(self, "_action_delta_indices_cache", None)
|
||||
if cached is not None and cached[0] == cache_key:
|
||||
return cached[1]
|
||||
"""Return indices for delta actions."""
|
||||
model_action_horizon = (
|
||||
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||
)
|
||||
indices = list(range(min(self.chunk_size, model_action_horizon)))
|
||||
object.__setattr__(self, "_action_delta_indices_cache", (cache_key, indices))
|
||||
return indices
|
||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
|
||||
@@ -16,7 +16,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import json
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from copy import deepcopy
|
||||
@@ -32,9 +31,17 @@ from torch.distributions import Beta
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||
from transformers import (
|
||||
AutoConfig,
|
||||
AutoModel,
|
||||
PretrainedConfig,
|
||||
PreTrainedModel,
|
||||
Qwen3VLConfig,
|
||||
Qwen3VLForConditionalGeneration,
|
||||
)
|
||||
from transformers.feature_extraction_utils import BatchFeature
|
||||
else:
|
||||
AutoConfig = None
|
||||
@@ -42,25 +49,17 @@ else:
|
||||
PretrainedConfig = object
|
||||
PreTrainedModel = object
|
||||
BatchFeature = None
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
try:
|
||||
import tree
|
||||
except ImportError:
|
||||
tree = None
|
||||
|
||||
try:
|
||||
from transformers import Qwen3VLConfig, Qwen3VLForConditionalGeneration
|
||||
except ImportError:
|
||||
Qwen3VLConfig = None
|
||||
Qwen3VLForConditionalGeneration = None
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def _copy_default(value: Any) -> Any:
|
||||
return deepcopy(value)
|
||||
|
||||
|
||||
GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"model_dtype": "bfloat16",
|
||||
"dtype": "bfloat16",
|
||||
@@ -71,13 +70,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
||||
"backbone_embedding_dim": 2048,
|
||||
"tune_llm": False,
|
||||
"tune_visual": False,
|
||||
"select_layer": 16, # N1.7-3B checkpoint value; real checkpoint loads override this from config.json
|
||||
"select_layer": 16,
|
||||
"reproject_vision": False,
|
||||
"use_flash_attention": True,
|
||||
"use_flash_attention": False,
|
||||
"load_bf16": False,
|
||||
"backbone_trainable_params_fp32": True,
|
||||
"image_crop_size": (230, 230),
|
||||
"image_target_size": (256, 256),
|
||||
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||
"shortest_image_edge": None,
|
||||
"crop_fraction": None,
|
||||
"random_rotation_angle": None,
|
||||
@@ -151,29 +150,10 @@ class GR00TN17Config(PretrainedConfig):
|
||||
def __init__(self, **kwargs):
|
||||
super().__init__(**kwargs)
|
||||
for key, value in GR00T_N1_7_DEFAULTS.items():
|
||||
setattr(self, key, _copy_default(kwargs.pop(key, value)))
|
||||
setattr(self, key, deepcopy(kwargs.pop(key, value)))
|
||||
for key, value in kwargs.items():
|
||||
setattr(self, key, value)
|
||||
|
||||
def to_filtered_dict(self, exclude_augment: bool = True) -> dict[str, Any]:
|
||||
cfg = self.to_dict()
|
||||
if not exclude_augment:
|
||||
return cfg
|
||||
exclude_keys = {
|
||||
"random_rotation_angle",
|
||||
"color_jitter_params",
|
||||
"use_albumentations_transforms",
|
||||
"formalize_language",
|
||||
"image_crop_size",
|
||||
"image_target_size",
|
||||
"shortest_image_edge",
|
||||
"crop_fraction",
|
||||
}
|
||||
return {k: v for k, v in cfg.items() if k not in exclude_keys}
|
||||
|
||||
def to_filtered_json(self, exclude_augment: bool = True, **kwargs) -> str:
|
||||
return json.dumps(self.to_filtered_dict(exclude_augment), indent=2, default=str, **kwargs)
|
||||
|
||||
|
||||
class CategorySpecificLinear(nn.Module):
|
||||
"""Linear layer with category-specific weights for multi-embodiment support."""
|
||||
@@ -274,13 +254,7 @@ class Qwen3Backbone(nn.Module):
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_pretrained_weights: bool = True,
|
||||
):
|
||||
if Qwen3VLForConditionalGeneration is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLForConditionalGeneration is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'` "
|
||||
"or use a transformers version that provides Qwen3-VL support."
|
||||
)
|
||||
|
||||
require_package("transformers", extra="groot")
|
||||
super().__init__()
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
|
||||
@@ -360,11 +334,6 @@ class Qwen3Backbone(nn.Module):
|
||||
if _is_cosmos_reason2_backbone(model_name):
|
||||
backbone_config = _cosmos_reason2_qwen3_vl_config()
|
||||
else:
|
||||
if AutoConfig is None:
|
||||
raise ImportError(
|
||||
"AutoConfig is required to initialize a GR00T N1.7 backbone from config. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
backbone_config = AutoConfig.from_pretrained(model_name, **config_kwargs)
|
||||
return Qwen3VLForConditionalGeneration._from_config(backbone_config, **model_kwargs)
|
||||
|
||||
@@ -761,11 +730,6 @@ def _is_cosmos_reason2_backbone(model_name: str) -> bool:
|
||||
|
||||
|
||||
def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
||||
if Qwen3VLConfig is None:
|
||||
raise ImportError(
|
||||
"Qwen3VLConfig is required for GR00T N1.7. "
|
||||
"Install the GR00T optional dependencies with `pip install 'lerobot[groot]'`."
|
||||
)
|
||||
return Qwen3VLConfig(
|
||||
image_token_id=151655,
|
||||
video_token_id=151656,
|
||||
@@ -822,8 +786,6 @@ def get_backbone_cls(config: GR00TN17Config):
|
||||
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||
return Qwen3Backbone
|
||||
if config.backbone_model_type == "qwen":
|
||||
# Local backbone checkpoints (e.g. hub-cache snapshot paths) contain neither hub
|
||||
# marker, so trust the explicit backbone type but surface what is being assumed.
|
||||
logger.warning(
|
||||
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||
"backbone because backbone_model_type='qwen'.",
|
||||
@@ -845,6 +807,7 @@ class GR00TN17(PreTrainedModel):
|
||||
transformers_loading_kwargs: dict[str, Any] | None = None,
|
||||
load_backbone_weights: bool = True,
|
||||
):
|
||||
_register_with_transformers()
|
||||
super().__init__(config)
|
||||
transformers_loading_kwargs = transformers_loading_kwargs or {"trust_remote_code": True}
|
||||
self.config = config
|
||||
@@ -914,10 +877,6 @@ class GR00TN17(PreTrainedModel):
|
||||
"trust_remote_code": True
|
||||
}
|
||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||
# Only repo-agnostic hub kwargs are forwarded to the backbone loading kwargs:
|
||||
# ``revision`` pins the GR00T checkpoint repo (see snapshot_download below) and would
|
||||
# be invalid for the unrelated backbone repo (``config.model_name``). Pin the backbone
|
||||
# itself by passing ``revision`` inside ``transformers_loading_kwargs``.
|
||||
for key in ("cache_dir", "local_files_only", "token"):
|
||||
if key in kwargs:
|
||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||
@@ -954,6 +913,12 @@ class GR00TN17(PreTrainedModel):
|
||||
|
||||
|
||||
def _register_with_transformers() -> None:
|
||||
"""Register GR00T N1.7 with transformers' Auto* factories.
|
||||
|
||||
Idempotent: ``register(..., exist_ok=True)`` makes repeat calls no-ops (with a fallback that
|
||||
suppresses the already-registered error on transformers builds whose ``register()`` predates
|
||||
``exist_ok``), so no run-once guard is needed.
|
||||
"""
|
||||
if AutoConfig is None or AutoModel is None:
|
||||
return
|
||||
try:
|
||||
@@ -966,6 +931,3 @@ def _register_with_transformers() -> None:
|
||||
except TypeError:
|
||||
with suppress(ValueError):
|
||||
AutoModel.register(GR00TN17Config, GR00TN17)
|
||||
|
||||
|
||||
_register_with_transformers()
|
||||
|
||||
@@ -30,6 +30,9 @@ from pathlib import Path
|
||||
from typing import TypeVar
|
||||
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
@@ -46,8 +49,8 @@ from .configuration_groot import (
|
||||
infer_groot_model_version,
|
||||
infer_groot_n1_7_action_execution_horizon,
|
||||
infer_groot_n1_7_action_horizon,
|
||||
normalize_groot_model_version,
|
||||
)
|
||||
from .groot_n1_7 import GR00TN17
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -74,33 +77,30 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
self.reset()
|
||||
|
||||
def _create_groot_model(self):
|
||||
"""Create and initialize the GR00T N1.7 model using Isaac-GR00T APIs."""
|
||||
# Handle Flash Attention compatibility issues
|
||||
self._handle_flash_attention_compatibility()
|
||||
|
||||
"""Create and initialize the GR00T N1.7 model using the ported components."""
|
||||
model_kwargs = {
|
||||
"pretrained_model_name_or_path": self.config.base_model_path,
|
||||
"tune_llm": self.config.tune_llm,
|
||||
"tune_visual": self.config.tune_visual,
|
||||
"tune_projector": self.config.tune_projector,
|
||||
"tune_diffusion_model": self.config.tune_diffusion_model,
|
||||
# Forwarded as a GR00TN17Config override; read back by set_trainable_parameters.
|
||||
"tune_top_llm_layers": self.config.tune_top_llm_layers,
|
||||
"use_flash_attention": self.config.use_flash_attention,
|
||||
}
|
||||
from .groot_n1_7 import GR00TN17
|
||||
# Surface the inference-time knobs onto the model config only when the user set them; None
|
||||
# leaves the value baked into the checkpoint untouched.
|
||||
if self.config.num_inference_timesteps is not None:
|
||||
model_kwargs["num_inference_timesteps"] = self.config.num_inference_timesteps
|
||||
if self.config.rtc_ramp_rate is not None:
|
||||
model_kwargs["rtc_ramp_rate"] = self.config.rtc_ramp_rate
|
||||
|
||||
model = GR00TN17.from_pretrained(
|
||||
return GR00TN17.from_pretrained(
|
||||
**model_kwargs,
|
||||
tune_vlln=True,
|
||||
tune_vlln=self.config.tune_vlln,
|
||||
transformers_loading_kwargs={"trust_remote_code": True},
|
||||
)
|
||||
|
||||
# GR00TN17 defines no compute_dtype attribute, so only record the
|
||||
# bf16 preference when it is enabled instead of reading a default back.
|
||||
if self.config.use_bf16:
|
||||
model.compute_dtype = "bfloat16"
|
||||
model.config.compute_dtype = "bfloat16"
|
||||
|
||||
return model
|
||||
|
||||
def reset(self):
|
||||
"""Reset policy state when environment resets."""
|
||||
self._action_queue = deque([], maxlen=self._action_queue_steps)
|
||||
@@ -143,15 +143,7 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
Returns:
|
||||
Initialized GrootPolicy instance with loaded model
|
||||
"""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
requested_version = (
|
||||
normalize_groot_model_version(config.model_version)
|
||||
if config is not None
|
||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
)
|
||||
requested_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
logger.info(
|
||||
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||
requested_version,
|
||||
@@ -205,10 +197,8 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||
|
||||
if config is None:
|
||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||
# Create default config with the pretrained path
|
||||
config = GrootConfig(
|
||||
model_version=model_version,
|
||||
base_model_path=str(pretrained_name_or_path),
|
||||
)
|
||||
|
||||
@@ -231,11 +221,10 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
if hasattr(config, key):
|
||||
setattr(config, key, value)
|
||||
|
||||
config.model_version = normalize_groot_model_version(config.model_version)
|
||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||
if inferred_version is not None and inferred_version != config.model_version:
|
||||
if inferred_version is not None and inferred_version != GROOT_N1_7:
|
||||
message = (
|
||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||
f"GR00T model_version '{GROOT_N1_7}' does not match base_model_path "
|
||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||
)
|
||||
if inferred_version == GROOT_N1_5:
|
||||
@@ -438,13 +427,11 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
# During inference, we do not pass action because it is predicted.
|
||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||
groot_inputs = self._filter_groot_inputs(batch, include_action=False)
|
||||
groot_options = None
|
||||
if self.config.model_version == GROOT_N1_7:
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
groot_inputs, groot_options = self._prepare_n1_7_rtc_inputs(
|
||||
groot_inputs,
|
||||
inference_delay=kwargs.get("inference_delay"),
|
||||
prev_chunk_left_over=kwargs.get("prev_chunk_left_over"),
|
||||
)
|
||||
|
||||
# Get device from model parameters
|
||||
device = get_device_from_parameters(self)
|
||||
@@ -475,26 +462,3 @@ class GrootPolicy(PreTrainedPolicy):
|
||||
actions = self.predict_action_chunk(batch)
|
||||
self._action_queue.extend(actions[:, : self._action_queue_steps].transpose(0, 1))
|
||||
return self._action_queue.popleft()
|
||||
|
||||
# -------------------------
|
||||
# Internal helpers
|
||||
# -------------------------
|
||||
def _handle_flash_attention_compatibility(self) -> None:
|
||||
"""Log Flash Attention availability (diagnostic only).
|
||||
|
||||
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||
change behaviour or mutate global state.
|
||||
"""
|
||||
try:
|
||||
import flash_attn
|
||||
|
||||
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||
except ImportError:
|
||||
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||
except Exception as e: # noqa: BLE001
|
||||
logger.warning(
|
||||
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||
e,
|
||||
)
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,256 @@
|
||||
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||||
# SPDX-License-Identifier: Apache-2.0
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Shared, side-effect-free utilities for the GR00T N1.7 policy.
|
||||
|
||||
These helpers are consumed by both the config layer (checkpoint sidecar
|
||||
inspection) and the processor layer (stat flattening, action decoding, language
|
||||
and image packing). They are pure functions with no GR00T-specific state so they
|
||||
can be unit-tested in isolation and reused without importing the heavier
|
||||
config/processor modules.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
|
||||
def read_json(path: Path) -> dict[str, Any]:
|
||||
"""Read a JSON object from ``path``, returning ``{}`` on any read/parse error."""
|
||||
try:
|
||||
with path.open() as f:
|
||||
data = json.load(f)
|
||||
except (OSError, json.JSONDecodeError):
|
||||
return {}
|
||||
return data if isinstance(data, dict) else {}
|
||||
|
||||
|
||||
def as_int_pair(value: Any) -> list[int] | None:
|
||||
if not isinstance(value, (list, tuple)) or len(value) != 2:
|
||||
return None
|
||||
try:
|
||||
return [int(value[0]), int(value[1])]
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_optional_int(value: Any) -> int | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return int(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_optional_float(value: Any) -> float | None:
|
||||
if value is None:
|
||||
return None
|
||||
try:
|
||||
return float(value)
|
||||
except (TypeError, ValueError):
|
||||
return None
|
||||
|
||||
|
||||
def as_float_list(values: Any) -> list[float]:
|
||||
if values is None:
|
||||
return []
|
||||
if isinstance(values, torch.Tensor):
|
||||
return values.detach().cpu().reshape(-1).float().tolist()
|
||||
if isinstance(values, np.ndarray):
|
||||
return values.reshape(-1).astype(np.float32).tolist()
|
||||
if isinstance(values, (list, tuple)):
|
||||
flattened: list[float] = []
|
||||
for value in values:
|
||||
flattened.extend(as_float_list(value))
|
||||
return flattened
|
||||
return [float(values)]
|
||||
|
||||
|
||||
def config_value(value: Any) -> str:
|
||||
if hasattr(value, "value"):
|
||||
value = value.value
|
||||
text = str(value).lower()
|
||||
return {
|
||||
"relative": "relative",
|
||||
"absolute": "absolute",
|
||||
"delta": "delta",
|
||||
"eef": "eef",
|
||||
"non_eef": "non_eef",
|
||||
"default": "default",
|
||||
"xyz_rot6d": "xyz+rot6d",
|
||||
"xyz+rot6d": "xyz+rot6d",
|
||||
"xyz_rotvec": "xyz+rotvec",
|
||||
"xyz+rotvec": "xyz+rotvec",
|
||||
}.get(text, text)
|
||||
|
||||
|
||||
def has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
||||
if not stats:
|
||||
return False
|
||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||
|
||||
|
||||
def stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||
value = entry.get(stat_name)
|
||||
if isinstance(value, list) and len(value) > 0:
|
||||
return len(value)
|
||||
return 0
|
||||
|
||||
|
||||
def flatten_n1_7_modality_stats(
|
||||
*,
|
||||
embodiment_stats: dict[str, Any],
|
||||
embodiment_config: dict[str, Any],
|
||||
modality: str,
|
||||
use_percentiles: bool,
|
||||
use_relative_action: bool,
|
||||
) -> dict[str, list[float]]:
|
||||
"""Flatten one N1.7 modality's grouped statistics in checkpoint order.
|
||||
|
||||
When checkpoints request percentile normalization, q01/q99 replace min/max
|
||||
for regular groups. Relative action groups read from ``relative_action``
|
||||
stats and keep min/max, matching Isaac-GR00T's processor override.
|
||||
"""
|
||||
|
||||
source_stats = embodiment_stats.get(modality, {})
|
||||
modality_config = embodiment_config.get(modality, {})
|
||||
if not isinstance(source_stats, dict) or not isinstance(modality_config, dict):
|
||||
return {}
|
||||
modality_keys = modality_config.get("modality_keys", [])
|
||||
if not isinstance(modality_keys, list):
|
||||
return {}
|
||||
|
||||
flattened: dict[str, list[float]] = {}
|
||||
action_configs = modality_config.get("action_configs", []) if modality == "action" else []
|
||||
if not isinstance(action_configs, list):
|
||||
action_configs = []
|
||||
relative_stats = embodiment_stats.get("relative_action", {})
|
||||
if not isinstance(relative_stats, dict):
|
||||
relative_stats = {}
|
||||
|
||||
for stat_name in ("min", "max", "mean", "std"):
|
||||
values: list[float] = []
|
||||
source_stat_name = stat_name
|
||||
if use_percentiles and stat_name == "min":
|
||||
source_stat_name = "q01"
|
||||
elif use_percentiles and stat_name == "max":
|
||||
source_stat_name = "q99"
|
||||
|
||||
for idx, modality_key in enumerate(modality_keys):
|
||||
if not isinstance(modality_key, str):
|
||||
continue
|
||||
key_source_stats = source_stats
|
||||
key_stat_name = source_stat_name
|
||||
if modality == "action" and use_relative_action and idx < len(action_configs):
|
||||
action_config = action_configs[idx]
|
||||
if isinstance(action_config, dict) and config_value(action_config.get("rep")) == "relative":
|
||||
key_source_stats = relative_stats
|
||||
key_stat_name = stat_name
|
||||
key_stats = key_source_stats.get(modality_key, {})
|
||||
if not isinstance(key_stats, dict):
|
||||
raise KeyError(f"Missing statistics for {modality}.{modality_key}")
|
||||
raw_values = key_stats.get(key_stat_name)
|
||||
if raw_values is None:
|
||||
raise KeyError(f"Missing '{key_stat_name}' statistics for {modality}.{modality_key}")
|
||||
values.extend(as_float_list(raw_values))
|
||||
if values:
|
||||
flattened[stat_name] = values
|
||||
|
||||
return flattened
|
||||
|
||||
|
||||
def rot6d_to_matrix(rot6d: np.ndarray) -> np.ndarray:
|
||||
rows = rot6d.reshape(2, 3).astype(np.float64)
|
||||
row1 = rows[0] / np.linalg.norm(rows[0])
|
||||
row2 = rows[1] - np.dot(row1, rows[1]) * row1
|
||||
row2 = row2 / np.linalg.norm(row2)
|
||||
row3 = np.cross(row1, row2)
|
||||
return np.vstack([row1, row2, row3])
|
||||
|
||||
|
||||
def xyz_rot6d_to_homogeneous(xyz_rot6d: np.ndarray) -> np.ndarray:
|
||||
transform = np.eye(4, dtype=np.float64)
|
||||
transform[:3, :3] = rot6d_to_matrix(xyz_rot6d[3:])
|
||||
transform[:3, 3] = xyz_rot6d[:3]
|
||||
return transform
|
||||
|
||||
|
||||
def homogeneous_to_xyz_rot6d(transform: np.ndarray) -> np.ndarray:
|
||||
return np.concatenate([transform[:3, 3], transform[:2, :3].reshape(-1)], axis=0)
|
||||
|
||||
|
||||
def relative_eef_to_absolute(action: np.ndarray, reference_state: np.ndarray) -> np.ndarray:
|
||||
"""Convert relative EEF deltas in xyz+rot6d format to absolute EEF poses."""
|
||||
|
||||
out = np.empty_like(action, dtype=np.float64)
|
||||
for batch_idx in range(action.shape[0]):
|
||||
reference = xyz_rot6d_to_homogeneous(reference_state[batch_idx])
|
||||
for timestep in range(action.shape[1]):
|
||||
relative = xyz_rot6d_to_homogeneous(action[batch_idx, timestep])
|
||||
out[batch_idx, timestep] = homogeneous_to_xyz_rot6d(reference @ relative)
|
||||
return out.astype(np.float32)
|
||||
|
||||
|
||||
def infer_n1_7_batch_size_and_device(
|
||||
obs: dict[str, Any], action: torch.Tensor | None
|
||||
) -> tuple[int, torch.device]:
|
||||
for value in list(obs.values()) + [action]:
|
||||
if isinstance(value, torch.Tensor):
|
||||
return value.shape[0], value.device
|
||||
video = obs.get("video")
|
||||
if isinstance(video, np.ndarray):
|
||||
return video.shape[0], torch.device("cpu")
|
||||
return 1, torch.device("cpu")
|
||||
|
||||
|
||||
def prepare_n1_7_language_batch(
|
||||
language: Any,
|
||||
batch_size: int,
|
||||
*,
|
||||
formalize_language: bool,
|
||||
) -> list[str]:
|
||||
default_language = "Perform the task."
|
||||
if language is None or (isinstance(language, str) and language == ""):
|
||||
languages = [default_language] * batch_size
|
||||
elif isinstance(language, str):
|
||||
languages = [language] * batch_size
|
||||
elif isinstance(language, (list, tuple)):
|
||||
languages = list(language)
|
||||
if len(languages) == 0:
|
||||
languages = [default_language] * batch_size
|
||||
elif len(languages) == 1 and batch_size > 1:
|
||||
languages = languages * batch_size
|
||||
elif len(languages) != batch_size:
|
||||
raise ValueError(
|
||||
f"language batch has {len(languages)} entries, but GR00T N1.7 input batch has {batch_size}."
|
||||
)
|
||||
else:
|
||||
languages = [str(language)] * batch_size
|
||||
|
||||
formatted = []
|
||||
for item in languages:
|
||||
text = str(item) if item else default_language
|
||||
if formalize_language:
|
||||
text = text.lower()
|
||||
text = "".join(ch for ch in text if ch.isalnum() or ch.isspace() or ch == "_")
|
||||
formatted.append(text)
|
||||
return formatted
|
||||
@@ -207,11 +207,6 @@ def test_lerobot_groot_forward_pass():
|
||||
with torch.no_grad():
|
||||
lerobot_loss, lerobot_metrics = lerobot_policy.forward(batch_lerobot_processed)
|
||||
|
||||
assert isinstance(lerobot_loss, torch.Tensor)
|
||||
assert torch.isfinite(lerobot_loss).all()
|
||||
assert "loss" in lerobot_metrics
|
||||
assert np.isfinite(lerobot_metrics["loss"])
|
||||
|
||||
print("\nForward pass successful.")
|
||||
print(f" - Loss: {lerobot_loss.item():.6f}")
|
||||
print(f" - Metrics: {lerobot_metrics}")
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -14,36 +14,31 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Parity tests: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
"""Parity test: original NVIDIA GR00T N1.7 vs the GR00T N1.7 integration in LeRobot.
|
||||
|
||||
Two comparisons run per embodiment tag, against per-tag ``.npz`` artifacts produced
|
||||
once in the original ``gr00t`` env by the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file):
|
||||
Verifies that the self-contained LeRobot reimplementation of the GR00T N1.7 action
|
||||
head + Qwen3-VL backbone produces the SAME raw model output (``action_pred``, the
|
||||
normalized flow-matching prediction before any action decoding) as NVIDIA's original
|
||||
``gr00t`` package, given byte-identical pre-processed inputs and the same
|
||||
flow-matching seed. The comparison is parametrized over every embodiment tag present
|
||||
in the checkpoint.
|
||||
|
||||
1. **Model parity** -- the self-contained LeRobot reimplementation of the GR00T N1.7
|
||||
action head + Qwen3-VL backbone must produce the SAME raw model output
|
||||
(``action_pred``, the normalized flow-matching prediction before any action
|
||||
decoding) as NVIDIA's original ``gr00t`` package, given byte-identical
|
||||
pre-processed inputs and the flow-matching seed recorded in the artifact.
|
||||
2. **Preprocessor parity** -- LeRobot's own preprocessor pipeline (real Qwen3-VL chat
|
||||
template / tokenizer / image packing + state normalization, no mocks) must produce
|
||||
the SAME collated model inputs (``input_ids``, ``pixel_values``, ``state``, ...)
|
||||
as the original package's processor, given the identical raw observations
|
||||
(images, state, language) recorded in the artifact. Artifacts written by older
|
||||
versions of the dump script carry no raw observations; this case then SKIPS with
|
||||
a regeneration hint.
|
||||
To keep the comparison fair, the original outputs + the exact collated inputs are
|
||||
produced once per embodiment in the original ``gr00t`` env via the companion script
|
||||
``utils/dump_original_n1_7.py`` (in the ``utils`` package next to this file) and saved
|
||||
to per-tag ``.npz`` files.
|
||||
This test discovers those artifacts, replays the identical inputs through the LeRobot
|
||||
model, and compares.
|
||||
|
||||
These tests are LOCAL-only and skip on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default they look for artifacts in
|
||||
This test is LOCAL-only and skips on CI, when ``gr00t``-side prerequisites are not
|
||||
present, or when no artifact has been generated. By default it looks for artifacts in
|
||||
``<this dir>/artifacts/``; override with ``GROOT_N1_7_PARITY_DIR``. See the
|
||||
"Original-vs-LeRobot parity test" section of ``src/lerobot/policies/groot/README.md``
|
||||
for the full run procedure.
|
||||
"""
|
||||
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
@@ -55,9 +50,7 @@ pytestmark = pytest.mark.skipif(
|
||||
)
|
||||
|
||||
from lerobot.policies.groot.configuration_groot import GROOT_N1_7 # noqa: E402,F401
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE # noqa: E402
|
||||
|
||||
# Fallback flow-matching seed for artifacts predating the recorded ``seed`` field.
|
||||
SEED = 42
|
||||
DEVICE = os.environ.get("GROOT_PARITY_DEVICE", "cuda" if torch.cuda.is_available() else "cpu")
|
||||
ATOL = float(os.environ.get("GROOT_PARITY_ATOL", "1e-3"))
|
||||
@@ -67,11 +60,6 @@ RTOL = float(os.environ.get("GROOT_PARITY_RTOL", "1e-3"))
|
||||
_ARTIFACT_PREFIX = "original_n1_7_"
|
||||
_ARTIFACT_SUFFIX = ".npz"
|
||||
|
||||
# Collated keys compared by the preprocessor parity case: integer/id tensors must
|
||||
# match exactly; float tensors within ATOL/RTOL.
|
||||
_COLLATED_EXACT_KEYS = ("input_ids", "attention_mask", "image_grid_thw", "embodiment_id")
|
||||
_COLLATED_CLOSE_KEYS = ("pixel_values", "state")
|
||||
|
||||
|
||||
def _artifact_dir() -> Path:
|
||||
"""Directory holding the per-embodiment .npz artifacts.
|
||||
@@ -121,20 +109,9 @@ def _resolve_checkpoint() -> str:
|
||||
return str(ckpt)
|
||||
|
||||
|
||||
def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], int]:
|
||||
"""Return (original action_pred, collated model inputs, flow-matching seed)."""
|
||||
def _load_artifact(path: Path):
|
||||
data = np.load(path, allow_pickle=True)
|
||||
original_action = torch.from_numpy(data["action_pred"]).float()
|
||||
if "seed" in data.files:
|
||||
seed = int(data["seed"])
|
||||
else:
|
||||
warnings.warn(
|
||||
f"Artifact '{path.name}' does not record the producer seed (it predates the current "
|
||||
f"dump_original_n1_7.py); falling back to seed={SEED}. If the parity comparison fails, "
|
||||
"regenerate the artifact with the current dump script.",
|
||||
stacklevel=2,
|
||||
)
|
||||
seed = SEED
|
||||
dtypes = dict(zip(data["meta_keys"].tolist(), data["meta_dtypes"].tolist(), strict=False))
|
||||
inputs = {}
|
||||
for key in data.files:
|
||||
@@ -147,45 +124,7 @@ def _load_artifact(path: Path) -> tuple[torch.Tensor, dict[str, torch.Tensor], i
|
||||
if "int" in declared or "long" in declared:
|
||||
t = t.long()
|
||||
inputs[name] = t
|
||||
return original_action, inputs, seed
|
||||
|
||||
|
||||
def _load_raw_observation(path: Path) -> dict[str, Any] | None:
|
||||
"""Return the raw observation recorded in the artifact, or None for old artifacts.
|
||||
|
||||
Artifacts produced by the current ``dump_original_n1_7.py`` additionally store the
|
||||
exact raw observation the producer fed to the original processor: per-camera uint8
|
||||
frames (``raw::video.<key>``, (B, T, H, W, C)), per-key state vectors
|
||||
(``raw::state.<key>``, (B, T, dim)) and the language instruction
|
||||
(``raw::language``, one string per batch element). ``raw_video_keys`` /
|
||||
``raw_state_keys`` record the checkpoint modality-key order.
|
||||
"""
|
||||
data = np.load(path, allow_pickle=True)
|
||||
markers = ("raw_video_keys", "raw_state_keys", "raw::language")
|
||||
if any(marker not in data.files for marker in markers):
|
||||
return None
|
||||
video_keys = [str(k) for k in data["raw_video_keys"].tolist()]
|
||||
state_keys = [str(k) for k in data["raw_state_keys"].tolist()]
|
||||
return {
|
||||
"video": {k: data[f"raw::video.{k}"] for k in video_keys},
|
||||
"state": {k: data[f"raw::state.{k}"] for k in state_keys},
|
||||
"language": [str(t) for t in data["raw::language"].tolist()],
|
||||
}
|
||||
|
||||
|
||||
def _raw_observation_to_lerobot_batch(raw: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Convert the producer's raw observation into a LeRobot policy batch."""
|
||||
batch: dict[str, Any] = {}
|
||||
for key, frames in raw["video"].items():
|
||||
# (B, T, H, W, C) uint8 -> (B, T, C, H, W); the pack step converts back losslessly.
|
||||
batch[f"{OBS_IMAGES}.{key}"] = torch.from_numpy(frames).permute(0, 1, 4, 2, 3).contiguous()
|
||||
# observation.state is the per-key state vectors (latest frame) concatenated in
|
||||
# checkpoint modality-key order -- the layout the LeRobot pack step and the
|
||||
# flattened checkpoint statistics expect.
|
||||
state_parts = [torch.from_numpy(np.asarray(arr)[:, -1, :]).float() for arr in raw["state"].values()]
|
||||
batch[OBS_STATE] = torch.cat(state_parts, dim=-1)
|
||||
batch["task"] = list(raw["language"])
|
||||
return batch
|
||||
return original_action, inputs
|
||||
|
||||
|
||||
def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
@@ -200,36 +139,6 @@ def _unflatten(inputs: dict[str, torch.Tensor]) -> dict:
|
||||
return nested.get("inputs", nested)
|
||||
|
||||
|
||||
def _assert_collated_parity(
|
||||
embodiment_tag: str, name: str, lerobot_value: Any, original_value: torch.Tensor, *, exact: bool
|
||||
) -> None:
|
||||
"""Compare one collated tensor produced by LeRobot against the original's."""
|
||||
assert isinstance(lerobot_value, torch.Tensor), (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output '{name}' is "
|
||||
f"{type(lerobot_value).__name__}, expected a tensor."
|
||||
)
|
||||
lerobot_t = lerobot_value.detach().cpu()
|
||||
original_t = original_value.detach().cpu()
|
||||
assert lerobot_t.shape == original_t.shape, (
|
||||
f"[{embodiment_tag}] collated '{name}' shape mismatch: lerobot={tuple(lerobot_t.shape)} vs "
|
||||
f"original={tuple(original_t.shape)}."
|
||||
)
|
||||
if exact:
|
||||
mismatched = int((lerobot_t.long() != original_t.long()).sum())
|
||||
assert mismatched == 0, (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output: "
|
||||
f"{mismatched}/{original_t.numel()} elements mismatch."
|
||||
)
|
||||
else:
|
||||
lerobot_f, original_f = lerobot_t.float(), original_t.float()
|
||||
max_diff = (lerobot_f - original_f).abs().max().item()
|
||||
print(f"[{embodiment_tag}] {name}: shape {tuple(lerobot_t.shape)} max|diff|={max_diff:.6e}")
|
||||
assert torch.allclose(lerobot_f, original_f, atol=ATOL, rtol=RTOL), (
|
||||
f"[{embodiment_tag}] collated '{name}' differs from the original processor output beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}."
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def lerobot_model():
|
||||
"""Load the LeRobot GR00T N1.7 model once (fp32 + SDPA) and reuse across tags."""
|
||||
@@ -256,7 +165,8 @@ def lerobot_model():
|
||||
|
||||
_ARTIFACTS = _discover_artifacts()
|
||||
|
||||
_requires_artifacts = pytest.mark.skipif(
|
||||
|
||||
@pytest.mark.skipif(
|
||||
not _ARTIFACTS,
|
||||
reason=(
|
||||
"No GR00T N1.7 parity artifacts found. Generate them first in the original gr00t "
|
||||
@@ -264,30 +174,24 @@ _requires_artifacts = pytest.mark.skipif(
|
||||
"--ckpt <ckpt> --out-dir tests/policies/groot/artifacts --device cuda"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
"""Raw model.get_action(action_pred) parity per embodiment: original vs LeRobot."""
|
||||
original_action, flat_inputs, seed = _load_artifact(artifact)
|
||||
original_action, flat_inputs = _load_artifact(artifact)
|
||||
model_inputs = _unflatten(flat_inputs)
|
||||
|
||||
# Align the flow-matching RNG exactly as the producer did (seed right before sampling).
|
||||
torch.manual_seed(seed)
|
||||
torch.manual_seed(SEED)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
torch.cuda.manual_seed_all(SEED)
|
||||
with torch.inference_mode():
|
||||
out = lerobot_model.get_action(model_inputs)
|
||||
lerobot_action = out["action_pred"].float().cpu()
|
||||
|
||||
assert lerobot_action.shape == original_action.shape, (
|
||||
f"GR00T N1.7 action_pred shape mismatch for embodiment '{embodiment_tag}': "
|
||||
f"lerobot={tuple(lerobot_action.shape)} vs original={tuple(original_action.shape)}. "
|
||||
"The same checkpoint and inputs must produce identical shapes; this indicates an "
|
||||
"action-horizon or action-dim regression (or a stale artifact -- regenerate it with "
|
||||
"utils/dump_original_n1_7.py)."
|
||||
)
|
||||
t = min(original_action.shape[1], lerobot_action.shape[1])
|
||||
d = min(original_action.shape[2], lerobot_action.shape[2])
|
||||
original_action = original_action[:, :t, :d]
|
||||
lerobot_action = lerobot_action[:, :t, :d]
|
||||
|
||||
diff = torch.abs(lerobot_action - original_action)
|
||||
max_diff = diff.max().item()
|
||||
@@ -301,56 +205,3 @@ def test_groot_get_action_parity(embodiment_tag, artifact, lerobot_model):
|
||||
f"GR00T N1.7 raw action_pred differs for embodiment '{embodiment_tag}' beyond "
|
||||
f"atol={ATOL}, rtol={RTOL}: max|diff|={max_diff:.6e}"
|
||||
)
|
||||
|
||||
|
||||
@_requires_artifacts
|
||||
@pytest.mark.parametrize("embodiment_tag,artifact", _ARTIFACTS, ids=[t for t, _ in _ARTIFACTS])
|
||||
def test_groot_preprocessor_parity(embodiment_tag, artifact):
|
||||
"""LeRobot's real preprocessor vs the original's collated tensors, from identical raw obs.
|
||||
|
||||
Runs LeRobot's full preprocessor pipeline -- including the real Qwen3-VL chat
|
||||
template, tokenizer and image packing plus the checkpoint-driven state
|
||||
normalization (no mocks) -- on the raw observations recorded in the artifact, and
|
||||
compares every collated model input against the ones the original ``gr00t``
|
||||
processor produced from the same raw observations.
|
||||
"""
|
||||
raw = _load_raw_observation(artifact)
|
||||
if raw is None:
|
||||
pytest.skip(
|
||||
f"Artifact '{artifact.name}' was produced by an older dump_original_n1_7.py that does "
|
||||
"not record raw observations; regenerate it with the current dump script to run the "
|
||||
"preprocessor parity case."
|
||||
)
|
||||
_, flat_inputs, _ = _load_artifact(artifact)
|
||||
original_inputs = _unflatten(flat_inputs)
|
||||
|
||||
ckpt = _resolve_checkpoint()
|
||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
# CPU keeps this case runnable without a GPU; the preprocessor is deterministic.
|
||||
config = GrootConfig(base_model_path=ckpt, embodiment_tag=embodiment_tag, device="cpu")
|
||||
preprocessor, _ = make_groot_pre_post_processors(config)
|
||||
|
||||
processed = preprocessor(_raw_observation_to_lerobot_batch(raw))
|
||||
|
||||
compared_keys = (*_COLLATED_EXACT_KEYS, *_COLLATED_CLOSE_KEYS)
|
||||
missing_original = [k for k in compared_keys if k not in original_inputs]
|
||||
missing_lerobot = [k for k in compared_keys if k not in processed]
|
||||
assert not missing_original, (
|
||||
f"[{embodiment_tag}] artifact collated inputs miss {missing_original} "
|
||||
f"(available: {sorted(original_inputs)}); regenerate the artifact with the current dump script."
|
||||
)
|
||||
assert not missing_lerobot, (
|
||||
f"[{embodiment_tag}] LeRobot preprocessor output misses {missing_lerobot} (tensor keys "
|
||||
f"available: {sorted(k for k, v in processed.items() if isinstance(v, torch.Tensor))})."
|
||||
)
|
||||
|
||||
for name in compared_keys:
|
||||
_assert_collated_parity(
|
||||
embodiment_tag,
|
||||
name,
|
||||
processed[name],
|
||||
original_inputs[name],
|
||||
exact=name in _COLLATED_EXACT_KEYS,
|
||||
)
|
||||
|
||||
@@ -9,9 +9,6 @@ LeRobot GR00T N1.7 integration requires. The two implementations therefore canno
|
||||
imported in the same Python process. To keep the parity comparison FAIR, we run the
|
||||
original model in its native env here and serialize, PER EMBODIMENT TAG:
|
||||
|
||||
* the RAW observation fed to the original processor (per-camera uint8 frames,
|
||||
per-key state vectors, the language instruction), so the LeRobot side can also
|
||||
run its OWN preprocessor on identical raw inputs and compare collated tensors,
|
||||
* the exact pre-processed/collated model inputs (so the LeRobot side consumes the
|
||||
byte-identical tensors -- same image preprocessing, tokenization, normalization),
|
||||
* the random seed used right before the flow-matching sampler,
|
||||
@@ -24,10 +21,8 @@ processor's per-embodiment modality configs. This lets us test many embodiment t
|
||||
from the SAME checkpoint and confirm the LeRobot integration is not overfit to
|
||||
``libero_sim``.
|
||||
|
||||
The companion pytest (run in the LeRobot env) loads each .npz and asserts parity
|
||||
twice: the collated inputs + seed are replayed through the LeRobot GR00T N1.7 model
|
||||
(model parity), and the raw observation is replayed through LeRobot's own
|
||||
preprocessor pipeline and compared against the collated inputs (preprocessor parity).
|
||||
The companion pytest (run in the LeRobot env) loads each .npz, replays the identical
|
||||
inputs + seed through the LeRobot GR00T N1.7 model, and asserts the outputs match.
|
||||
|
||||
Usage:
|
||||
.venv-original/bin/python tests/policies/groot/utils/dump_original_n1_7.py \
|
||||
@@ -67,7 +62,10 @@ def make_observation(seed: int, video_keys, lang_key, state_spec):
|
||||
# One ndarray per state key, shape (B, T=1, key_dim); dim taken from statistics.
|
||||
# Keys with dim 0 (e.g. disabled eef on some embodiments) are still emitted as
|
||||
# present-but-empty so the processor's state transform finds every expected key.
|
||||
state = {k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32) for k, dim in state_spec}
|
||||
state = {
|
||||
k: rng.standard_normal((BATCH_SIZE, 1, dim)).astype(np.float32)
|
||||
for k, dim in state_spec
|
||||
}
|
||||
language = {lang_key: [[PROMPT] for _ in range(BATCH_SIZE)]}
|
||||
return {"video": video, "state": state, "language": language}
|
||||
|
||||
@@ -79,25 +77,6 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
lang_key = modality_cfg["language"].modality_keys[0]
|
||||
observation = make_observation(args.seed, video_keys, lang_key, state_spec)
|
||||
|
||||
# Snapshot the RAW observation exactly as fed to the original processor below. The
|
||||
# consumer's preprocessor-parity case replays it through LeRobot's own preprocessor
|
||||
# and compares the resulting collated tensors against the "in::" ones saved further
|
||||
# down. raw_state_keys records the checkpoint modality-key order, which is the
|
||||
# concatenation order of the flat LeRobot ``observation.state`` vector.
|
||||
spec_keys = [key for key, _ in state_spec]
|
||||
state_modality = modality_cfg.get("state")
|
||||
state_keys = [key for key in state_modality.modality_keys if key in spec_keys] if state_modality else []
|
||||
state_keys += [key for key in spec_keys if key not in state_keys]
|
||||
raw_language = [
|
||||
str(item[0]) if isinstance(item, (list, tuple)) else str(item)
|
||||
for item in observation["language"][lang_key]
|
||||
]
|
||||
raw_flat = {f"raw::video.{key}": arr.copy() for key, arr in observation["video"].items()}
|
||||
raw_flat.update({f"raw::state.{key}": arr.copy() for key, arr in observation["state"].items()})
|
||||
raw_flat["raw::language"] = np.array(raw_language, dtype=object)
|
||||
raw_flat["raw_video_keys"] = np.array([str(key) for key in video_keys], dtype=object)
|
||||
raw_flat["raw_state_keys"] = np.array([str(key) for key in state_keys], dtype=object)
|
||||
|
||||
# Point the policy preprocessing at this embodiment (mirrors Gr00tPolicy.__init__).
|
||||
policy.embodiment_tag = type(policy.embodiment_tag)(tag)
|
||||
policy.modality_configs = {
|
||||
@@ -157,7 +136,6 @@ def dump_one_tag(policy, fair_model, tag, modality_cfg, state_spec, args, out_pa
|
||||
embodiment_tag=np.array(tag),
|
||||
meta_keys=np.array(list(meta.keys()), dtype=object),
|
||||
meta_dtypes=np.array(list(meta.values()), dtype=object),
|
||||
**raw_flat,
|
||||
**flat,
|
||||
)
|
||||
print(f"[{tag}] action_pred {action_pred.shape} -> {out_path.name} ({os.path.getsize(out_path)} B)")
|
||||
@@ -203,12 +181,7 @@ def main():
|
||||
state_spec = [(k, len(v["min"])) for k, v in stats[tag]["state"].items()]
|
||||
try:
|
||||
dump_one_tag(
|
||||
policy,
|
||||
fair_model,
|
||||
tag,
|
||||
all_modality[tag],
|
||||
state_spec,
|
||||
args,
|
||||
policy, fair_model, tag, all_modality[tag], state_spec, args,
|
||||
out_dir / f"original_n1_7_{tag}.npz",
|
||||
)
|
||||
done.append(tag)
|
||||
|
||||
@@ -1464,17 +1464,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/81/47/dd9a212ef6e343a6857485ffe25bba537304f1913bdbed446a23f7f592e1/filelock-3.29.0-py3-none-any.whl", hash = "sha256:96f5f6344709aa1572bbf631c640e4ebeeb519e08da902c39a001882f30ac258", size = 39812, upload-time = "2026-04-19T15:39:08.752Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "flash-attn"
|
||||
version = "2.8.3"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
dependencies = [
|
||||
{ name = "einops", marker = "platform_machine != 'arm64' or sys_platform != 'darwin'" },
|
||||
{ name = "torch", version = "2.11.0", source = { registry = "https://pypi.org/simple" }, marker = "(platform_machine != 'arm64' and sys_platform == 'darwin') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
|
||||
{ name = "torch", version = "2.11.0+cu128", source = { registry = "https://download.pytorch.org/whl/cu128" }, marker = "sys_platform == 'linux'" },
|
||||
]
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/3b/b2/8d76c41ad7974ee264754709c22963447f7f8134613fd9ce80984ed0dab7/flash_attn-2.8.3.tar.gz", hash = "sha256:1e71dd64a9e0280e0447b8a0c2541bad4bf6ac65bdeaa2f90e51a9e57de0370d", size = 8447812, upload-time = "2025-08-15T08:28:12.911Z" }
|
||||
|
||||
[[package]]
|
||||
name = "flatbuffers"
|
||||
version = "25.12.19"
|
||||
@@ -2690,8 +2679,10 @@ all = [
|
||||
{ name = "contourpy" },
|
||||
{ name = "datasets" },
|
||||
{ name = "debugpy" },
|
||||
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
|
||||
{ name = "deepdiff" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dm-tree" },
|
||||
{ name = "dynamixel-sdk" },
|
||||
{ name = "faker" },
|
||||
{ name = "fastapi" },
|
||||
@@ -2739,6 +2730,7 @@ all = [
|
||||
{ name = "scikit-image" },
|
||||
{ name = "scipy" },
|
||||
{ name = "teleop" },
|
||||
{ name = "timm" },
|
||||
{ name = "torchcodec", marker = "(platform_machine == 'arm64' and sys_platform == 'darwin') or (platform_machine == 'AMD64' and sys_platform == 'linux') or (platform_machine == 'aarch64' and sys_platform == 'linux') or (platform_machine == 'arm64' and sys_platform == 'linux') or (platform_machine == 'x86_64' and sys_platform == 'linux') or sys_platform == 'win32'" },
|
||||
{ name = "torchdiffeq" },
|
||||
{ name = "transformers" },
|
||||
@@ -2843,8 +2835,6 @@ groot = [
|
||||
{ name = "decord", marker = "platform_machine == 'AMD64' or platform_machine == 'x86_64'" },
|
||||
{ name = "diffusers" },
|
||||
{ name = "dm-tree" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin'" },
|
||||
{ name = "ninja" },
|
||||
{ name = "peft" },
|
||||
{ name = "timm" },
|
||||
{ name = "transformers" },
|
||||
@@ -3087,7 +3077,6 @@ requires-dist = [
|
||||
{ name = "faker", marker = "extra == 'sarm'", specifier = ">=33.0.0,<35.0.0" },
|
||||
{ name = "fastapi", marker = "extra == 'phone'", specifier = "<1.0" },
|
||||
{ name = "feetech-servo-sdk", marker = "extra == 'feetech'", specifier = ">=1.0.0,<2.0.0" },
|
||||
{ name = "flash-attn", marker = "sys_platform != 'darwin' and extra == 'groot'", specifier = ">=2.5.9,<3.0.0" },
|
||||
{ name = "grpcio", marker = "extra == 'grpcio-dep'", specifier = "==1.73.1" },
|
||||
{ name = "grpcio-tools", marker = "extra == 'dev'", specifier = "==1.73.1" },
|
||||
{ name = "gym-aloha", marker = "extra == 'aloha'", specifier = ">=0.1.2,<0.2.0" },
|
||||
@@ -3132,6 +3121,7 @@ requires-dist = [
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'hopejr'" },
|
||||
{ name = "lerobot", extras = ["feetech"], marker = "extra == 'lekiwi'" },
|
||||
{ name = "lerobot", extras = ["gamepad"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["groot"], marker = "extra == 'all'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'async'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'dev'" },
|
||||
{ name = "lerobot", extras = ["grpcio-dep"], marker = "extra == 'hilserl'" },
|
||||
@@ -3224,7 +3214,6 @@ requires-dist = [
|
||||
{ name = "motorbridge", marker = "extra == 'motorbridge-dep'", specifier = ">=0.3.2,<0.4.0" },
|
||||
{ name = "motorbridge-smart-servo", marker = "extra == 'motorbridge-smart-servo-dep'", specifier = ">=0.0.4,<0.1.0" },
|
||||
{ name = "mypy", marker = "extra == 'dev'", specifier = ">=1.19.1" },
|
||||
{ name = "ninja", marker = "extra == 'groot'", specifier = ">=1.11.1,<2.0.0" },
|
||||
{ name = "num2words", marker = "extra == 'smolvla'", specifier = ">=0.5.14,<0.6.0" },
|
||||
{ name = "numpy", specifier = ">=2.0.0,<2.3.0" },
|
||||
{ name = "onnx", marker = "extra == 'unitree-g1'", specifier = ">=1.16.0,<2.0.0" },
|
||||
@@ -4012,32 +4001,6 @@ wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/9e/c9/b2622292ea83fbb4ec318f5b9ab867d0a28ab43c5717bb85b0a5f6b3b0a4/networkx-3.6.1-py3-none-any.whl", hash = "sha256:d47fbf302e7d9cbbb9e2555a0d267983d2aa476bac30e90dfbe5669bd57f3762", size = 2068504, upload-time = "2025-12-08T17:02:38.159Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "ninja"
|
||||
version = "1.13.0"
|
||||
source = { registry = "https://pypi.org/simple" }
|
||||
sdist = { url = "https://files.pythonhosted.org/packages/43/73/79a0b22fc731989c708068427579e840a6cf4e937fe7ae5c5d0b7356ac22/ninja-1.13.0.tar.gz", hash = "sha256:4a40ce995ded54d9dc24f8ea37ff3bf62ad192b547f6c7126e7e25045e76f978", size = 242558, upload-time = "2025-08-11T15:10:19.421Z" }
|
||||
wheels = [
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/74/d02409ed2aa865e051b7edda22ad416a39d81a84980f544f8de717cab133/ninja-1.13.0-py3-none-macosx_10_9_universal2.whl", hash = "sha256:fa2a8bfc62e31b08f83127d1613d10821775a0eb334197154c4d6067b7068ff1", size = 310125, upload-time = "2025-08-11T15:09:50.971Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8e/de/6e1cd6b84b412ac1ef327b76f0641aeb5dcc01e9d3f9eee0286d0c34fd93/ninja-1.13.0-py3-none-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:3d00c692fb717fd511abeb44b8c5d00340c36938c12d6538ba989fe764e79630", size = 177467, upload-time = "2025-08-11T15:09:52.767Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c8/83/49320fb6e58ae3c079381e333575fdbcf1cca3506ee160a2dcce775046fa/ninja-1.13.0-py3-none-manylinux2014_i686.manylinux_2_17_i686.whl", hash = "sha256:be7f478ff9f96a128b599a964fc60a6a87b9fa332ee1bd44fa243ac88d50291c", size = 187834, upload-time = "2025-08-11T15:09:54.115Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/56/c7/ba22748fb59f7f896b609cd3e568d28a0a367a6d953c24c461fe04fc4433/ninja-1.13.0-py3-none-manylinux2014_ppc64le.manylinux_2_17_ppc64le.whl", hash = "sha256:60056592cf495e9a6a4bea3cd178903056ecb0943e4de45a2ea825edb6dc8d3e", size = 202736, upload-time = "2025-08-11T15:09:55.745Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/79/22/d1de07632b78ac8e6b785f41fa9aad7a978ec8c0a1bf15772def36d77aac/ninja-1.13.0-py3-none-manylinux2014_s390x.manylinux_2_17_s390x.whl", hash = "sha256:1c97223cdda0417f414bf864cfb73b72d8777e57ebb279c5f6de368de0062988", size = 179034, upload-time = "2025-08-11T15:09:57.394Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/ed/de/0e6edf44d6a04dabd0318a519125ed0415ce437ad5a1ec9b9be03d9048cf/ninja-1.13.0-py3-none-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:fb46acf6b93b8dd0322adc3a4945452a4e774b75b91293bafcc7b7f8e6517dfa", size = 180716, upload-time = "2025-08-11T15:09:58.696Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/54/28/938b562f9057aaa4d6bfbeaa05e81899a47aebb3ba6751e36c027a7f5ff7/ninja-1.13.0-py3-none-manylinux_2_28_armv7l.manylinux_2_31_armv7l.whl", hash = "sha256:4be9c1b082d244b1ad7ef41eb8ab088aae8c109a9f3f0b3e56a252d3e00f42c1", size = 146843, upload-time = "2025-08-11T15:10:00.046Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/2a/fb/d06a3838de4f8ab866e44ee52a797b5491df823901c54943b2adb0389fbb/ninja-1.13.0-py3-none-manylinux_2_31_riscv64.whl", hash = "sha256:6739d3352073341ad284246f81339a384eec091d9851a886dfa5b00a6d48b3e2", size = 154402, upload-time = "2025-08-11T15:10:01.657Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/31/bf/0d7808af695ceddc763cf251b84a9892cd7f51622dc8b4c89d5012779f06/ninja-1.13.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:11be2d22027bde06f14c343f01d31446747dbb51e72d00decca2eb99be911e2f", size = 552388, upload-time = "2025-08-11T15:10:03.349Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9d/70/c99d0c2c809f992752453cce312848abb3b1607e56d4cd1b6cded317351a/ninja-1.13.0-py3-none-musllinux_1_2_armv7l.whl", hash = "sha256:aa45b4037b313c2f698bc13306239b8b93b4680eb47e287773156ac9e9304714", size = 472501, upload-time = "2025-08-11T15:10:04.735Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/9f/43/c217b1153f0e499652f5e0766da8523ce3480f0a951039c7af115e224d55/ninja-1.13.0-py3-none-musllinux_1_2_i686.whl", hash = "sha256:5f8e1e8a1a30835eeb51db05cf5a67151ad37542f5a4af2a438e9490915e5b72", size = 638280, upload-time = "2025-08-11T15:10:06.512Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/8c/45/9151bba2c8d0ae2b6260f71696330590de5850e5574b7b5694dce6023e20/ninja-1.13.0-py3-none-musllinux_1_2_ppc64le.whl", hash = "sha256:3d7d7779d12cb20c6d054c61b702139fd23a7a964ec8f2c823f1ab1b084150db", size = 642420, upload-time = "2025-08-11T15:10:08.35Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/3c/fb/95752eb635bb8ad27d101d71bef15bc63049de23f299e312878fc21cb2da/ninja-1.13.0-py3-none-musllinux_1_2_riscv64.whl", hash = "sha256:d741a5e6754e0bda767e3274a0f0deeef4807f1fec6c0d7921a0244018926ae5", size = 585106, upload-time = "2025-08-11T15:10:09.818Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/c1/31/aa56a1a286703800c0cbe39fb4e82811c277772dc8cd084f442dd8e2938a/ninja-1.13.0-py3-none-musllinux_1_2_s390x.whl", hash = "sha256:e8bad11f8a00b64137e9b315b137d8bb6cbf3086fbdc43bf1f90fd33324d2e96", size = 707138, upload-time = "2025-08-11T15:10:11.366Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/34/6f/5f5a54a1041af945130abdb2b8529cbef0cdcbbf9bcf3f4195378319d29a/ninja-1.13.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:b4f2a072db3c0f944c32793e91532d8948d20d9ab83da9c0c7c15b5768072200", size = 581758, upload-time = "2025-08-11T15:10:13.295Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/95/97/51359c77527d45943fe7a94d00a3843b81162e6c4244b3579fe8fc54cb9c/ninja-1.13.0-py3-none-win32.whl", hash = "sha256:8cfbb80b4a53456ae8a39f90ae3d7a2129f45ea164f43fadfa15dc38c4aef1c9", size = 267201, upload-time = "2025-08-11T15:10:15.158Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/29/45/c0adfbfb0b5895aa18cec400c535b4f7ff3e52536e0403602fc1a23f7de9/ninja-1.13.0-py3-none-win_amd64.whl", hash = "sha256:fb8ee8719f8af47fed145cced4a85f0755dd55d45b2bddaf7431fa89803c5f3e", size = 309975, upload-time = "2025-08-11T15:10:16.697Z" },
|
||||
{ url = "https://files.pythonhosted.org/packages/df/93/a7b983643d1253bb223234b5b226e69de6cda02b76cdca7770f684b795f5/ninja-1.13.0-py3-none-win_arm64.whl", hash = "sha256:3c0b40b1f0bba764644385319028650087b4c1b18cdfa6f45cb39a3669b81aa9", size = 290806, upload-time = "2025-08-11T15:10:18.018Z" },
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "nodeenv"
|
||||
version = "1.10.0"
|
||||
|
||||
Reference in New Issue
Block a user