mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 08:47:05 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 226a4c5a8c | |||
| 05a9ca274b | |||
| 13ed657056 | |||
| 559cba212d | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec |
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
|||||||
|
|
||||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||||
|
|
||||||
## Model Overview
|
## Model Overview
|
||||||
|
|
||||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||||
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
|
|||||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
|
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||||
--include "libero_spatial/*" \
|
--include "libero_spatial/*" \
|
||||||
--local-dir ./GR00T-N1.7-LIBERO
|
--local-dir ./GR00T-N1.7-LIBERO
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
## Research Paper
|
## Research Paper
|
||||||
|
|
||||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||||
|
|
||||||
|
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||||
|
|
||||||
|
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||||
|
|
||||||
|
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||||
|
> Current releases support GR00T N1.7 only.
|
||||||
|
|
||||||
## Repository
|
## Repository
|
||||||
|
|
||||||
@@ -31,12 +38,22 @@ Hugging Face Models:
|
|||||||
|
|
||||||
## Original-vs-LeRobot parity test
|
## Original-vs-LeRobot parity test
|
||||||
|
|
||||||
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
|
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||||
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
|
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||||
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
|
over every embodiment tag present in the checkpoint:
|
||||||
byte-identical pre-processed inputs and the same flow-matching seed. It is
|
|
||||||
parametrized over every embodiment tag present in the checkpoint.
|
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||||
|
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||||
|
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||||
|
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||||
|
or action-dim mismatch fails the test.
|
||||||
|
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||||
|
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||||
|
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||||
|
state normalization, no mocks) must produce the **same collated model inputs**
|
||||||
|
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||||
|
`embodiment_id`) as the original package's processor.
|
||||||
|
|
||||||
### Why two environments
|
### Why two environments
|
||||||
|
|
||||||
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
|
|||||||
|
|
||||||
So the test uses a **producer / consumer** split across two venvs:
|
So the test uses a **producer / consumer** split across two venvs:
|
||||||
|
|
||||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
|
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||||
the processor modality configs), runs the original model, and saves the exact
|
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||||
collated inputs + raw `action_pred` to one `.npz` per tag.
|
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||||
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
|
(`in::` keys), the seed, and the raw `action_pred`.
|
||||||
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
|
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||||
seed, and asserts the outputs match.
|
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||||
|
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||||
|
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||||
|
preprocessor pipeline and asserts the collated tensors match.
|
||||||
|
|
||||||
|
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||||
|
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||||
|
> Re-run the producer to refresh them.
|
||||||
|
|
||||||
### Fairness controls
|
### Fairness controls
|
||||||
|
|
||||||
- **Same pre-processed inputs** — the original processor's `input_ids`,
|
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
|
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||||
|
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||||
|
covered separately by the preprocessor-parity case, which compares its output
|
||||||
|
against those same collated tensors from identical raw observations.
|
||||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||||
kernel/rounding noise, not an implementation difference.)
|
kernel/rounding noise, not an implementation difference.)
|
||||||
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
|
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||||
|
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||||
|
replays the recorded value.
|
||||||
|
|
||||||
### How to run
|
### How to run
|
||||||
|
|
||||||
@@ -90,15 +119,15 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
|||||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||||
```
|
```
|
||||||
|
|
||||||
The `.npz` artifacts are local-only (gitignored, ~6–9 MB each) and are regenerated by
|
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||||
the producer; they are never committed. The test **skips** (does not fail) on CI or
|
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||||
when the checkpoint / artifacts are absent.
|
when the checkpoint / artifacts are absent.
|
||||||
|
|
||||||
#### Env knobs (all optional)
|
#### Env knobs (all optional)
|
||||||
|
|
||||||
| Var | Default | Purpose |
|
| Var | Default | Purpose |
|
||||||
|---|---|---|
|
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||||
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
| `GROOT_PARITY_ATOL` / `GROOT_PARITY_RTOL` | `1e-3` | comparison tolerance |
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -42,6 +43,9 @@ else:
|
|||||||
Timesteps = None
|
Timesteps = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(nn.Module):
|
class TimestepEncoder(nn.Module):
|
||||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||||
require_package("diffusers", extra="groot")
|
require_package("diffusers", extra="groot")
|
||||||
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
|||||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of DiT parameters: ",
|
"Total number of DiT parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
|||||||
for _ in range(self.config.num_layers)
|
for _ in range(self.config.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of SelfAttentionTransformer parameters: ",
|
"Total number of SelfAttentionTransformer parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,15 +24,37 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
|||||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GROOT_N1_7 = "n1.7"
|
GROOT_N1_7 = "n1.7"
|
||||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||||
GROOT_N1_5 = "n1.5"
|
GROOT_N1_5 = "n1.5"
|
||||||
|
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||||
|
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||||
|
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||||
|
"GR00T N1.5 support was removed from LeRobot. "
|
||||||
|
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||||
|
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||||
|
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||||
|
)
|
||||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||||
|
# Image preprocessing geometry the GR00T N1.7 backbone was trained on. The processor
|
||||||
|
# falls back to these when a checkpoint ships no image sizing in its processor_config
|
||||||
|
# (e.g. fine-tuning the raw nvidia/GR00T-N1.7-3B base with a new embodiment), so frames
|
||||||
|
# are resized to the expected resolution instead of being patchified at full camera
|
||||||
|
# resolution (which both slows training and is a train/checkpoint distribution mismatch).
|
||||||
|
# Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||||
|
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||||
|
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||||
|
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||||
|
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||||
|
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||||
|
|
||||||
_GROOT_MODEL_VERSION_ALIASES = {
|
_GROOT_MODEL_VERSION_ALIASES = {
|
||||||
"n1.7": GROOT_N1_7,
|
"n1.7": GROOT_N1_7,
|
||||||
@@ -41,7 +64,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
|
|||||||
"1.7": GROOT_N1_7,
|
"1.7": GROOT_N1_7,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||||
|
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||||
|
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||||
|
|
||||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||||
"none": None,
|
"none": None,
|
||||||
"": None,
|
"": None,
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||||
@@ -52,9 +80,10 @@ def normalize_groot_model_version(model_version: str) -> str:
|
|||||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||||
if normalized is None:
|
if normalized is None:
|
||||||
supported = GROOT_N1_7
|
supported = GROOT_N1_7
|
||||||
raise ValueError(
|
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||||
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||||
)
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
@@ -286,6 +315,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
|||||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||||
model_version = config.get("model_version")
|
model_version = config.get("model_version")
|
||||||
if isinstance(model_version, str):
|
if isinstance(model_version, str):
|
||||||
|
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||||
|
return GROOT_N1_5
|
||||||
try:
|
try:
|
||||||
return normalize_groot_model_version(model_version)
|
return normalize_groot_model_version(model_version)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -298,8 +329,14 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
|||||||
normalized = candidate.lower().replace("-", "_")
|
normalized = candidate.lower().replace("-", "_")
|
||||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||||
return GROOT_N1_7
|
return GROOT_N1_7
|
||||||
|
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||||
|
return GROOT_N1_5
|
||||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||||
return GROOT_N1_7
|
return GROOT_N1_7
|
||||||
|
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||||
|
backbone_cfg = config.get("backbone_cfg")
|
||||||
|
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||||
|
return GROOT_N1_5
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -310,29 +347,30 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Basic policy settings
|
# Basic policy settings
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 50
|
chunk_size: int = 40
|
||||||
n_action_steps: int = 50
|
n_action_steps: int = 40
|
||||||
|
|
||||||
# Dimension settings (must match pretrained GR00T model expectations)
|
# Dimension settings (must match pretrained GR00T model expectations)
|
||||||
# Maximum state dimension. Shorter states will be zero-padded.
|
# Maximum state dimension. Shorter states will be zero-padded.
|
||||||
max_state_dim: int = 64
|
max_state_dim: int = 132
|
||||||
|
|
||||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||||
max_action_dim: int = 32
|
max_action_dim: int = 132
|
||||||
|
|
||||||
# Normalization (start with identity, adjust as needed)
|
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||||
|
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||||
|
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||||
|
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||||
|
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
"STATE": NormalizationMode.MEAN_STD,
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
"ACTION": NormalizationMode.MEAN_STD,
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Image preprocessing (adjust to match Groot's expected input)
|
# Groot-specific model parameters
|
||||||
image_size: tuple[int, int] = (224, 224)
|
|
||||||
|
|
||||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
|
||||||
|
|
||||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||||
model_version: str = GROOT_N1_7
|
model_version: str = GROOT_N1_7
|
||||||
@@ -344,11 +382,47 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||||
|
|
||||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||||
action_decode_transform: str | None = None
|
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||||
|
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||||
|
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||||
|
|
||||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||||
embodiment_tag: str = "new_embodiment"
|
embodiment_tag: str = "new_embodiment"
|
||||||
|
|
||||||
|
# Inference-only override for the number of flow-matching denoising steps used to decode an
|
||||||
|
# action chunk. None = use the model checkpoint default (currently 4). Higher values trade
|
||||||
|
# inference speed for action quality; applied at base-model load via _create_groot_model.
|
||||||
|
num_inference_timesteps: int | None = None
|
||||||
|
|
||||||
|
# If set, caps the number of open-loop actions executed before replanning (inference cadence).
|
||||||
|
# Overrides the value inferred from the checkpoint/embodiment in _resolve_action_queue_steps.
|
||||||
|
execution_horizon: int | None = None
|
||||||
|
|
||||||
|
# Opt-in. Copy a pretrained embodiment category slot's action-head weights into the target
|
||||||
|
# embodiment slot at base-model build (in _create_groot_model), to warm-start a cold
|
||||||
|
# 'new_embodiment' slot. Accepts an embodiment name (e.g.
|
||||||
|
# 'oxe_droid_relative_eef_relative_joint') or an int embodiment id. Runs on every fresh
|
||||||
|
# base-model build (so it applies during lerobot-train, which uses __init__ not
|
||||||
|
# from_pretrained); on a fine-tuned checkpoint reload it is harmlessly overwritten.
|
||||||
|
warm_start_embodiment_slot: int | str | None = None
|
||||||
|
|
||||||
|
# Opt-in relative-action support for the 'new_embodiment' slot (sync-safe, GR00T-native).
|
||||||
|
# When True, GR00T converts absolute->relative inside its own pack step (training) and
|
||||||
|
# reconstructs absolute inside its own flat decode step (inference), using a cached
|
||||||
|
# reference state. The dataset stays absolute; compute relative ACTION stats with
|
||||||
|
# `lerobot-edit-dataset --operation.relative_action true --operation.relative_exclude_joints
|
||||||
|
# "['gripper']"` (this only rewrites stats, not actions).
|
||||||
|
use_relative_actions: bool = False
|
||||||
|
|
||||||
|
# Joint names kept absolute (not converted to relative) when use_relative_actions is True.
|
||||||
|
# Case-insensitive token match against action_feature_names.
|
||||||
|
relative_exclude_joints: list[str] = field(default_factory=lambda: ["gripper"])
|
||||||
|
|
||||||
|
# Action dimension names from dataset metadata; auto-populated by the factory from dataset
|
||||||
|
# meta (see factory.py:528). Used to build the relative-action mask so the gripper can be
|
||||||
|
# identified and kept absolute. When None, the gripper cannot be identified.
|
||||||
|
action_feature_names: list[str] | None = None
|
||||||
|
|
||||||
# Fine-tuning control arguments
|
# Fine-tuning control arguments
|
||||||
|
|
||||||
# Whether to fine-tune the llm backbone
|
# Whether to fine-tune the llm backbone
|
||||||
@@ -384,17 +458,16 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
warmup_ratio: float = 0.05
|
warmup_ratio: float = 0.05
|
||||||
use_bf16: bool = True
|
use_bf16: bool = True
|
||||||
|
|
||||||
# Dataset parameters
|
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||||
|
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||||
|
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||||
|
# would break every previously saved groot checkpoint at config-load time.
|
||||||
|
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||||
|
tokenizer_assets_repo: str | None = None
|
||||||
video_backend: str = "decord"
|
video_backend: str = "decord"
|
||||||
|
|
||||||
# Whether to balance dataset weights in mixture datasets
|
|
||||||
balance_dataset_weights: bool = True
|
balance_dataset_weights: bool = True
|
||||||
|
|
||||||
# Whether to sample trajectories weighted by their length
|
|
||||||
balance_trajectory_weights: bool = True
|
balance_trajectory_weights: bool = True
|
||||||
|
|
||||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
|
||||||
dataset_paths: list[str] | None = None
|
dataset_paths: list[str] | None = None
|
||||||
output_dir: str = "./tmp/gr00t"
|
output_dir: str = "./tmp/gr00t"
|
||||||
save_steps: int = 1000
|
save_steps: int = 1000
|
||||||
@@ -405,6 +478,12 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
resume: bool = False
|
resume: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.tokenizer_assets_repo is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||||
|
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
)
|
||||||
|
|
||||||
self.model_version = normalize_groot_model_version(self.model_version)
|
self.model_version = normalize_groot_model_version(self.model_version)
|
||||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||||
if self.base_model_path is None:
|
if self.base_model_path is None:
|
||||||
@@ -416,26 +495,48 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||||
# This matches the embodiment-specific handling already done for the
|
# This matches the embodiment-specific handling already done for the
|
||||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||||
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
|
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||||
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
|
# 'none' (normalized to None above) keeps the transform disabled.
|
||||||
|
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||||
|
self.action_decode_transform = (
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||||
|
)
|
||||||
|
|
||||||
if self.max_state_dim == 64:
|
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||||
self.max_state_dim = 132
|
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||||
if self.max_action_dim == 32:
|
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||||
self.max_action_dim = 132
|
# GrootConfig() never triggers this.
|
||||||
if self.chunk_size == 50:
|
legacy_default_remaps = (
|
||||||
self.chunk_size = 40
|
("max_state_dim", 64, 132),
|
||||||
if self.n_action_steps == 50:
|
("max_action_dim", 32, 132),
|
||||||
self.n_action_steps = 40
|
("chunk_size", 50, 40),
|
||||||
if tuple(self.image_size) == (224, 224):
|
("n_action_steps", 50, 40),
|
||||||
self.image_size = (256, 256)
|
("image_size", (224, 224), (256, 256)),
|
||||||
|
)
|
||||||
|
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||||
|
current_value = getattr(self, field_name)
|
||||||
|
if isinstance(legacy_value, tuple):
|
||||||
|
current_value = tuple(current_value)
|
||||||
|
if current_value == legacy_value:
|
||||||
|
logger.warning(
|
||||||
|
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||||
|
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||||
|
"if this is not what you want.",
|
||||||
|
field_name,
|
||||||
|
legacy_value,
|
||||||
|
n1_7_value,
|
||||||
|
)
|
||||||
|
setattr(self, field_name, n1_7_value)
|
||||||
|
|
||||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||||
if inferred_version is not None and inferred_version != self.model_version:
|
if inferred_version is not None and inferred_version != self.model_version:
|
||||||
raise ValueError(
|
message = (
|
||||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||||
)
|
)
|
||||||
|
if inferred_version == GROOT_N1_5:
|
||||||
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
@@ -512,7 +613,9 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list[int]:
|
def action_delta_indices(self) -> list[int]:
|
||||||
"""Return indices for delta actions."""
|
"""Return indices for delta actions."""
|
||||||
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
model_action_horizon = (
|
||||||
|
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||||
|
)
|
||||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from torch.distributions import Beta
|
|||||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||||
|
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||||
|
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
@@ -71,13 +72,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
|||||||
"backbone_embedding_dim": 2048,
|
"backbone_embedding_dim": 2048,
|
||||||
"tune_llm": False,
|
"tune_llm": False,
|
||||||
"tune_visual": False,
|
"tune_visual": False,
|
||||||
"select_layer": 12,
|
"select_layer": 16,
|
||||||
"reproject_vision": False,
|
"reproject_vision": False,
|
||||||
"use_flash_attention": True,
|
"use_flash_attention": True,
|
||||||
"load_bf16": False,
|
"load_bf16": False,
|
||||||
"backbone_trainable_params_fp32": True,
|
"backbone_trainable_params_fp32": True,
|
||||||
"image_crop_size": (230, 230),
|
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||||
"image_target_size": (256, 256),
|
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||||
"shortest_image_edge": None,
|
"shortest_image_edge": None,
|
||||||
"crop_fraction": None,
|
"crop_fraction": None,
|
||||||
"random_rotation_angle": None,
|
"random_rotation_angle": None,
|
||||||
@@ -819,11 +820,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
|||||||
|
|
||||||
|
|
||||||
def get_backbone_cls(config: GR00TN17Config):
|
def get_backbone_cls(config: GR00TN17Config):
|
||||||
if (
|
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||||
config.backbone_model_type == "qwen"
|
return Qwen3Backbone
|
||||||
or "nvidia/Cosmos-Reason2" in config.model_name
|
if config.backbone_model_type == "qwen":
|
||||||
or "Qwen/Qwen3-VL" in config.model_name
|
logger.warning(
|
||||||
):
|
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||||
|
"backbone because backbone_model_type='qwen'.",
|
||||||
|
config.model_name,
|
||||||
|
)
|
||||||
return Qwen3Backbone
|
return Qwen3Backbone
|
||||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||||
|
|
||||||
@@ -909,7 +913,7 @@ class GR00TN17(PreTrainedModel):
|
|||||||
"trust_remote_code": True
|
"trust_remote_code": True
|
||||||
}
|
}
|
||||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||||
for key in ("revision", "cache_dir", "local_files_only", "token"):
|
for key in ("cache_dir", "local_files_only", "token"):
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||||
|
|
||||||
|
|||||||
@@ -18,15 +18,12 @@
|
|||||||
Groot Policy Wrapper for LeRobot Integration
|
Groot Policy Wrapper for LeRobot Integration
|
||||||
|
|
||||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||||
possible without porting their code.
|
possible without porting their code. Dataset loading and training
|
||||||
|
orchestration are handled by LeRobot's standard training stack.
|
||||||
Notes:
|
|
||||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
|
||||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
|
||||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
|
|||||||
from ..pretrained import PreTrainedPolicy
|
from ..pretrained import PreTrainedPolicy
|
||||||
from ..utils import get_device_from_parameters
|
from ..utils import get_device_from_parameters
|
||||||
from .configuration_groot import (
|
from .configuration_groot import (
|
||||||
|
GROOT_N1_5,
|
||||||
|
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||||
GROOT_N1_7,
|
GROOT_N1_7,
|
||||||
GrootConfig,
|
GrootConfig,
|
||||||
infer_groot_model_version,
|
infer_groot_model_version,
|
||||||
@@ -50,9 +49,103 @@ from .configuration_groot import (
|
|||||||
normalize_groot_model_version,
|
normalize_groot_model_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", bound="GrootPolicy")
|
T = TypeVar("T", bound="GrootPolicy")
|
||||||
|
|
||||||
|
|
||||||
|
def _resolve_embodiment_id(value: int | str) -> int:
|
||||||
|
"""Resolve an embodiment id from an int or an N1.7 embodiment name.
|
||||||
|
|
||||||
|
Names are looked up in N1_7_EMBODIMENT_MAPPING (e.g. 'new_embodiment' -> 10).
|
||||||
|
Raises ValueError listing the known keys if the name is unknown.
|
||||||
|
"""
|
||||||
|
from .processor_groot import N1_7_EMBODIMENT_MAPPING
|
||||||
|
|
||||||
|
if isinstance(value, bool): # bool is a subclass of int; reject it explicitly.
|
||||||
|
raise ValueError(f"Embodiment id must be an int or embodiment name, got bool {value!r}.")
|
||||||
|
if isinstance(value, int):
|
||||||
|
return value
|
||||||
|
if value in N1_7_EMBODIMENT_MAPPING:
|
||||||
|
return N1_7_EMBODIMENT_MAPPING[value]
|
||||||
|
raise ValueError(
|
||||||
|
f"Unknown GR00T N1.7 embodiment name '{value}'. Known names: "
|
||||||
|
f"{sorted(N1_7_EMBODIMENT_MAPPING.keys())}."
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _warm_start_embodiment_slot(model, source_id: int, target_id: int) -> None:
|
||||||
|
"""Copy category-specific action-head weights from one embodiment slot to another.
|
||||||
|
|
||||||
|
Used at base-model load (training only) to warm-start a cold target embodiment slot
|
||||||
|
(e.g. 'new_embodiment') from a pretrained slot. Copies the per-category ``W``/``b``
|
||||||
|
parameters across every CategorySpecificLinear in the action head's state encoder,
|
||||||
|
action encoder, and action decoder. No-ops (with a logged warning) if the ids are out
|
||||||
|
of range or identical.
|
||||||
|
"""
|
||||||
|
if source_id == target_id:
|
||||||
|
logger.warning(
|
||||||
|
"GR00T warm_start_embodiment_slot: source and target embodiment id are both %d; "
|
||||||
|
"skipping (nothing to copy).",
|
||||||
|
source_id,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
action_head = getattr(model, "action_head", None)
|
||||||
|
if action_head is None:
|
||||||
|
logger.warning("GR00T warm_start_embodiment_slot: model has no action_head; skipping.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Each entry is (submodule, [CategorySpecificLinear attribute names]).
|
||||||
|
linear_groups = [
|
||||||
|
(getattr(action_head, "state_encoder", None), ["layer1", "layer2"]),
|
||||||
|
(getattr(action_head, "action_encoder", None), ["W1", "W2", "W3"]),
|
||||||
|
(getattr(action_head, "action_decoder", None), ["layer1", "layer2"]),
|
||||||
|
]
|
||||||
|
|
||||||
|
copied: list[str] = []
|
||||||
|
with torch.no_grad():
|
||||||
|
for submodule, attr_names in linear_groups:
|
||||||
|
if submodule is None:
|
||||||
|
continue
|
||||||
|
submodule_name = type(submodule).__name__
|
||||||
|
for attr_name in attr_names:
|
||||||
|
lin = getattr(submodule, attr_name, None)
|
||||||
|
if lin is None or not hasattr(lin, "W") or not hasattr(lin, "b"):
|
||||||
|
continue
|
||||||
|
num_categories = lin.W.shape[0]
|
||||||
|
if not (0 <= source_id < num_categories and 0 <= target_id < num_categories):
|
||||||
|
logger.warning(
|
||||||
|
"GR00T warm_start_embodiment_slot: source_id=%d/target_id=%d out of range "
|
||||||
|
"for %s.%s (num_categories=%d); skipping this layer.",
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
submodule_name,
|
||||||
|
attr_name,
|
||||||
|
num_categories,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
lin.W.data[target_id] = lin.W.data[source_id].clone()
|
||||||
|
lin.b.data[target_id] = lin.b.data[source_id].clone()
|
||||||
|
copied.append(f"{submodule_name}.{attr_name}")
|
||||||
|
|
||||||
|
if copied:
|
||||||
|
logger.info(
|
||||||
|
"GR00T warm_start_embodiment_slot: copied action-head weights from embodiment slot %d "
|
||||||
|
"to slot %d for: %s.",
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
", ".join(copied),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logger.warning(
|
||||||
|
"GR00T warm_start_embodiment_slot: no action-head weights were copied "
|
||||||
|
"(source_id=%d, target_id=%d).",
|
||||||
|
source_id,
|
||||||
|
target_id,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class GrootPolicy(PreTrainedPolicy):
|
class GrootPolicy(PreTrainedPolicy):
|
||||||
"""Wrapper around external Groot model for LeRobot integration."""
|
"""Wrapper around external Groot model for LeRobot integration."""
|
||||||
|
|
||||||
@@ -92,8 +185,24 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
transformers_loading_kwargs={"trust_remote_code": True},
|
transformers_loading_kwargs={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
# Inference-only override for the number of flow-matching denoising steps. The action
|
||||||
model.config.compute_dtype = model.compute_dtype
|
# head reads self.num_inference_timesteps in get_action_with_features; dt (1/n) and the
|
||||||
|
# t schedule adapt automatically.
|
||||||
|
if self.config.num_inference_timesteps is not None:
|
||||||
|
n = int(self.config.num_inference_timesteps)
|
||||||
|
model.config.num_inference_timesteps = n
|
||||||
|
model.action_head.num_inference_timesteps = n
|
||||||
|
|
||||||
|
# Opt-in: warm-start a cold embodiment slot (e.g. 'new_embodiment') from a pretrained
|
||||||
|
# slot's action-head weights. Done here (not in from_pretrained) so it applies on every
|
||||||
|
# fresh base-model build -- training via make_policy instantiates GrootPolicy(config)
|
||||||
|
# directly (factory uses __init__ when cfg.pretrained_path is unset), it does NOT go
|
||||||
|
# through from_pretrained. On a fine-tuned checkpoint reload this also runs but is
|
||||||
|
# immediately overwritten by the loaded state_dict, so it is a harmless no-op there.
|
||||||
|
if self.config.warm_start_embodiment_slot is not None:
|
||||||
|
source_id = _resolve_embodiment_id(self.config.warm_start_embodiment_slot)
|
||||||
|
target_id = _resolve_embodiment_id(self.config.embodiment_tag)
|
||||||
|
_warm_start_embodiment_slot(model, source_id, target_id)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
@@ -148,9 +257,10 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if config is not None
|
if config is not None
|
||||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||||
)
|
)
|
||||||
print(
|
logger.info(
|
||||||
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
|
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
requested_version,
|
||||||
|
pretrained_name_or_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
@@ -181,7 +291,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if is_finetuned_checkpoint:
|
if is_finetuned_checkpoint:
|
||||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||||
return super().from_pretrained(
|
return super().from_pretrained(
|
||||||
pretrained_name_or_path=pretrained_name_or_path,
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -197,7 +307,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# This is a base GR00T model - load it fresh
|
# This is a base GR00T model - load it fresh
|
||||||
print("Detected base GR00T model, loading from HuggingFace...")
|
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||||
@@ -229,10 +339,13 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
config.model_version = normalize_groot_model_version(config.model_version)
|
config.model_version = normalize_groot_model_version(config.model_version)
|
||||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||||
if inferred_version is not None and inferred_version != config.model_version:
|
if inferred_version is not None and inferred_version != config.model_version:
|
||||||
raise ValueError(
|
message = (
|
||||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||||
)
|
)
|
||||||
|
if inferred_version == GROOT_N1_5:
|
||||||
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||||
# in __init__ via _create_groot_model()
|
# in __init__ via _create_groot_model()
|
||||||
policy = cls(config)
|
policy = cls(config)
|
||||||
@@ -258,7 +371,11 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
horizons.append(checkpoint_action_horizon)
|
horizons.append(checkpoint_action_horizon)
|
||||||
if execution_horizon is not None:
|
if execution_horizon is not None:
|
||||||
horizons.append(execution_horizon)
|
horizons.append(execution_horizon)
|
||||||
return min(horizons)
|
# An explicit config override caps the open-loop horizon (inference cadence), overriding
|
||||||
|
# the value inferred from the checkpoint/embodiment.
|
||||||
|
if self.config.execution_horizon is not None:
|
||||||
|
horizons.append(max(1, int(self.config.execution_horizon)))
|
||||||
|
return max(1, min(horizons))
|
||||||
|
|
||||||
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
def _resolve_prediction_horizon(self, actions: Tensor) -> int:
|
||||||
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
"""Return the policy-facing action horizon for a native GR00T prediction."""
|
||||||
@@ -297,9 +414,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
allowed_base.add("action_mask")
|
allowed_base.add("action_mask")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: v
|
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||||
for k, v in batch.items()
|
|
||||||
if k in allowed_base and not (k.startswith("next.") or k == "info")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_n1_7_rtc_inputs(
|
def _prepare_n1_7_rtc_inputs(
|
||||||
@@ -320,9 +435,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if prev_actions.ndim == 2:
|
if prev_actions.ndim == 2:
|
||||||
prev_actions = prev_actions.unsqueeze(0)
|
prev_actions = prev_actions.unsqueeze(0)
|
||||||
elif prev_actions.ndim != 3:
|
elif prev_actions.ndim != 3:
|
||||||
raise ValueError(
|
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||||
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
|
|
||||||
)
|
|
||||||
|
|
||||||
state = inputs.get("state")
|
state = inputs.get("state")
|
||||||
if state is None:
|
if state is None:
|
||||||
@@ -331,9 +444,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||||
elif prev_actions.shape[0] != batch_size:
|
elif prev_actions.shape[0] != batch_size:
|
||||||
raise ValueError(
|
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
|
||||||
)
|
|
||||||
|
|
||||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||||
@@ -346,7 +457,9 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
return inputs, None
|
return inputs, None
|
||||||
|
|
||||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
model_action_horizon = int(
|
||||||
|
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||||
|
)
|
||||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||||
if prev_actions.shape[1] > model_action_horizon:
|
if prev_actions.shape[1] > model_action_horizon:
|
||||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||||
@@ -409,6 +522,11 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||||
loss = outputs.get("loss")
|
loss = outputs.get("loss")
|
||||||
|
if loss is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||||
|
"'action' and 'action_mask'; check the preprocessor output."
|
||||||
|
)
|
||||||
|
|
||||||
loss_dict = {"loss": loss.item()}
|
loss_dict = {"loss": loss.item()}
|
||||||
|
|
||||||
@@ -425,6 +543,16 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
"""
|
"""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
|
# Freeze the relative-action reference at the exact chunk-prediction event so every popped
|
||||||
|
# delta of this chunk is reconstructed (in the postprocessor) against this S_T, not the
|
||||||
|
# per-tick latest state. Driven by the predict event, so it is correct under any runtime
|
||||||
|
# n_action_steps/execution_horizon. No-op for non-relative checkpoints (holder absent/unused).
|
||||||
|
from .processor_groot import _GROOT_REF_HOLDER_KEY
|
||||||
|
|
||||||
|
holder = batch.get(_GROOT_REF_HOLDER_KEY)
|
||||||
|
if holder is not None:
|
||||||
|
holder.freeze()
|
||||||
|
|
||||||
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
# Preprocessing is handled by the processor pipeline, so we just filter the batch.
|
||||||
# During inference, we do not pass action because it is predicted.
|
# During inference, we do not pass action because it is predicted.
|
||||||
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
# N1.7 still carries a 2-D action horizon mask from its checkpoint processor.
|
||||||
@@ -471,33 +599,21 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
# Internal helpers
|
# Internal helpers
|
||||||
# -------------------------
|
# -------------------------
|
||||||
def _handle_flash_attention_compatibility(self) -> None:
|
def _handle_flash_attention_compatibility(self) -> None:
|
||||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
"""Log Flash Attention availability (diagnostic only).
|
||||||
|
|
||||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||||
is compiled against a different PyTorch version than what's currently installed.
|
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||||
|
change behaviour or mutate global state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set environment variables to handle Flash Attention compatibility
|
|
||||||
# These help with symbol resolution issues
|
|
||||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
|
||||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
|
||||||
|
|
||||||
# Try to import flash_attn and handle failures gracefully
|
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
|
|
||||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print(f"[GROOT] Flash Attention not available: {e}")
|
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||||
print("[GROOT] Will use fallback attention mechanism")
|
except Exception as e: # noqa: BLE001
|
||||||
except Exception as e:
|
logger.warning(
|
||||||
if "undefined symbol" in str(e):
|
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
e,
|
||||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
)
|
||||||
print(" pip uninstall flash-attn")
|
|
||||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
|
||||||
print("[GROOT] Continuing with fallback attention mechanism")
|
|
||||||
else:
|
|
||||||
print(f"[GROOT] Flash Attention error: {e}")
|
|
||||||
print("[GROOT] Continuing with fallback attention mechanism")
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user