mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-17 08:17:02 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 5753f8c18b | |||
| 97bd373d15 | |||
| 10a73e3c95 | |||
| 27c9288b24 | |||
| 378897800a | |||
| fcb371eddd | |||
| 895eaf0d7c | |||
| edda8552ec |
@@ -4,6 +4,9 @@ GR00T is an NVIDIA foundation model family for generalized humanoid robot reason
|
|||||||
|
|
||||||
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
LeRobot integrates GR00T N1.7 through the `groot` policy type.
|
||||||
|
|
||||||
|
> [!WARNING]
|
||||||
|
> **Breaking change:** GR00T N1.5 support was removed from LeRobot, and current releases support GR00T N1.7 only. N1.5 checkpoints, configs, and `--policy.model_version=n1.5` are rejected with a clear error. To keep using an N1.5 checkpoint, pin the last release that supports it: `pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 (`model_version='n1.7'`, base model [`nvidia/GR00T-N1.7-3B`](https://huggingface.co/nvidia/GR00T-N1.7-3B)).
|
||||||
|
|
||||||
## Model Overview
|
## Model Overview
|
||||||
|
|
||||||
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
GR00T N1.7 uses a Cosmos-Reason2/Qwen3-VL backbone and provides checkpoints for SimplerEnv, DROID, and LIBERO.
|
||||||
@@ -133,7 +136,7 @@ Replace the `XX` placeholders with final eval artifacts before merge.
|
|||||||
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
Download the suite checkpoint locally, then point `--policy.base_model_path` at the downloaded subdirectory. `--policy.path` is reserved for LeRobot checkpoints that contain a LeRobot `config.json` with a `type` field.
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
huggingface-cli download nvidia/GR00T-N1.7-LIBERO \
|
hf download nvidia/GR00T-N1.7-LIBERO \
|
||||||
--include "libero_spatial/*" \
|
--include "libero_spatial/*" \
|
||||||
--local-dir ./GR00T-N1.7-LIBERO
|
--local-dir ./GR00T-N1.7-LIBERO
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,13 @@
|
|||||||
## Research Paper
|
## Research Paper
|
||||||
|
|
||||||
Paper: https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
GR00T N1 technical report (covers the GR00T N1.x family, including N1.7): https://arxiv.org/abs/2503.14734
|
||||||
|
|
||||||
|
GR00T N1.7 model card: https://huggingface.co/nvidia/GR00T-N1.7-3B
|
||||||
|
|
||||||
|
GR00T N1.5 research page (earlier version): https://research.nvidia.com/labs/gear/gr00t-n1_5/
|
||||||
|
|
||||||
|
> GR00T N1.5 support was removed from LeRobot; the last release supporting it is `lerobot==0.5.1`.
|
||||||
|
> Current releases support GR00T N1.7 only.
|
||||||
|
|
||||||
## Repository
|
## Repository
|
||||||
|
|
||||||
@@ -31,12 +38,22 @@ Hugging Face Models:
|
|||||||
|
|
||||||
## Original-vs-LeRobot parity test
|
## Original-vs-LeRobot parity test
|
||||||
|
|
||||||
`tests/policies/groot/test_groot_vs_original.py` verifies that this LeRobot
|
`tests/policies/groot/test_groot_vs_original.py` verifies this LeRobot
|
||||||
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
reimplementation of GR00T N1.7 (Qwen3-VL backbone + flow-matching action head)
|
||||||
produces the **same raw model output** (`get_action(...)["action_pred"]`, the
|
against NVIDIA's original `gr00t` package with two comparisons, each parametrized
|
||||||
normalized flow-matching prediction) as NVIDIA's original `gr00t` package, given
|
over every embodiment tag present in the checkpoint:
|
||||||
byte-identical pre-processed inputs and the same flow-matching seed. It is
|
|
||||||
parametrized over every embodiment tag present in the checkpoint.
|
1. **Model parity** — given byte-identical pre-processed inputs and the same
|
||||||
|
flow-matching seed (recorded in each artifact), both implementations must produce
|
||||||
|
the **same raw model output** (`get_action(...)["action_pred"]`, the normalized
|
||||||
|
flow-matching prediction). Output shapes must match exactly; any action-horizon
|
||||||
|
or action-dim mismatch fails the test.
|
||||||
|
2. **Preprocessor parity** — given the identical raw observations (per-camera
|
||||||
|
frames, state vectors, language instruction), LeRobot's own preprocessor pipeline
|
||||||
|
(real Qwen3-VL chat template / tokenizer / image packing + checkpoint-driven
|
||||||
|
state normalization, no mocks) must produce the **same collated model inputs**
|
||||||
|
(`input_ids`, `attention_mask`, `pixel_values`, `image_grid_thw`, `state`,
|
||||||
|
`embodiment_id`) as the original package's processor.
|
||||||
|
|
||||||
### Why two environments
|
### Why two environments
|
||||||
|
|
||||||
@@ -48,25 +65,37 @@ is itself a defaulted dataclass, so the original config dataclasses fail to impo
|
|||||||
|
|
||||||
So the test uses a **producer / consumer** split across two venvs:
|
So the test uses a **producer / consumer** split across two venvs:
|
||||||
|
|
||||||
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the *original*
|
1. **Producer** — `tests/policies/groot/utils/dump_original_n1_7.py`, run in the _original_
|
||||||
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
gr00t venv. For each embodiment it builds dummy inputs generically from the
|
||||||
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
checkpoint metadata (state dims from `statistics.json`; camera/language keys from
|
||||||
the processor modality configs), runs the original model, and saves the exact
|
the processor modality configs), runs the original model, and saves to one `.npz`
|
||||||
collated inputs + raw `action_pred` to one `.npz` per tag.
|
per tag: the raw observations (`raw::` keys), the exact collated inputs
|
||||||
2. **Consumer** — the pytest above, run in the *LeRobot* venv. It discovers every
|
(`in::` keys), the seed, and the raw `action_pred`.
|
||||||
`.npz`, replays the byte-identical inputs through the LeRobot model with the same
|
2. **Consumer** — the pytest above, run in the _LeRobot_ venv. It discovers every
|
||||||
seed, and asserts the outputs match.
|
`.npz`; the model-parity case replays the byte-identical collated inputs through
|
||||||
|
the LeRobot model with the recorded seed and asserts the outputs match, and the
|
||||||
|
preprocessor-parity case replays the raw observations through LeRobot's full
|
||||||
|
preprocessor pipeline and asserts the collated tensors match.
|
||||||
|
|
||||||
|
> Artifacts generated by older versions of the dump script contain no `raw::`
|
||||||
|
> fields; the preprocessor-parity case then **skips** with a regeneration hint.
|
||||||
|
> Re-run the producer to refresh them.
|
||||||
|
|
||||||
### Fairness controls
|
### Fairness controls
|
||||||
|
|
||||||
- **Same pre-processed inputs** — the original processor's `input_ids`,
|
- **Same pre-processed inputs (model parity)** — the original processor's `input_ids`,
|
||||||
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
`pixel_values`, `image_grid_thw`, `attention_mask`, `state`, `embodiment_id` are
|
||||||
fed verbatim to the LeRobot model (no re-tokenization / re-normalization).
|
fed verbatim to the LeRobot model (no re-tokenization / re-normalization), so the
|
||||||
|
model comparison isolates the model. LeRobot's own tokenization / image packing is
|
||||||
|
covered separately by the preprocessor-parity case, which compares its output
|
||||||
|
against those same collated tensors from identical raw observations.
|
||||||
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
- **Same precision + attention kernel** — both sides run **fp32 + SDPA**. The
|
||||||
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
original defaults to `use_flash_attention=True` (flash_attention_2 + bf16); the
|
||||||
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
producer forces SDPA + fp32. (With the defaults the gap is ~3e-2 — pure
|
||||||
kernel/rounding noise, not an implementation difference.)
|
kernel/rounding noise, not an implementation difference.)
|
||||||
- **Same flow-matching seed** — fixed (42) right before sampling on both sides.
|
- **Same flow-matching seed** — fixed right before sampling on both sides; the
|
||||||
|
producer records it in each artifact (`--seed`, default 42) and the consumer
|
||||||
|
replays the recorded value.
|
||||||
|
|
||||||
### How to run
|
### How to run
|
||||||
|
|
||||||
@@ -90,14 +119,14 @@ CUDA_VISIBLE_DEVICES=0 GROOT_PARITY_DEVICE=cuda \
|
|||||||
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
uv run pytest tests/policies/groot/test_groot_vs_original.py -v -s
|
||||||
```
|
```
|
||||||
|
|
||||||
The `.npz` artifacts are local-only (gitignored, ~6–9 MB each) and are regenerated by
|
The `.npz` artifacts are local-only (gitignored, ~6–10 MB each) and are regenerated by
|
||||||
the producer; they are never committed. The test **skips** (does not fail) on CI or
|
the producer; they are never committed. The tests **skip** (do not fail) on CI or
|
||||||
when the checkpoint / artifacts are absent.
|
when the checkpoint / artifacts are absent.
|
||||||
|
|
||||||
#### Env knobs (all optional)
|
#### Env knobs (all optional)
|
||||||
|
|
||||||
| Var | Default | Purpose |
|
| Var | Default | Purpose |
|
||||||
|---|---|---|
|
| ----------------------------------------- | -------------------------------- | ------------------------------------- |
|
||||||
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
| `GROOT_N1_7_PARITY_DIR` | `tests/policies/groot/artifacts` | directory of per-tag `.npz` artifacts |
|
||||||
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
| `GROOT_N1_7_LIBERO_CKPT` | auto (HF cache) | override checkpoint dir |
|
||||||
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
| `GROOT_PARITY_DEVICE` | `cuda` if available | `cpu` or `cuda` |
|
||||||
|
|||||||
@@ -14,6 +14,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
|
import logging
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@@ -42,6 +43,9 @@ else:
|
|||||||
Timesteps = None
|
Timesteps = None
|
||||||
|
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TimestepEncoder(nn.Module):
|
class TimestepEncoder(nn.Module):
|
||||||
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
def __init__(self, embedding_dim, compute_dtype=torch.float32):
|
||||||
require_package("diffusers", extra="groot")
|
require_package("diffusers", extra="groot")
|
||||||
@@ -265,8 +269,8 @@ class DiT(ModelMixin, ConfigMixin):
|
|||||||
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
self.norm_out = nn.LayerNorm(self.inner_dim, elementwise_affine=False, eps=1e-6)
|
||||||
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
self.proj_out_1 = nn.Linear(self.inner_dim, 2 * self.inner_dim)
|
||||||
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
self.proj_out_2 = nn.Linear(self.inner_dim, self.config.output_dim)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of DiT parameters: ",
|
"Total number of DiT parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -426,8 +430,8 @@ class SelfAttentionTransformer(ModelMixin, ConfigMixin):
|
|||||||
for _ in range(self.config.num_layers)
|
for _ in range(self.config.num_layers)
|
||||||
]
|
]
|
||||||
)
|
)
|
||||||
print(
|
logger.debug(
|
||||||
"Total number of SelfAttentionTransformer parameters: ",
|
"Total number of SelfAttentionTransformer parameters: %d",
|
||||||
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
sum(p.numel() for p in self.parameters() if p.requires_grad),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -15,6 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -23,15 +24,33 @@ from lerobot.configs import FeatureType, NormalizationMode, PolicyFeature, PreTr
|
|||||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
GROOT_N1_7 = "n1.7"
|
GROOT_N1_7 = "n1.7"
|
||||||
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
# Legacy GR00T N1.5 identifier. N1.5 is NOT a supported model_version (it is
|
||||||
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
# intentionally absent from _GROOT_MODEL_VERSION_ALIASES so normalize_groot_model_version
|
||||||
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
# still rejects it). It is retained only so that infer_groot_model_version can recognise
|
||||||
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
# an N1.5 base path/checkpoint and the N1.7 config/loader can reject the mismatch.
|
||||||
GROOT_N1_5 = "n1.5"
|
GROOT_N1_5 = "n1.5"
|
||||||
|
# Canonical guidance appended to every error raised when an N1.5 checkpoint, config,
|
||||||
|
# or processor pipeline is detected. Keep this message in sync with docs/source/groot.mdx.
|
||||||
|
GROOT_N1_5_REMOVAL_GUIDANCE = (
|
||||||
|
"GR00T N1.5 support was removed from LeRobot. "
|
||||||
|
"To keep using an N1.5 checkpoint, pin the last release that supports it: "
|
||||||
|
"`pip install 'lerobot==0.5.1'`. To use the current release, migrate to GR00T N1.7 "
|
||||||
|
"(model_version='n1.7', base model nvidia/GR00T-N1.7-3B)."
|
||||||
|
)
|
||||||
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
GROOT_N1_7_BASE_MODEL = "nvidia/GR00T-N1.7-3B"
|
||||||
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
GROOT_N1_7_BACKBONE_MODEL = "nvidia/Cosmos-Reason2-2B"
|
||||||
|
# Default GR00T N1.7 training resolution. Fallback if processor_config lacks sizing. Prevents mismatched
|
||||||
|
# full-res patchification by forcing a resize. Mirrored by GR00T_N1_7_DEFAULTS in groot_n1_7.py.
|
||||||
|
N1_7_DEFAULT_IMAGE_TARGET_SIZE = (256, 256)
|
||||||
|
N1_7_DEFAULT_IMAGE_CROP_SIZE = (230, 230)
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO = "libero"
|
||||||
|
# Sentinel meaning "the user did not pick an action decode transform": __post_init__ resolves it
|
||||||
|
# to the embodiment default ('libero' for 'libero_sim', otherwise None). It is distinct from an
|
||||||
|
# explicit 'none' (resolved to None) so an opt-out survives a draccus save/load round-trip.
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_AUTO = "auto"
|
||||||
|
|
||||||
_GROOT_MODEL_VERSION_ALIASES = {
|
_GROOT_MODEL_VERSION_ALIASES = {
|
||||||
"n1.7": GROOT_N1_7,
|
"n1.7": GROOT_N1_7,
|
||||||
@@ -41,7 +60,12 @@ _GROOT_MODEL_VERSION_ALIASES = {
|
|||||||
"1.7": GROOT_N1_7,
|
"1.7": GROOT_N1_7,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Legacy N1.5 spellings, kept ONLY so they can be detected and rejected with
|
||||||
|
# GROOT_N1_5_REMOVAL_GUIDANCE (see GROOT_N1_5 above). Never map these to a supported version.
|
||||||
|
_GROOT_N1_5_VERSION_ALIASES = {"n1.5", "n1_5", "n1d5", "n15", "1.5"}
|
||||||
|
|
||||||
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
_GROOT_ACTION_DECODE_TRANSFORM_ALIASES = {
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_AUTO: GROOT_ACTION_DECODE_TRANSFORM_AUTO,
|
||||||
"none": None,
|
"none": None,
|
||||||
"": None,
|
"": None,
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO: GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||||
@@ -52,9 +76,10 @@ def normalize_groot_model_version(model_version: str) -> str:
|
|||||||
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
normalized = _GROOT_MODEL_VERSION_ALIASES.get(model_version.lower())
|
||||||
if normalized is None:
|
if normalized is None:
|
||||||
supported = GROOT_N1_7
|
supported = GROOT_N1_7
|
||||||
raise ValueError(
|
message = f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
||||||
f"Unsupported GR00T model_version '{model_version}'. Supported versions: {supported}."
|
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||||
)
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
return normalized
|
return normalized
|
||||||
|
|
||||||
|
|
||||||
@@ -286,6 +311,8 @@ def _infer_groot_model_version_from_local_config(model_path: str) -> str | None:
|
|||||||
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
||||||
model_version = config.get("model_version")
|
model_version = config.get("model_version")
|
||||||
if isinstance(model_version, str):
|
if isinstance(model_version, str):
|
||||||
|
if model_version.lower() in _GROOT_N1_5_VERSION_ALIASES:
|
||||||
|
return GROOT_N1_5
|
||||||
try:
|
try:
|
||||||
return normalize_groot_model_version(model_version)
|
return normalize_groot_model_version(model_version)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
@@ -298,8 +325,14 @@ def _infer_groot_model_version_from_config(config: dict) -> str | None:
|
|||||||
normalized = candidate.lower().replace("-", "_")
|
normalized = candidate.lower().replace("-", "_")
|
||||||
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
if normalized in {"gr00tn1d7", "gr00t_n1d7", "gr00t_n1_7"}:
|
||||||
return GROOT_N1_7
|
return GROOT_N1_7
|
||||||
|
if normalized in {"gr00t_n1_5", "gr00tn1_5", "gr00t_n15", "gr00t_n1d5", "gr00tn1d5"}:
|
||||||
|
return GROOT_N1_5
|
||||||
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
if config.get("model_name") == GROOT_N1_7_BACKBONE_MODEL:
|
||||||
return GROOT_N1_7
|
return GROOT_N1_7
|
||||||
|
# The Eagle VLM backbone is specific to pre-N1.7 GR00T checkpoints (N1.7 uses Cosmos/Qwen3-VL).
|
||||||
|
backbone_cfg = config.get("backbone_cfg")
|
||||||
|
if isinstance(backbone_cfg, dict) and "eagle_path" in backbone_cfg:
|
||||||
|
return GROOT_N1_5
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
@@ -310,29 +343,30 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
|
|
||||||
# Basic policy settings
|
# Basic policy settings
|
||||||
n_obs_steps: int = 1
|
n_obs_steps: int = 1
|
||||||
chunk_size: int = 50
|
chunk_size: int = 40
|
||||||
n_action_steps: int = 50
|
n_action_steps: int = 40
|
||||||
|
|
||||||
# Dimension settings (must match pretrained GR00T model expectations)
|
# Dimension settings (must match pretrained GR00T model expectations)
|
||||||
# Maximum state dimension. Shorter states will be zero-padded.
|
# Maximum state dimension. Shorter states will be zero-padded.
|
||||||
max_state_dim: int = 64
|
max_state_dim: int = 132
|
||||||
|
|
||||||
# Maximum action dimension. Shorter actions will be zero-padded.
|
# Maximum action dimension. Shorter actions will be zero-padded.
|
||||||
max_action_dim: int = 32
|
max_action_dim: int = 132
|
||||||
|
|
||||||
# Normalization (start with identity, adjust as needed)
|
# GR00T normalizes state/action internally in its processor steps (min/max with
|
||||||
|
# q01/q99 percentiles, per embodiment), and the Qwen3-VL backbone's image processor
|
||||||
|
# handles image normalization. The policy therefore does NOT use LeRobot's
|
||||||
|
# NormalizerProcessorStep/UnnormalizerProcessorStep, so this mapping is intentionally
|
||||||
|
# IDENTITY for every feature and is not consulted by make_groot_pre_post_processors.
|
||||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||||
default_factory=lambda: {
|
default_factory=lambda: {
|
||||||
"VISUAL": NormalizationMode.IDENTITY,
|
"VISUAL": NormalizationMode.IDENTITY,
|
||||||
"STATE": NormalizationMode.MEAN_STD,
|
"STATE": NormalizationMode.IDENTITY,
|
||||||
"ACTION": NormalizationMode.MEAN_STD,
|
"ACTION": NormalizationMode.IDENTITY,
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Image preprocessing (adjust to match Groot's expected input)
|
# Groot-specific model parameters
|
||||||
image_size: tuple[int, int] = (224, 224)
|
|
||||||
|
|
||||||
# Groot-specific model parameters (from groot_finetune_script.py)
|
|
||||||
|
|
||||||
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
# Explicit GR00T model family selection. LeRobot supports GR00T N1.7 only.
|
||||||
model_version: str = GROOT_N1_7
|
model_version: str = GROOT_N1_7
|
||||||
@@ -344,7 +378,9 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
n1_7_backbone_model: str = GROOT_N1_7_BACKBONE_MODEL
|
||||||
|
|
||||||
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
# Optional named action transform applied after raw N1.7 checkpoint decoding and before env.step().
|
||||||
action_decode_transform: str | None = None
|
# 'auto' (default) resolves to the embodiment default ('libero' for 'libero_sim', otherwise no
|
||||||
|
# transform). Pass 'none' to explicitly disable the transform, including for 'libero_sim'.
|
||||||
|
action_decode_transform: str | None = GROOT_ACTION_DECODE_TRANSFORM_AUTO
|
||||||
|
|
||||||
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
# Embodiment tag to use for training (e.g. 'new_embodiment', 'gr1')
|
||||||
embodiment_tag: str = "new_embodiment"
|
embodiment_tag: str = "new_embodiment"
|
||||||
@@ -384,17 +420,16 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
warmup_ratio: float = 0.05
|
warmup_ratio: float = 0.05
|
||||||
use_bf16: bool = True
|
use_bf16: bool = True
|
||||||
|
|
||||||
# Dataset parameters
|
# TODO(Steven): Remove these deprecated fields in a future release.
|
||||||
# Video backend to use for training ('decord' or 'torchvision_av')
|
# Deprecated Isaac-GR00T runner/N1.5 fields below — unused by the LeRobot N1.7 implementation
|
||||||
|
# (nothing in src/lerobot reads them). They are kept only so config.json files saved by
|
||||||
|
# earlier lerobot releases still parse: draccus rejects unknown fields, so removing them
|
||||||
|
# would break every previously saved groot checkpoint at config-load time.
|
||||||
|
image_size: tuple[int, int] = (256, 256) # image sizing is handled by the backbone's image processor.
|
||||||
|
tokenizer_assets_repo: str | None = None
|
||||||
video_backend: str = "decord"
|
video_backend: str = "decord"
|
||||||
|
|
||||||
# Whether to balance dataset weights in mixture datasets
|
|
||||||
balance_dataset_weights: bool = True
|
balance_dataset_weights: bool = True
|
||||||
|
|
||||||
# Whether to sample trajectories weighted by their length
|
|
||||||
balance_trajectory_weights: bool = True
|
balance_trajectory_weights: bool = True
|
||||||
|
|
||||||
# Optional dataset paths for delegating training to Isaac-GR00T runner
|
|
||||||
dataset_paths: list[str] | None = None
|
dataset_paths: list[str] | None = None
|
||||||
output_dir: str = "./tmp/gr00t"
|
output_dir: str = "./tmp/gr00t"
|
||||||
save_steps: int = 1000
|
save_steps: int = 1000
|
||||||
@@ -405,6 +440,12 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
resume: bool = False
|
resume: bool = False
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
|
if self.tokenizer_assets_repo is not None:
|
||||||
|
raise ValueError(
|
||||||
|
"Config sets 'tokenizer_assets_repo', which only existed for GR00T N1.5; this looks "
|
||||||
|
f"like a legacy GR00T N1.5 checkpoint or config. {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
)
|
||||||
|
|
||||||
self.model_version = normalize_groot_model_version(self.model_version)
|
self.model_version = normalize_groot_model_version(self.model_version)
|
||||||
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
self.action_decode_transform = normalize_groot_action_decode_transform(self.action_decode_transform)
|
||||||
if self.base_model_path is None:
|
if self.base_model_path is None:
|
||||||
@@ -416,26 +457,48 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
# 'libero_sim' embodiment grasps correctly instead of scoring 0% success.
|
||||||
# This matches the embodiment-specific handling already done for the
|
# This matches the embodiment-specific handling already done for the
|
||||||
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
# action execution horizon (see infer_groot_n1_7_action_execution_horizon).
|
||||||
if self.action_decode_transform is None and self.embodiment_tag == "libero_sim":
|
# Only the 'auto' sentinel resolves to the embodiment default; an explicit
|
||||||
self.action_decode_transform = GROOT_ACTION_DECODE_TRANSFORM_LIBERO
|
# 'none' (normalized to None above) keeps the transform disabled.
|
||||||
|
if self.action_decode_transform == GROOT_ACTION_DECODE_TRANSFORM_AUTO:
|
||||||
|
self.action_decode_transform = (
|
||||||
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO if self.embodiment_tag == "libero_sim" else None
|
||||||
|
)
|
||||||
|
|
||||||
if self.max_state_dim == 64:
|
# GR00T N1.5-era default values (e.g. --policy.chunk_size=50 from old commands or
|
||||||
self.max_state_dim = 132
|
# stale configs) are migrated to the values the N1.7 checkpoints expect, with a
|
||||||
if self.max_action_dim == 32:
|
# warning. The dataclass defaults are already the N1.7 values, so a plain
|
||||||
self.max_action_dim = 132
|
# GrootConfig() never triggers this.
|
||||||
if self.chunk_size == 50:
|
legacy_default_remaps = (
|
||||||
self.chunk_size = 40
|
("max_state_dim", 64, 132),
|
||||||
if self.n_action_steps == 50:
|
("max_action_dim", 32, 132),
|
||||||
self.n_action_steps = 40
|
("chunk_size", 50, 40),
|
||||||
if tuple(self.image_size) == (224, 224):
|
("n_action_steps", 50, 40),
|
||||||
self.image_size = (256, 256)
|
("image_size", (224, 224), (256, 256)),
|
||||||
|
)
|
||||||
|
for field_name, legacy_value, n1_7_value in legacy_default_remaps:
|
||||||
|
current_value = getattr(self, field_name)
|
||||||
|
if isinstance(legacy_value, tuple):
|
||||||
|
current_value = tuple(current_value)
|
||||||
|
if current_value == legacy_value:
|
||||||
|
logger.warning(
|
||||||
|
"GrootConfig.%s=%s matches a legacy GR00T N1.5-era default; remapping it to %s, "
|
||||||
|
"the value expected by GR00T N1.7 checkpoints. Set a different value explicitly "
|
||||||
|
"if this is not what you want.",
|
||||||
|
field_name,
|
||||||
|
legacy_value,
|
||||||
|
n1_7_value,
|
||||||
|
)
|
||||||
|
setattr(self, field_name, n1_7_value)
|
||||||
|
|
||||||
inferred_version = infer_groot_model_version(self.base_model_path)
|
inferred_version = infer_groot_model_version(self.base_model_path)
|
||||||
if inferred_version is not None and inferred_version != self.model_version:
|
if inferred_version is not None and inferred_version != self.model_version:
|
||||||
raise ValueError(
|
message = (
|
||||||
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
f"GR00T model_version '{self.model_version}' does not match base_model_path "
|
||||||
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
f"'{self.base_model_path}', which looks like '{inferred_version}'."
|
||||||
)
|
)
|
||||||
|
if inferred_version == GROOT_N1_5:
|
||||||
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
|
|
||||||
super().__post_init__()
|
super().__post_init__()
|
||||||
|
|
||||||
@@ -512,7 +575,9 @@ class GrootConfig(PreTrainedConfig):
|
|||||||
@property
|
@property
|
||||||
def action_delta_indices(self) -> list[int]:
|
def action_delta_indices(self) -> list[int]:
|
||||||
"""Return indices for delta actions."""
|
"""Return indices for delta actions."""
|
||||||
model_action_horizon = infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
model_action_horizon = (
|
||||||
|
infer_groot_n1_7_action_horizon(self.base_model_path, self.embodiment_tag) or 40
|
||||||
|
)
|
||||||
return list(range(min(self.chunk_size, model_action_horizon)))
|
return list(range(min(self.chunk_size, model_action_horizon)))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
|
|||||||
@@ -32,6 +32,7 @@ from torch.distributions import Beta
|
|||||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||||
|
|
||||||
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
from .action_head.cross_attention_dit import AlternateVLDiT, DiT, SelfAttentionTransformer
|
||||||
|
from .configuration_groot import N1_7_DEFAULT_IMAGE_CROP_SIZE, N1_7_DEFAULT_IMAGE_TARGET_SIZE
|
||||||
|
|
||||||
if TYPE_CHECKING or _transformers_available:
|
if TYPE_CHECKING or _transformers_available:
|
||||||
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
from transformers import AutoConfig, AutoModel, PretrainedConfig, PreTrainedModel
|
||||||
@@ -71,13 +72,13 @@ GR00T_N1_7_DEFAULTS: dict[str, Any] = {
|
|||||||
"backbone_embedding_dim": 2048,
|
"backbone_embedding_dim": 2048,
|
||||||
"tune_llm": False,
|
"tune_llm": False,
|
||||||
"tune_visual": False,
|
"tune_visual": False,
|
||||||
"select_layer": 12,
|
"select_layer": 16,
|
||||||
"reproject_vision": False,
|
"reproject_vision": False,
|
||||||
"use_flash_attention": True,
|
"use_flash_attention": True,
|
||||||
"load_bf16": False,
|
"load_bf16": False,
|
||||||
"backbone_trainable_params_fp32": True,
|
"backbone_trainable_params_fp32": True,
|
||||||
"image_crop_size": (230, 230),
|
"image_crop_size": N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||||
"image_target_size": (256, 256),
|
"image_target_size": N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||||
"shortest_image_edge": None,
|
"shortest_image_edge": None,
|
||||||
"crop_fraction": None,
|
"crop_fraction": None,
|
||||||
"random_rotation_angle": None,
|
"random_rotation_angle": None,
|
||||||
@@ -819,11 +820,14 @@ def _cosmos_reason2_qwen3_vl_config() -> PretrainedConfig:
|
|||||||
|
|
||||||
|
|
||||||
def get_backbone_cls(config: GR00TN17Config):
|
def get_backbone_cls(config: GR00TN17Config):
|
||||||
if (
|
if "nvidia/Cosmos-Reason2" in config.model_name or "Qwen/Qwen3-VL" in config.model_name:
|
||||||
config.backbone_model_type == "qwen"
|
return Qwen3Backbone
|
||||||
or "nvidia/Cosmos-Reason2" in config.model_name
|
if config.backbone_model_type == "qwen":
|
||||||
or "Qwen/Qwen3-VL" in config.model_name
|
logger.warning(
|
||||||
):
|
"Unrecognized GR00T N1.7 backbone model name '%s'; assuming a Qwen3-VL-compatible "
|
||||||
|
"backbone because backbone_model_type='qwen'.",
|
||||||
|
config.model_name,
|
||||||
|
)
|
||||||
return Qwen3Backbone
|
return Qwen3Backbone
|
||||||
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
raise ValueError(f"Unsupported GR00T N1.7 backbone model: {config.model_name}")
|
||||||
|
|
||||||
@@ -909,7 +913,7 @@ class GR00TN17(PreTrainedModel):
|
|||||||
"trust_remote_code": True
|
"trust_remote_code": True
|
||||||
}
|
}
|
||||||
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
load_backbone_weights = kwargs.pop("load_backbone_weights", False)
|
||||||
for key in ("revision", "cache_dir", "local_files_only", "token"):
|
for key in ("cache_dir", "local_files_only", "token"):
|
||||||
if key in kwargs:
|
if key in kwargs:
|
||||||
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
transformers_loading_kwargs.setdefault(key, kwargs[key])
|
||||||
|
|
||||||
|
|||||||
@@ -18,15 +18,12 @@
|
|||||||
Groot Policy Wrapper for LeRobot Integration
|
Groot Policy Wrapper for LeRobot Integration
|
||||||
|
|
||||||
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
Minimal integration that delegates to Isaac-GR00T N1.7 components where
|
||||||
possible without porting their code.
|
possible without porting their code. Dataset loading and training
|
||||||
|
orchestration are handled by LeRobot's standard training stack.
|
||||||
Notes:
|
|
||||||
- Dataset loading and full training orchestration is handled by Isaac-GR00T
|
|
||||||
TrainRunner in their codebase. If you want to invoke that flow end-to-end
|
|
||||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
import builtins
|
import builtins
|
||||||
|
import logging
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -42,6 +39,8 @@ from lerobot.utils.import_utils import require_package
|
|||||||
from ..pretrained import PreTrainedPolicy
|
from ..pretrained import PreTrainedPolicy
|
||||||
from ..utils import get_device_from_parameters
|
from ..utils import get_device_from_parameters
|
||||||
from .configuration_groot import (
|
from .configuration_groot import (
|
||||||
|
GROOT_N1_5,
|
||||||
|
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||||
GROOT_N1_7,
|
GROOT_N1_7,
|
||||||
GrootConfig,
|
GrootConfig,
|
||||||
infer_groot_model_version,
|
infer_groot_model_version,
|
||||||
@@ -50,6 +49,8 @@ from .configuration_groot import (
|
|||||||
normalize_groot_model_version,
|
normalize_groot_model_version,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
T = TypeVar("T", bound="GrootPolicy")
|
T = TypeVar("T", bound="GrootPolicy")
|
||||||
|
|
||||||
|
|
||||||
@@ -92,9 +93,6 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
transformers_loading_kwargs={"trust_remote_code": True},
|
transformers_loading_kwargs={"trust_remote_code": True},
|
||||||
)
|
)
|
||||||
|
|
||||||
model.compute_dtype = "bfloat16" if self.config.use_bf16 else model.compute_dtype
|
|
||||||
model.config.compute_dtype = model.compute_dtype
|
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
@@ -148,9 +146,10 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if config is not None
|
if config is not None
|
||||||
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
else infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||||
)
|
)
|
||||||
print(
|
logger.info(
|
||||||
f"The Groot policy is a wrapper around Nvidia's GR00T {requested_version} model.\n"
|
"The Groot policy wraps NVIDIA's GR00T %s model. Loading pretrained model from: %s",
|
||||||
f"Loading pretrained model from: {pretrained_name_or_path}"
|
requested_version,
|
||||||
|
pretrained_name_or_path,
|
||||||
)
|
)
|
||||||
|
|
||||||
model_id = str(pretrained_name_or_path)
|
model_id = str(pretrained_name_or_path)
|
||||||
@@ -181,7 +180,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
if is_finetuned_checkpoint:
|
if is_finetuned_checkpoint:
|
||||||
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||||
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
logger.info("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||||
return super().from_pretrained(
|
return super().from_pretrained(
|
||||||
pretrained_name_or_path=pretrained_name_or_path,
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
config=config,
|
config=config,
|
||||||
@@ -197,7 +196,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
)
|
)
|
||||||
|
|
||||||
# This is a base GR00T model - load it fresh
|
# This is a base GR00T model - load it fresh
|
||||||
print("Detected base GR00T model, loading from HuggingFace...")
|
logger.info("Detected base GR00T model, loading from HuggingFace...")
|
||||||
|
|
||||||
if config is None:
|
if config is None:
|
||||||
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
model_version = infer_groot_model_version(str(pretrained_name_or_path)) or GROOT_N1_7
|
||||||
@@ -229,10 +228,13 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
config.model_version = normalize_groot_model_version(config.model_version)
|
config.model_version = normalize_groot_model_version(config.model_version)
|
||||||
inferred_version = infer_groot_model_version(config.base_model_path)
|
inferred_version = infer_groot_model_version(config.base_model_path)
|
||||||
if inferred_version is not None and inferred_version != config.model_version:
|
if inferred_version is not None and inferred_version != config.model_version:
|
||||||
raise ValueError(
|
message = (
|
||||||
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
f"GR00T model_version '{config.model_version}' does not match base_model_path "
|
||||||
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
f"'{config.base_model_path}', which looks like '{inferred_version}'."
|
||||||
)
|
)
|
||||||
|
if inferred_version == GROOT_N1_5:
|
||||||
|
message = f"{message} {GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
raise ValueError(message)
|
||||||
# Create a fresh policy instance - this will automatically load the GR00T model
|
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||||
# in __init__ via _create_groot_model()
|
# in __init__ via _create_groot_model()
|
||||||
policy = cls(config)
|
policy = cls(config)
|
||||||
@@ -297,9 +299,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
allowed_base.add("action_mask")
|
allowed_base.add("action_mask")
|
||||||
|
|
||||||
return {
|
return {
|
||||||
k: v
|
k: v for k, v in batch.items() if k in allowed_base and not (k.startswith("next.") or k == "info")
|
||||||
for k, v in batch.items()
|
|
||||||
if k in allowed_base and not (k.startswith("next.") or k == "info")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def _prepare_n1_7_rtc_inputs(
|
def _prepare_n1_7_rtc_inputs(
|
||||||
@@ -320,9 +320,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if prev_actions.ndim == 2:
|
if prev_actions.ndim == 2:
|
||||||
prev_actions = prev_actions.unsqueeze(0)
|
prev_actions = prev_actions.unsqueeze(0)
|
||||||
elif prev_actions.ndim != 3:
|
elif prev_actions.ndim != 3:
|
||||||
raise ValueError(
|
raise ValueError("prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC.")
|
||||||
"prev_chunk_left_over must have shape (T, A) or (B, T, A) for GR00T N1.7 RTC."
|
|
||||||
)
|
|
||||||
|
|
||||||
state = inputs.get("state")
|
state = inputs.get("state")
|
||||||
if state is None:
|
if state is None:
|
||||||
@@ -331,9 +329,7 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
if prev_actions.shape[0] == 1 and batch_size > 1:
|
if prev_actions.shape[0] == 1 and batch_size > 1:
|
||||||
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
prev_actions = prev_actions.expand(batch_size, -1, -1).clone()
|
||||||
elif prev_actions.shape[0] != batch_size:
|
elif prev_actions.shape[0] != batch_size:
|
||||||
raise ValueError(
|
raise ValueError("prev_chunk_left_over batch size must match the current GR00T N1.7 batch size.")
|
||||||
"prev_chunk_left_over batch size must match the current GR00T N1.7 batch size."
|
|
||||||
)
|
|
||||||
|
|
||||||
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
# The generic LeRobot RTC engine pads short leftovers with exact zero
|
||||||
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
# rows for fixed-shape policy calls. Native GR00T N1.7 RTC treats every
|
||||||
@@ -346,7 +342,9 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
else:
|
else:
|
||||||
return inputs, None
|
return inputs, None
|
||||||
|
|
||||||
model_action_horizon = int(getattr(self._groot_model.config, "action_horizon", self.config.chunk_size))
|
model_action_horizon = int(
|
||||||
|
getattr(self._groot_model.config, "action_horizon", self.config.chunk_size)
|
||||||
|
)
|
||||||
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
max_action_dim = int(getattr(self._groot_model.config, "max_action_dim", self.config.max_action_dim))
|
||||||
if prev_actions.shape[1] > model_action_horizon:
|
if prev_actions.shape[1] > model_action_horizon:
|
||||||
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
prev_actions = prev_actions[:, -model_action_horizon:, :]
|
||||||
@@ -409,6 +407,11 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
# Isaac-GR00T returns a BatchFeature; loss key is typically 'loss'
|
||||||
loss = outputs.get("loss")
|
loss = outputs.get("loss")
|
||||||
|
if loss is None:
|
||||||
|
raise RuntimeError(
|
||||||
|
"GR00T model.forward did not return a 'loss'. Training batches must include "
|
||||||
|
"'action' and 'action_mask'; check the preprocessor output."
|
||||||
|
)
|
||||||
|
|
||||||
loss_dict = {"loss": loss.item()}
|
loss_dict = {"loss": loss.item()}
|
||||||
|
|
||||||
@@ -471,33 +474,21 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
# Internal helpers
|
# Internal helpers
|
||||||
# -------------------------
|
# -------------------------
|
||||||
def _handle_flash_attention_compatibility(self) -> None:
|
def _handle_flash_attention_compatibility(self) -> None:
|
||||||
"""Handle Flash Attention compatibility issues by setting environment variables.
|
"""Log Flash Attention availability (diagnostic only).
|
||||||
|
|
||||||
This addresses the common 'undefined symbol' error that occurs when Flash Attention
|
The GR00T N1.7 backbone automatically falls back to SDPA when ``flash_attn`` is
|
||||||
is compiled against a different PyTorch version than what's currently installed.
|
unavailable (see ``Qwen3Backbone``), so this probe only emits a hint; it does not
|
||||||
|
change behaviour or mutate global state.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
# Set environment variables to handle Flash Attention compatibility
|
|
||||||
# These help with symbol resolution issues
|
|
||||||
os.environ.setdefault("FLASH_ATTENTION_FORCE_BUILD", "0")
|
|
||||||
os.environ.setdefault("FLASH_ATTENTION_SKIP_CUDA_BUILD", "0")
|
|
||||||
|
|
||||||
# Try to import flash_attn and handle failures gracefully
|
|
||||||
try:
|
try:
|
||||||
import flash_attn
|
import flash_attn
|
||||||
|
|
||||||
print(f"[GROOT] Flash Attention version: {flash_attn.__version__}")
|
logger.debug("Flash Attention %s is available.", flash_attn.__version__)
|
||||||
except ImportError as e:
|
except ImportError:
|
||||||
print(f"[GROOT] Flash Attention not available: {e}")
|
logger.debug("Flash Attention is not installed; the GR00T backbone will use SDPA.")
|
||||||
print("[GROOT] Will use fallback attention mechanism")
|
except Exception as e: # noqa: BLE001
|
||||||
except Exception as e:
|
logger.warning(
|
||||||
if "undefined symbol" in str(e):
|
"Flash Attention failed to import (%s); the GR00T backbone will use SDPA. If this is "
|
||||||
print(f"[GROOT] Flash Attention compatibility issue detected: {e}")
|
"an 'undefined symbol' error, reinstall a flash-attn build matching your torch version.",
|
||||||
print("[GROOT] This is likely due to PyTorch/Flash Attention version mismatch")
|
e,
|
||||||
print("[GROOT] Consider reinstalling Flash Attention with compatible version:")
|
)
|
||||||
print(" pip uninstall flash-attn")
|
|
||||||
print(" pip install --no-build-isolation flash-attn==2.6.3")
|
|
||||||
print("[GROOT] Continuing with fallback attention mechanism")
|
|
||||||
else:
|
|
||||||
print(f"[GROOT] Flash Attention error: {e}")
|
|
||||||
print("[GROOT] Continuing with fallback attention mechanism")
|
|
||||||
|
|||||||
@@ -15,15 +15,18 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
|
import logging
|
||||||
from copy import copy
|
from copy import copy
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field, fields, is_dataclass
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
import torchvision.transforms.v2.functional as tv_functional
|
||||||
from einops import rearrange
|
from einops import rearrange
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
from torchvision.transforms import InterpolationMode
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -56,10 +59,14 @@ from lerobot.utils.constants import (
|
|||||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
)
|
)
|
||||||
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
|
|
||||||
from .configuration_groot import (
|
from .configuration_groot import (
|
||||||
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
GROOT_ACTION_DECODE_TRANSFORM_LIBERO,
|
||||||
|
GROOT_N1_5_REMOVAL_GUIDANCE,
|
||||||
GROOT_N1_7_BACKBONE_MODEL,
|
GROOT_N1_7_BACKBONE_MODEL,
|
||||||
|
N1_7_DEFAULT_IMAGE_CROP_SIZE,
|
||||||
|
N1_7_DEFAULT_IMAGE_TARGET_SIZE,
|
||||||
GrootConfig,
|
GrootConfig,
|
||||||
is_raw_groot_n1_7_checkpoint,
|
is_raw_groot_n1_7_checkpoint,
|
||||||
)
|
)
|
||||||
@@ -80,12 +87,6 @@ N1_7_EMBODIMENT_MAPPING = {
|
|||||||
"new_embodiment": 10,
|
"new_embodiment": 10,
|
||||||
}
|
}
|
||||||
|
|
||||||
_N1_7_RAW_STATE_CACHE: dict[str, dict[str, np.ndarray]] = {}
|
|
||||||
|
|
||||||
|
|
||||||
def _n1_7_state_cache_key(value: str | None) -> str:
|
|
||||||
return value or "groot_n1_7_default"
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class _GrootN17CheckpointProcessorAssets:
|
class _GrootN17CheckpointProcessorAssets:
|
||||||
@@ -451,43 +452,97 @@ def _has_modality_stats(stats: dict[str, dict[str, Any]] | None) -> bool:
|
|||||||
return any(bool(modality_stats) for modality_stats in stats.values())
|
return any(bool(modality_stats) for modality_stats in stats.values())
|
||||||
|
|
||||||
|
|
||||||
def _legacy_groot_processor_overrides(
|
# GR00T normalizes state/action inside its own processor steps and so deliberately has no
|
||||||
config: GrootConfig,
|
# NormalizerProcessorStep/UnnormalizerProcessorStep (see GrootConfig.normalization_mapping, which is
|
||||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None,
|
# IDENTITY for every feature). lerobot-train nonetheless emits these standard override keys
|
||||||
preprocessor_overrides: dict[str, Any] | None = None,
|
# unconditionally, so for a GR00T pipeline they legitimately match no step. They are dropped up front
|
||||||
postprocessor_overrides: dict[str, Any] | None = None,
|
# by _drop_groot_absent_standard_overrides so they neither break loading nor mask genuine typos.
|
||||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
_GROOT_ABSENT_STANDARD_OVERRIDE_KEYS = frozenset({"normalizer_processor", "unnormalizer_processor"})
|
||||||
"""Patch older serialized Groot processors with fields current processors expect."""
|
|
||||||
|
|
||||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
|
||||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
|
||||||
pack_inputs_key = "groot_n1_7_pack_inputs_v1"
|
|
||||||
|
|
||||||
pack_input_overrides = dict(preprocessor_overrides.get(pack_inputs_key, {}))
|
|
||||||
pack_input_overrides["normalize_min_max"] = True
|
|
||||||
preprocessor_overrides[pack_inputs_key] = pack_input_overrides
|
|
||||||
|
|
||||||
try:
|
|
||||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
|
||||||
except Exception:
|
|
||||||
env_action_dim = 0
|
|
||||||
action_unpack_overrides = dict(postprocessor_overrides.get("groot_action_unpack_unnormalize_v1", {}))
|
|
||||||
action_unpack_overrides["normalize_min_max"] = True
|
|
||||||
action_unpack_overrides["env_action_dim"] = env_action_dim
|
|
||||||
postprocessor_overrides["groot_action_unpack_unnormalize_v1"] = action_unpack_overrides
|
|
||||||
|
|
||||||
return preprocessor_overrides, postprocessor_overrides
|
|
||||||
|
|
||||||
|
|
||||||
def _local_processor_config_has_step(pretrained_path: str, config_filename: str, step_name: str) -> bool:
|
def _drop_groot_absent_standard_overrides(overrides: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||||
path = Path(pretrained_path).expanduser()
|
"""Strip standard normalization override keys that a GR00T pipeline has no step for.
|
||||||
if not path.is_dir():
|
|
||||||
return False
|
``lerobot-train`` emits ``normalizer_processor``/``unnormalizer_processor`` overrides
|
||||||
config = _read_json(path / config_filename)
|
unconditionally, but GR00T normalizes inside its own steps and has no such step (see
|
||||||
steps = config.get("steps", [])
|
``GrootConfig.normalization_mapping``). Both override-application paths reject keys that match no
|
||||||
if not isinstance(steps, list):
|
step — ``_apply_groot_step_overrides`` raises for the freshly built raw-checkpoint pipeline, and
|
||||||
return False
|
``PolicyProcessorPipeline.from_pretrained`` raises via its used-override validation for the
|
||||||
return any(isinstance(step, dict) and step.get("registry_name") == step_name for step in steps)
|
serialized pipeline — so these keys are removed before either path runs. Any other unknown key
|
||||||
|
(e.g. a typo) is left in place and still raises.
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not overrides:
|
||||||
|
return overrides
|
||||||
|
|
||||||
|
filtered: dict[str, Any] = {}
|
||||||
|
for key, value in overrides.items():
|
||||||
|
if key in _GROOT_ABSENT_STANDARD_OVERRIDE_KEYS:
|
||||||
|
logging.debug(
|
||||||
|
"Ignoring override key '%s': GR00T normalizes inside its own processor steps and has "
|
||||||
|
"no matching step (see GrootConfig.normalization_mapping).",
|
||||||
|
key,
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
filtered[key] = value
|
||||||
|
return filtered
|
||||||
|
|
||||||
|
|
||||||
|
def _apply_groot_step_overrides(
|
||||||
|
pipeline: PolicyProcessorPipeline,
|
||||||
|
overrides: dict[str, Any] | None,
|
||||||
|
) -> None:
|
||||||
|
"""Apply ``from_pretrained``-style step overrides to a freshly built pipeline.
|
||||||
|
|
||||||
|
Raw N1.7 checkpoints build their processors from scratch instead of
|
||||||
|
deserializing them, so caller overrides must be applied to the constructed
|
||||||
|
steps. Override keys match a step's registry name or, as a convenience, its
|
||||||
|
class name (``PolicyProcessorPipeline.from_pretrained`` matches registered
|
||||||
|
steps by registry name only — prefer registry names so overrides keep
|
||||||
|
working after the checkpoint is converted and reloaded from a serialized
|
||||||
|
pipeline). Keys or fields that match nothing raise instead of being dropped
|
||||||
|
silently (standard normalization keys GR00T has no step for are removed
|
||||||
|
beforehand by ``_drop_groot_absent_standard_overrides``).
|
||||||
|
"""
|
||||||
|
|
||||||
|
if not overrides:
|
||||||
|
return
|
||||||
|
|
||||||
|
def _step_keys(step: ProcessorStep) -> set[str]:
|
||||||
|
keys = {type(step).__name__}
|
||||||
|
registry_name = getattr(type(step), "_registry_name", None)
|
||||||
|
if registry_name:
|
||||||
|
keys.add(registry_name)
|
||||||
|
return keys
|
||||||
|
|
||||||
|
for override_key, step_overrides in overrides.items():
|
||||||
|
matched_steps = [step for step in pipeline.steps if override_key in _step_keys(step)]
|
||||||
|
if not matched_steps:
|
||||||
|
available = [
|
||||||
|
getattr(type(step), "_registry_name", None) or type(step).__name__ for step in pipeline.steps
|
||||||
|
]
|
||||||
|
raise KeyError(
|
||||||
|
f"Override key '{override_key}' does not match any step of the GR00T processor pipeline "
|
||||||
|
f"built for this raw N1.7 checkpoint. Available step keys: {available}."
|
||||||
|
)
|
||||||
|
for step in matched_steps:
|
||||||
|
if not is_dataclass(step):
|
||||||
|
raise TypeError(
|
||||||
|
f"Cannot apply overrides to step '{override_key}': it is not a dataclass step."
|
||||||
|
)
|
||||||
|
init_field_names = {f.name for f in fields(step) if f.init}
|
||||||
|
for field_name, value in dict(step_overrides).items():
|
||||||
|
if field_name not in init_field_names:
|
||||||
|
raise TypeError(
|
||||||
|
f"Override field '{field_name}' is not a config field of step '{override_key}'. "
|
||||||
|
f"Available fields: {sorted(init_field_names)}."
|
||||||
|
)
|
||||||
|
setattr(step, field_name, value)
|
||||||
|
# Re-derive attributes computed from the overridden config (e.g.
|
||||||
|
# DeviceProcessorStep resolves its torch.device in __post_init__).
|
||||||
|
post_init = getattr(step, "__post_init__", None)
|
||||||
|
if callable(post_init):
|
||||||
|
post_init()
|
||||||
|
|
||||||
|
|
||||||
def make_groot_pre_post_processors_from_pretrained(
|
def make_groot_pre_post_processors_from_pretrained(
|
||||||
@@ -503,33 +558,51 @@ def make_groot_pre_post_processors_from_pretrained(
|
|||||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
]:
|
]:
|
||||||
"""Load Groot processors while preserving compatibility with older serialized configs."""
|
"""Load Groot processors for a raw N1.7 checkpoint or a serialized LeRobot pipeline."""
|
||||||
|
|
||||||
|
# Drop the standard normalizer/unnormalizer override keys lerobot-train emits unconditionally:
|
||||||
|
# GR00T has no such steps, so they would make both the raw-checkpoint and serialized override
|
||||||
|
# paths raise. This must happen before either branch below.
|
||||||
|
preprocessor_overrides = _drop_groot_absent_standard_overrides(preprocessor_overrides)
|
||||||
|
postprocessor_overrides = _drop_groot_absent_standard_overrides(postprocessor_overrides)
|
||||||
|
|
||||||
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
if is_raw_groot_n1_7_checkpoint(pretrained_path):
|
||||||
processor_cfg = copy(config)
|
processor_cfg = copy(config)
|
||||||
processor_cfg.base_model_path = str(pretrained_path)
|
processor_cfg.base_model_path = str(pretrained_path)
|
||||||
return make_groot_pre_post_processors(
|
preprocessor, postprocessor = make_groot_pre_post_processors(
|
||||||
config=processor_cfg,
|
config=processor_cfg,
|
||||||
dataset_stats=dataset_stats,
|
dataset_stats=dataset_stats,
|
||||||
)
|
)
|
||||||
|
# Raw checkpoints have no serialized pipelines to load overrides into,
|
||||||
|
# so apply the caller overrides (e.g. device and rename_map from
|
||||||
|
# lerobot-eval or the policy server) to the freshly built steps.
|
||||||
|
_apply_groot_step_overrides(preprocessor, preprocessor_overrides)
|
||||||
|
_apply_groot_step_overrides(postprocessor, postprocessor_overrides)
|
||||||
|
return preprocessor, postprocessor
|
||||||
|
|
||||||
if _local_processor_config_has_step(
|
preprocessor, postprocessor = _load_groot_processor_pipelines(
|
||||||
pretrained_path,
|
pretrained_path,
|
||||||
postprocessor_config_filename,
|
|
||||||
"groot_n1_7_action_decode_v1",
|
|
||||||
):
|
|
||||||
# Converted raw N1.7 checkpoints already carry the checkpoint-specific
|
|
||||||
# action decoder. Adding the legacy action-unpack override would target
|
|
||||||
# a step that is not present and break loading.
|
|
||||||
preprocessor_overrides = dict(preprocessor_overrides or {})
|
|
||||||
postprocessor_overrides = dict(postprocessor_overrides or {})
|
|
||||||
else:
|
|
||||||
preprocessor_overrides, postprocessor_overrides = _legacy_groot_processor_overrides(
|
|
||||||
config=config,
|
|
||||||
dataset_stats=dataset_stats,
|
|
||||||
preprocessor_overrides=preprocessor_overrides,
|
preprocessor_overrides=preprocessor_overrides,
|
||||||
postprocessor_overrides=postprocessor_overrides,
|
postprocessor_overrides=postprocessor_overrides,
|
||||||
|
preprocessor_config_filename=preprocessor_config_filename,
|
||||||
|
postprocessor_config_filename=postprocessor_config_filename,
|
||||||
)
|
)
|
||||||
|
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
||||||
|
_reconnect_groot_n1_7_pack_decode_steps(preprocessor, postprocessor)
|
||||||
|
return preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
|
def _load_groot_processor_pipelines(
|
||||||
|
pretrained_path: str,
|
||||||
|
*,
|
||||||
|
preprocessor_overrides: dict[str, Any],
|
||||||
|
postprocessor_overrides: dict[str, Any],
|
||||||
|
preprocessor_config_filename: str,
|
||||||
|
postprocessor_config_filename: str,
|
||||||
|
) -> tuple[
|
||||||
|
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||||
|
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||||
|
]:
|
||||||
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
preprocessor = PolicyProcessorPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_path,
|
pretrained_model_name_or_path=pretrained_path,
|
||||||
config_filename=preprocessor_config_filename,
|
config_filename=preprocessor_config_filename,
|
||||||
@@ -544,7 +617,6 @@ def make_groot_pre_post_processors_from_pretrained(
|
|||||||
to_transition=policy_action_to_transition,
|
to_transition=policy_action_to_transition,
|
||||||
to_output=transition_to_policy_action,
|
to_output=transition_to_policy_action,
|
||||||
)
|
)
|
||||||
_reconnect_groot_relative_absolute_steps(preprocessor, postprocessor)
|
|
||||||
return preprocessor, postprocessor
|
return preprocessor, postprocessor
|
||||||
|
|
||||||
|
|
||||||
@@ -564,6 +636,28 @@ def _reconnect_groot_relative_absolute_steps(
|
|||||||
step.relative_step = relative_step
|
step.relative_step = relative_step
|
||||||
|
|
||||||
|
|
||||||
|
def _reconnect_groot_n1_7_pack_decode_steps(
|
||||||
|
preprocessor: PolicyProcessorPipeline,
|
||||||
|
postprocessor: PolicyProcessorPipeline,
|
||||||
|
) -> None:
|
||||||
|
"""Re-link a deserialized N1.7 action decode step to its pack step.
|
||||||
|
|
||||||
|
The pack step holds the per-instance raw-state cache that relative-action
|
||||||
|
decoding reads its reference state from; the link itself is not serialized.
|
||||||
|
"""
|
||||||
|
|
||||||
|
pack_step = next(
|
||||||
|
(step for step in preprocessor.steps if isinstance(step, GrootN17PackInputsStep)),
|
||||||
|
None,
|
||||||
|
)
|
||||||
|
if pack_step is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
for step in postprocessor.steps:
|
||||||
|
if isinstance(step, GrootN17ActionDecodeStep) and step.pack_step is None:
|
||||||
|
step.pack_step = pack_step
|
||||||
|
|
||||||
|
|
||||||
def make_groot_pre_post_processors(
|
def make_groot_pre_post_processors(
|
||||||
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
config: GrootConfig, dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None
|
||||||
) -> tuple[
|
) -> tuple[
|
||||||
@@ -607,9 +701,12 @@ def make_groot_pre_post_processors(
|
|||||||
else action_horizon
|
else action_horizon
|
||||||
)
|
)
|
||||||
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
|
checkpoint_stats = checkpoint_assets.stats if checkpoint_assets is not None else None
|
||||||
padded_stats = checkpoint_stats if _has_modality_stats(checkpoint_stats) else (dataset_stats or {})
|
checkpoint_has_stats = _has_modality_stats(checkpoint_stats)
|
||||||
|
padded_stats = checkpoint_stats if checkpoint_has_stats else (dataset_stats or {})
|
||||||
embodiment_mapping = (
|
embodiment_mapping = (
|
||||||
checkpoint_assets.embodiment_mapping if checkpoint_assets is not None else dict(N1_7_EMBODIMENT_MAPPING)
|
checkpoint_assets.embodiment_mapping
|
||||||
|
if checkpoint_assets is not None
|
||||||
|
else dict(N1_7_EMBODIMENT_MAPPING)
|
||||||
)
|
)
|
||||||
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
|
formalize_language = checkpoint_assets.formalize_language if checkpoint_assets is not None else True
|
||||||
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
|
clip_outliers = checkpoint_assets.clip_outliers if checkpoint_assets is not None else True
|
||||||
@@ -618,7 +715,6 @@ def make_groot_pre_post_processors(
|
|||||||
env_action_dim = int(config.output_features[ACTION].shape[0])
|
env_action_dim = int(config.output_features[ACTION].shape[0])
|
||||||
except Exception:
|
except Exception:
|
||||||
env_action_dim = 0
|
env_action_dim = 0
|
||||||
state_cache_key = f"groot_n1_7:{config.embodiment_tag}"
|
|
||||||
pack_step = GrootN17PackInputsStep(
|
pack_step = GrootN17PackInputsStep(
|
||||||
state_horizon=1,
|
state_horizon=1,
|
||||||
action_horizon=action_horizon,
|
action_horizon=action_horizon,
|
||||||
@@ -636,25 +732,56 @@ def make_groot_pre_post_processors(
|
|||||||
video_modality_keys=video_modality_keys,
|
video_modality_keys=video_modality_keys,
|
||||||
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
raw_stats=checkpoint_assets.raw_stats if checkpoint_assets is not None else None,
|
||||||
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
modality_config=checkpoint_assets.modality_config if checkpoint_assets is not None else None,
|
||||||
state_cache_key=state_cache_key,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Resolve the image preprocessing geometry. Honor the checkpoint's processor_config
|
||||||
|
# when it provides an image_target_size; otherwise fall back to the geometry the
|
||||||
|
# N1.7 backbone was trained on. Without this fallback a raw base checkpoint with no
|
||||||
|
# processor_config image sizing (e.g. fine-tuning nvidia/GR00T-N1.7-3B with a new
|
||||||
|
# embodiment, where checkpoint_assets is None) would patchify full-resolution camera
|
||||||
|
# frames, inflating the VLM token count and feeding the model a resolution it was not trained on.
|
||||||
|
if checkpoint_assets is not None and checkpoint_assets.image_target_size is not None:
|
||||||
|
image_target_size = checkpoint_assets.image_target_size
|
||||||
|
image_crop_size = checkpoint_assets.image_crop_size
|
||||||
|
shortest_image_edge = checkpoint_assets.shortest_image_edge
|
||||||
|
crop_fraction = checkpoint_assets.crop_fraction
|
||||||
|
else:
|
||||||
|
image_target_size = list(N1_7_DEFAULT_IMAGE_TARGET_SIZE)
|
||||||
|
image_crop_size = list(N1_7_DEFAULT_IMAGE_CROP_SIZE)
|
||||||
|
shortest_image_edge = None
|
||||||
|
crop_fraction = None
|
||||||
|
use_albumentations = checkpoint_assets.use_albumentations if checkpoint_assets is not None else False
|
||||||
|
|
||||||
input_steps: list[ProcessorStep] = [
|
input_steps: list[ProcessorStep] = [
|
||||||
RenameObservationsProcessorStep(rename_map={}),
|
RenameObservationsProcessorStep(rename_map={}),
|
||||||
AddBatchDimensionProcessorStep(),
|
AddBatchDimensionProcessorStep(),
|
||||||
pack_step,
|
pack_step,
|
||||||
GrootN17VLMEncodeStep(
|
GrootN17VLMEncodeStep(
|
||||||
model_name=config.n1_7_backbone_model,
|
model_name=config.n1_7_backbone_model,
|
||||||
image_crop_size=checkpoint_assets.image_crop_size if checkpoint_assets is not None else None,
|
image_crop_size=image_crop_size,
|
||||||
image_target_size=checkpoint_assets.image_target_size if checkpoint_assets is not None else None,
|
image_target_size=image_target_size,
|
||||||
shortest_image_edge=checkpoint_assets.shortest_image_edge if checkpoint_assets is not None else None,
|
shortest_image_edge=shortest_image_edge,
|
||||||
crop_fraction=checkpoint_assets.crop_fraction if checkpoint_assets is not None else None,
|
crop_fraction=crop_fraction,
|
||||||
use_albumentations=checkpoint_assets.use_albumentations if checkpoint_assets is not None else False,
|
use_albumentations=use_albumentations,
|
||||||
|
device=config.device,
|
||||||
),
|
),
|
||||||
DeviceProcessorStep(device=config.device),
|
DeviceProcessorStep(device=config.device),
|
||||||
]
|
]
|
||||||
|
|
||||||
if checkpoint_assets is None:
|
if checkpoint_assets is not None and not checkpoint_has_stats and not _has_modality_stats(padded_stats):
|
||||||
|
raise ValueError(
|
||||||
|
f"GR00T N1.7 checkpoint '{config.base_model_path}' has no statistics for embodiment tag "
|
||||||
|
f"'{config.embodiment_tag}', and no dataset stats were provided to fall back to, so "
|
||||||
|
"actions cannot be normalized or decoded. Pass dataset_stats, or set "
|
||||||
|
"config.embodiment_tag to an embodiment present in the checkpoint's statistics.json."
|
||||||
|
)
|
||||||
|
if checkpoint_assets is None or not checkpoint_has_stats:
|
||||||
|
# When the checkpoint sidecars have no stats for the configured
|
||||||
|
# embodiment tag (e.g. finetuning a raw base checkpoint with the
|
||||||
|
# default 'new_embodiment' tag), the pack step above normalized with
|
||||||
|
# the dataset stats; the decode step must invert with the same stats
|
||||||
|
# instead of using a checkpoint decoder whose empty stats would
|
||||||
|
# silently return normalized [-1, 1] actions.
|
||||||
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
|
action_decode_step: ProcessorStep = GrootActionUnpackUnnormalizeStep(
|
||||||
env_action_dim=env_action_dim,
|
env_action_dim=env_action_dim,
|
||||||
stats=padded_stats,
|
stats=padded_stats,
|
||||||
@@ -669,7 +796,6 @@ def make_groot_pre_post_processors(
|
|||||||
use_percentiles=checkpoint_assets.use_percentiles,
|
use_percentiles=checkpoint_assets.use_percentiles,
|
||||||
use_relative_action=checkpoint_assets.use_relative_action,
|
use_relative_action=checkpoint_assets.use_relative_action,
|
||||||
pack_step=pack_step,
|
pack_step=pack_step,
|
||||||
state_cache_key=state_cache_key,
|
|
||||||
action_decode_transform=config.action_decode_transform,
|
action_decode_transform=config.action_decode_transform,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -770,7 +896,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
|||||||
try:
|
try:
|
||||||
from transformers import (
|
from transformers import (
|
||||||
AutoTokenizer,
|
AutoTokenizer,
|
||||||
Qwen2VLImageProcessorFast,
|
Qwen2VLImageProcessor,
|
||||||
Qwen3VLProcessor,
|
Qwen3VLProcessor,
|
||||||
Qwen3VLVideoProcessor,
|
Qwen3VLVideoProcessor,
|
||||||
)
|
)
|
||||||
@@ -781,7 +907,7 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
|||||||
) from exc
|
) from exc
|
||||||
|
|
||||||
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
tokenizer = AutoTokenizer.from_pretrained(model_name, trust_remote_code=True)
|
||||||
image_processor = Qwen2VLImageProcessorFast.from_pretrained(model_name, trust_remote_code=True)
|
image_processor = Qwen2VLImageProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
video_processor = Qwen3VLVideoProcessor.from_pretrained(model_name, trust_remote_code=True)
|
||||||
proc = Qwen3VLProcessor(
|
proc = Qwen3VLProcessor(
|
||||||
image_processor=image_processor,
|
image_processor=image_processor,
|
||||||
@@ -793,15 +919,22 @@ def _build_n1_7_processor(model_name: str = GROOT_N1_7_BACKBONE_MODEL) -> Proces
|
|||||||
return proc
|
return proc
|
||||||
|
|
||||||
|
|
||||||
def _transform_n1_7_image_for_vlm(
|
def _transform_n1_7_image_for_vlm_albumentations(
|
||||||
image: Image.Image,
|
image: Image.Image,
|
||||||
*,
|
*,
|
||||||
image_crop_size: list[int] | None,
|
image_crop_size: list[int] | None,
|
||||||
image_target_size: list[int] | None,
|
image_target_size: list[int] | None,
|
||||||
shortest_image_edge: int | None,
|
shortest_image_edge: int | None,
|
||||||
crop_fraction: float | None,
|
crop_fraction: float | None,
|
||||||
use_albumentations: bool = False,
|
|
||||||
) -> Image.Image:
|
) -> Image.Image:
|
||||||
|
"""cv2/INTER_AREA eval transform mirroring Isaac-GR00T's albumentations preprocessing.
|
||||||
|
|
||||||
|
Used only for checkpoints saved with ``use_albumentations=True``. cv2 is
|
||||||
|
CPU/numpy-only so this path cannot run on GPU; the default (non-albumentations)
|
||||||
|
geometry is handled on-device by :func:`_transform_n1_7_image_for_vlm_torch`. The
|
||||||
|
cv2/INTER_AREA resize and floored center-crop here intentionally differ from that
|
||||||
|
torch path and must stay bit-exact to the upstream reference.
|
||||||
|
"""
|
||||||
if image_target_size is None:
|
if image_target_size is None:
|
||||||
return image
|
return image
|
||||||
|
|
||||||
@@ -809,7 +942,6 @@ def _transform_n1_7_image_for_vlm(
|
|||||||
if image.mode != "RGB":
|
if image.mode != "RGB":
|
||||||
image = image.convert("RGB")
|
image = image.convert("RGB")
|
||||||
|
|
||||||
if use_albumentations:
|
|
||||||
try:
|
try:
|
||||||
import cv2
|
import cv2
|
||||||
except ImportError as exc:
|
except ImportError as exc:
|
||||||
@@ -851,28 +983,60 @@ def _transform_n1_7_image_for_vlm(
|
|||||||
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
image_np = cv2.resize(image_np, (target_w, target_h), interpolation=cv2.INTER_AREA)
|
||||||
return Image.fromarray(image_np)
|
return Image.fromarray(image_np)
|
||||||
|
|
||||||
square_edge = max(image.width, image.height)
|
|
||||||
if image.width != image.height:
|
def _transform_n1_7_image_for_vlm_torch(
|
||||||
padded = Image.new("RGB", (square_edge, square_edge))
|
image: torch.Tensor,
|
||||||
left = (square_edge - image.width) // 2
|
*,
|
||||||
top = (square_edge - image.height) // 2
|
image_crop_size: list[int] | None,
|
||||||
padded.paste(image, (left, top))
|
image_target_size: list[int] | None,
|
||||||
image = padded
|
shortest_image_edge: int | None,
|
||||||
|
crop_fraction: float | None,
|
||||||
|
) -> torch.Tensor:
|
||||||
|
"""Default (non-albumentations) N1.7 image transform: pad-to-square, resize to
|
||||||
|
``shortest_image_edge``, center-crop by ``crop_fraction``, resize to ``image_target_size``.
|
||||||
|
|
||||||
|
Operates on a ``(C, H, W)`` uint8 tensor and keeps the result on the input
|
||||||
|
tensor's device so the resize/crop run on GPU when the tensor is. Bicubic
|
||||||
|
interpolation with antialiasing matches PIL's ``Image.Resampling.BICUBIC``
|
||||||
|
closely (sub-``2/255`` per-pixel on worst-case inputs). The ``use_albumentations``
|
||||||
|
cv2/INTER_AREA path has no torch equivalent and stays on
|
||||||
|
:func:`_transform_n1_7_image_for_vlm_albumentations`.
|
||||||
|
"""
|
||||||
|
if image_target_size is None:
|
||||||
|
return image
|
||||||
|
|
||||||
|
target_h, target_w = image_target_size
|
||||||
|
_, height, width = image.shape
|
||||||
|
|
||||||
|
square_edge = max(height, width)
|
||||||
|
if height != width:
|
||||||
|
left = (square_edge - width) // 2
|
||||||
|
top = (square_edge - height) // 2
|
||||||
|
image = tv_functional.pad(
|
||||||
|
image, [left, top, square_edge - width - left, square_edge - height - top], fill=0
|
||||||
|
)
|
||||||
|
|
||||||
resize_edge = shortest_image_edge or target_h
|
resize_edge = shortest_image_edge or target_h
|
||||||
image = image.resize((resize_edge, resize_edge), Image.Resampling.BICUBIC)
|
image = tv_functional.resize(
|
||||||
|
image, [resize_edge, resize_edge], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||||
|
)
|
||||||
|
|
||||||
if crop_fraction is None and image_crop_size is not None:
|
if crop_fraction is None and image_crop_size is not None:
|
||||||
crop_fraction = image_crop_size[0] / float(target_h)
|
crop_fraction = image_crop_size[0] / float(target_h)
|
||||||
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
if crop_fraction is not None and 0.0 < crop_fraction < 1.0:
|
||||||
crop_w = max(1, int(round(image.width * crop_fraction)))
|
# Match the PIL helper's center crop exactly: round() the crop size but
|
||||||
crop_h = max(1, int(round(image.height * crop_fraction)))
|
# floor() the offset (torchvision.center_crop rounds the offset, which
|
||||||
left = max(0, (image.width - crop_w) // 2)
|
# shifts the region by 1px when (edge - crop) is odd).
|
||||||
top = max(0, (image.height - crop_h) // 2)
|
crop_h = max(1, int(round(image.shape[-2] * crop_fraction)))
|
||||||
image = image.crop((left, top, left + crop_w, top + crop_h))
|
crop_w = max(1, int(round(image.shape[-1] * crop_fraction)))
|
||||||
|
top = max(0, (image.shape[-2] - crop_h) // 2)
|
||||||
|
left = max(0, (image.shape[-1] - crop_w) // 2)
|
||||||
|
image = image[..., top : top + crop_h, left : left + crop_w]
|
||||||
|
|
||||||
if image.size != (target_w, target_h):
|
if tuple(image.shape[-2:]) != (target_h, target_w):
|
||||||
image = image.resize((target_w, target_h), Image.Resampling.BICUBIC)
|
image = tv_functional.resize(
|
||||||
|
image, [target_h, target_w], interpolation=InterpolationMode.BICUBIC, antialias=True
|
||||||
|
)
|
||||||
return image
|
return image
|
||||||
|
|
||||||
|
|
||||||
@@ -902,8 +1066,8 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
video_modality_keys: list[str] | None = None
|
video_modality_keys: list[str] | None = None
|
||||||
raw_stats: dict[str, Any] | None = None
|
raw_stats: dict[str, Any] | None = None
|
||||||
modality_config: dict[str, Any] | None = None
|
modality_config: dict[str, Any] | None = None
|
||||||
state_cache_key: str = ""
|
|
||||||
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
_last_raw_state: dict[str, np.ndarray] | None = field(default=None, init=False, repr=False)
|
||||||
|
_warned_image_keys: bool = field(default=False, init=False, repr=False)
|
||||||
|
|
||||||
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
def _ordered_image_keys(self, obs: dict[str, Any]) -> list[str]:
|
||||||
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
available = {key for key in obs if key.startswith(OBS_IMAGES)}
|
||||||
@@ -913,19 +1077,56 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
return sorted(available)
|
return sorted(available)
|
||||||
|
|
||||||
ordered: list[str] = []
|
ordered: list[str] = []
|
||||||
|
unmatched: list[str] = []
|
||||||
for modality_key in self.video_modality_keys:
|
for modality_key in self.video_modality_keys:
|
||||||
candidates = [f"{OBS_IMAGES}.{modality_key}"]
|
candidates = [f"{OBS_IMAGES}.{modality_key}"]
|
||||||
|
# Alias for datasets converted with generic camera names (e.g. the
|
||||||
|
# LIBERO conversions expose the wrist camera as
|
||||||
|
# `observation.images.image2`), so raw N1.7 LIBERO checkpoints
|
||||||
|
# match those datasets out of the box.
|
||||||
if modality_key == "wrist_image":
|
if modality_key == "wrist_image":
|
||||||
candidates.append(f"{OBS_IMAGES}.image2")
|
candidates.append(f"{OBS_IMAGES}.image2")
|
||||||
elif modality_key == "image":
|
|
||||||
candidates.append(f"{OBS_IMAGES}.image")
|
|
||||||
|
|
||||||
match = next((candidate for candidate in candidates if candidate in available), None)
|
match = next((candidate for candidate in candidates if candidate in available), None)
|
||||||
if match is not None:
|
if match is None:
|
||||||
|
unmatched.append(modality_key)
|
||||||
|
else:
|
||||||
ordered.append(match)
|
ordered.append(match)
|
||||||
|
|
||||||
if not ordered:
|
if not ordered:
|
||||||
|
if not self._warned_image_keys:
|
||||||
|
self._warned_image_keys = True
|
||||||
|
logging.warning(
|
||||||
|
"None of the GR00T N1.7 checkpoint video modality keys %s match a camera among %s; "
|
||||||
|
"falling back to feeding all cameras in alphabetical order, which is unlikely to be "
|
||||||
|
"the layout the checkpoint was trained with. Rename the dataset cameras (e.g. via "
|
||||||
|
"--rename_map) to match the checkpoint keys.",
|
||||||
|
self.video_modality_keys,
|
||||||
|
sorted(available),
|
||||||
|
)
|
||||||
return sorted(available)
|
return sorted(available)
|
||||||
|
unused = sorted(available - set(ordered))
|
||||||
|
if (unmatched or unused) and not self._warned_image_keys:
|
||||||
|
self._warned_image_keys = True
|
||||||
|
if unmatched:
|
||||||
|
logging.warning(
|
||||||
|
"GR00T N1.7 checkpoint video modality keys %s have no matching camera among %s; "
|
||||||
|
"the model will receive %d view(s) instead of the %d it was trained with. Rename "
|
||||||
|
"the dataset cameras (e.g. via --rename_map) to match the checkpoint keys %s.",
|
||||||
|
unmatched,
|
||||||
|
sorted(available),
|
||||||
|
len(ordered),
|
||||||
|
len(self.video_modality_keys),
|
||||||
|
self.video_modality_keys,
|
||||||
|
)
|
||||||
|
if unused:
|
||||||
|
logging.warning(
|
||||||
|
"Dropping camera(s) %s: the GR00T N1.7 checkpoint only consumes the video modality "
|
||||||
|
"keys %s, which matched %s.",
|
||||||
|
unused,
|
||||||
|
self.video_modality_keys,
|
||||||
|
ordered,
|
||||||
|
)
|
||||||
return ordered
|
return ordered
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
@@ -988,7 +1189,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
start_idx += dim
|
start_idx += dim
|
||||||
if grouped:
|
if grouped:
|
||||||
self._last_raw_state = grouped
|
self._last_raw_state = grouped
|
||||||
_N1_7_RAW_STATE_CACHE[_n1_7_state_cache_key(self.state_cache_key)] = grouped
|
|
||||||
|
|
||||||
img_keys = self._ordered_image_keys(obs)
|
img_keys = self._ordered_image_keys(obs)
|
||||||
if img_keys:
|
if img_keys:
|
||||||
@@ -1101,7 +1301,6 @@ class GrootN17PackInputsStep(ProcessorStep):
|
|||||||
"video_modality_keys": self.video_modality_keys,
|
"video_modality_keys": self.video_modality_keys,
|
||||||
"raw_stats": self.raw_stats,
|
"raw_stats": self.raw_stats,
|
||||||
"modality_config": self.modality_config,
|
"modality_config": self.modality_config,
|
||||||
"state_cache_key": self.state_cache_key,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
def get_cached_raw_state(self) -> dict[str, np.ndarray] | None:
|
||||||
@@ -1139,6 +1338,12 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
The packed video has shape ``(B, T, V, H, W, C)``. Each frame/view becomes
|
||||||
an image item in the same chat message so the resulting image tokens match
|
an image item in the same chat message so the resulting image tokens match
|
||||||
the temporal VLM packing used by Isaac-GR00T.
|
the temporal VLM packing used by Isaac-GR00T.
|
||||||
|
|
||||||
|
Images are handed to the torchvision-backed Qwen3-VL processor as ``(C, H, W)``
|
||||||
|
uint8 tensors (no per-frame PIL roundtrip), and, when ``device`` resolves to a
|
||||||
|
CUDA device, the resize/rescale/normalize/patchify run there. This keeps the
|
||||||
|
output bit-identical on CPU and moves the dominant preprocessing cost off
|
||||||
|
the critical path on GPU.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
model_name: str = GROOT_N1_7_BACKBONE_MODEL
|
||||||
@@ -1147,6 +1352,7 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
shortest_image_edge: int | None = None
|
shortest_image_edge: int | None = None
|
||||||
crop_fraction: float | None = None
|
crop_fraction: float | None = None
|
||||||
use_albumentations: bool = False
|
use_albumentations: bool = False
|
||||||
|
device: str | None = None
|
||||||
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
_proc: ProcessorMixin | None = field(default=None, init=False, repr=False)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
@@ -1155,6 +1361,69 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
self._proc = _build_n1_7_processor(self.model_name)
|
self._proc = _build_n1_7_processor(self.model_name)
|
||||||
return self._proc
|
return self._proc
|
||||||
|
|
||||||
|
def _target_device(self) -> torch.device | None:
|
||||||
|
# The albumentations path is cv2/PIL only, so it cannot run on GPU.
|
||||||
|
if self.device is None or self.use_albumentations:
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
return get_safe_torch_device(self.device)
|
||||||
|
except (AssertionError, RuntimeError):
|
||||||
|
# A device serialized at train time (e.g. "cuda") may be unavailable
|
||||||
|
# when the processor is reloaded elsewhere (e.g. CPU-only eval), and
|
||||||
|
# this step is not in the standard device-override set. Fall back to
|
||||||
|
# the CPU path, which is bit-identical, instead of crashing.
|
||||||
|
return None
|
||||||
|
|
||||||
|
def _build_sample_images(
|
||||||
|
self, video: Any, batch_size: int, target_device: torch.device | None
|
||||||
|
) -> list[list[Any]]:
|
||||||
|
"""Return, per batch item, its ordered ``(timestep, view)`` frames.
|
||||||
|
|
||||||
|
``use_albumentations`` keeps the legacy per-frame PIL/cv2 transform;
|
||||||
|
otherwise frames are ``(C, H, W)`` uint8 tensors (moved to
|
||||||
|
``target_device`` when set) for the torchvision-backed Qwen processor.
|
||||||
|
"""
|
||||||
|
if self.use_albumentations:
|
||||||
|
video_np = np.asarray(video)
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
_transform_n1_7_image_for_vlm_albumentations(
|
||||||
|
Image.fromarray(video_np[batch_idx, timestep, view_idx]),
|
||||||
|
image_crop_size=self.image_crop_size,
|
||||||
|
image_target_size=self.image_target_size,
|
||||||
|
shortest_image_edge=self.shortest_image_edge,
|
||||||
|
crop_fraction=self.crop_fraction,
|
||||||
|
)
|
||||||
|
for timestep in range(video_np.shape[1])
|
||||||
|
for view_idx in range(video_np.shape[2])
|
||||||
|
]
|
||||||
|
for batch_idx in range(batch_size)
|
||||||
|
]
|
||||||
|
|
||||||
|
video_t = video if torch.is_tensor(video) else torch.from_numpy(np.ascontiguousarray(video))
|
||||||
|
# (B, T, V, H, W, C) uint8 -> (B, T, V, C, H, W)
|
||||||
|
video_t = video_t.permute(0, 1, 2, 5, 3, 4).contiguous()
|
||||||
|
if target_device is not None and video_t.device != target_device:
|
||||||
|
video_t = video_t.to(target_device, non_blocking=(target_device.type == "cuda"))
|
||||||
|
|
||||||
|
frames_per_sample: list[list[Any]] = []
|
||||||
|
for batch_idx in range(batch_size):
|
||||||
|
sample = video_t[batch_idx] # (T, V, C, H, W)
|
||||||
|
frames_per_sample.append(
|
||||||
|
[
|
||||||
|
_transform_n1_7_image_for_vlm_torch(
|
||||||
|
sample[timestep, view_idx],
|
||||||
|
image_crop_size=self.image_crop_size,
|
||||||
|
image_target_size=self.image_target_size,
|
||||||
|
shortest_image_edge=self.shortest_image_edge,
|
||||||
|
crop_fraction=self.crop_fraction,
|
||||||
|
)
|
||||||
|
for timestep in range(sample.shape[0])
|
||||||
|
for view_idx in range(sample.shape[1])
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return frames_per_sample
|
||||||
|
|
||||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
obs = transition.get(TransitionKey.OBSERVATION, {}) or {}
|
||||||
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
comp = transition.get(TransitionKey.COMPLEMENTARY_DATA, {}) or {}
|
||||||
@@ -1162,33 +1431,25 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
if video is None:
|
if video is None:
|
||||||
return transition
|
return transition
|
||||||
|
|
||||||
|
batch_size = int(video.shape[0])
|
||||||
languages = _prepare_n1_7_language_batch(
|
languages = _prepare_n1_7_language_batch(
|
||||||
comp.get("language"),
|
comp.get("language"),
|
||||||
video.shape[0],
|
batch_size,
|
||||||
formalize_language=False,
|
formalize_language=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
target_device = self._target_device()
|
||||||
|
sample_images = self._build_sample_images(video, batch_size, target_device)
|
||||||
|
|
||||||
texts: list[str] = []
|
texts: list[str] = []
|
||||||
images: list[Image.Image] = []
|
images: list[Any] = []
|
||||||
for batch_idx in range(video.shape[0]):
|
for batch_idx in range(batch_size):
|
||||||
sample = video[batch_idx] # (T, V, H, W, C)
|
frames = sample_images[batch_idx]
|
||||||
sample_images = [
|
|
||||||
_transform_n1_7_image_for_vlm(
|
|
||||||
Image.fromarray(sample[timestep, view_idx]),
|
|
||||||
image_crop_size=self.image_crop_size,
|
|
||||||
image_target_size=self.image_target_size,
|
|
||||||
shortest_image_edge=self.shortest_image_edge,
|
|
||||||
crop_fraction=self.crop_fraction,
|
|
||||||
use_albumentations=self.use_albumentations,
|
|
||||||
)
|
|
||||||
for timestep in range(sample.shape[0])
|
|
||||||
for view_idx in range(sample.shape[1])
|
|
||||||
]
|
|
||||||
conversation = [
|
conversation = [
|
||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": [
|
"content": [
|
||||||
*[{"type": "image", "image": image} for image in sample_images],
|
*[{"type": "image", "image": image} for image in frames],
|
||||||
{"type": "text", "text": languages[batch_idx]},
|
{"type": "text", "text": languages[batch_idx]},
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
@@ -1200,9 +1461,17 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
add_generation_prompt=False,
|
add_generation_prompt=False,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
images.extend(sample_images)
|
images.extend(frames)
|
||||||
|
|
||||||
encoded = self.proc(text=texts, images=images, return_tensors="pt", padding=True)
|
proc_kwargs: dict[str, Any] = {
|
||||||
|
"text": texts,
|
||||||
|
"images": images,
|
||||||
|
"return_tensors": "pt",
|
||||||
|
"padding": True,
|
||||||
|
}
|
||||||
|
if target_device is not None:
|
||||||
|
proc_kwargs["device"] = str(target_device)
|
||||||
|
encoded = self.proc(**proc_kwargs)
|
||||||
for key, value in encoded.items():
|
for key, value in encoded.items():
|
||||||
comp[key] = value
|
comp[key] = value
|
||||||
obs.pop("video", None)
|
obs.pop("video", None)
|
||||||
@@ -1221,8 +1490,10 @@ class GrootN17VLMEncodeStep(ProcessorStep):
|
|||||||
"shortest_image_edge": self.shortest_image_edge,
|
"shortest_image_edge": self.shortest_image_edge,
|
||||||
"crop_fraction": self.crop_fraction,
|
"crop_fraction": self.crop_fraction,
|
||||||
"use_albumentations": self.use_albumentations,
|
"use_albumentations": self.use_albumentations,
|
||||||
|
"device": self.device,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
def _stat_dim_from_entry(entry: dict[str, Any]) -> int:
|
||||||
for stat_name in ("mean", "q01", "min", "max", "std"):
|
for stat_name in ("mean", "q01", "min", "max", "std"):
|
||||||
value = entry.get(stat_name)
|
value = entry.get(stat_name)
|
||||||
@@ -1351,6 +1622,18 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
group with the checkpoint stats, converts relative groups to absolute values
|
group with the checkpoint stats, converts relative groups to absolute values
|
||||||
using the raw state cached during packing, concatenates groups in checkpoint
|
using the raw state cached during packing, concatenates groups in checkpoint
|
||||||
order, and finally slices to the environment action dimension.
|
order, and finally slices to the environment action dimension.
|
||||||
|
|
||||||
|
Relative-action decoding reads the reference state from the connected
|
||||||
|
``pack_step`` (re-linked after ``from_pretrained`` by
|
||||||
|
``_reconnect_groot_n1_7_pack_decode_steps``), i.e. the state seen by the
|
||||||
|
most recent preprocess call. Engines that decode the whole chunk right
|
||||||
|
after prediction (RTC, async policy server) therefore use the
|
||||||
|
prediction-time state, matching Isaac-GR00T. The sync per-step queue path
|
||||||
|
instead decodes each popped (B, D) action against the latest observation:
|
||||||
|
the reference can be newer than the observation the chunk was predicted
|
||||||
|
from, and per-timestep relative stats are applied as if the popped action
|
||||||
|
were chunk step 0. Fixing that would require carrying the reference state
|
||||||
|
and chunk index alongside each queued action through the postprocessor.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
env_action_dim: int = 0
|
env_action_dim: int = 0
|
||||||
@@ -1358,7 +1641,6 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
modality_config: dict[str, Any] | None = None
|
modality_config: dict[str, Any] | None = None
|
||||||
use_percentiles: bool = False
|
use_percentiles: bool = False
|
||||||
use_relative_action: bool = False
|
use_relative_action: bool = False
|
||||||
state_cache_key: str = ""
|
|
||||||
action_decode_transform: str | None = None
|
action_decode_transform: str | None = None
|
||||||
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
pack_step: GrootN17PackInputsStep | None = field(default=None, repr=False)
|
||||||
|
|
||||||
@@ -1378,6 +1660,12 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
return transition
|
return transition
|
||||||
|
|
||||||
action_np = action.detach().cpu().float().numpy()
|
action_np = action.detach().cpu().float().numpy()
|
||||||
|
# The sync action queue postprocesses popped actions as (B, D); decode
|
||||||
|
# them as single-step (B, 1, D) chunks and squeeze the horizon back at
|
||||||
|
# the end so both ranks share the chunk decode logic below.
|
||||||
|
squeeze_horizon = action_np.ndim == 2
|
||||||
|
if squeeze_horizon:
|
||||||
|
action_np = action_np[:, None, :]
|
||||||
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
|
valid_horizon = _n1_7_decode_valid_horizon(action_config, action_np)
|
||||||
if valid_horizon is not None:
|
if valid_horizon is not None:
|
||||||
action_np = action_np[:, :valid_horizon]
|
action_np = action_np[:, :valid_horizon]
|
||||||
@@ -1405,17 +1693,24 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
use_relative_action=self.use_relative_action,
|
use_relative_action=self.use_relative_action,
|
||||||
use_percentiles=self.use_percentiles,
|
use_percentiles=self.use_percentiles,
|
||||||
)
|
)
|
||||||
|
# Per-timestep stats carry one row per chunk step; align them with
|
||||||
|
# the decoded horizon (chunks always start at step 0, and a popped
|
||||||
|
# (B, D) action is decoded as step 0).
|
||||||
|
if min_v.ndim == 2 and normalized.shape[1] <= min_v.shape[0]:
|
||||||
|
min_v = min_v[: normalized.shape[1]]
|
||||||
|
max_v = max_v[: normalized.shape[1]]
|
||||||
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
|
decoded_groups[key] = _unnormalize_min_max(normalized, min_v, max_v)
|
||||||
start_idx += dim
|
start_idx += dim
|
||||||
|
|
||||||
if self.use_relative_action:
|
if self.use_relative_action:
|
||||||
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
raw_state = self.pack_step.get_cached_raw_state() if self.pack_step is not None else None
|
||||||
if raw_state is None:
|
|
||||||
raw_state = _N1_7_RAW_STATE_CACHE.get(_n1_7_state_cache_key(self.state_cache_key))
|
|
||||||
if raw_state is None:
|
if raw_state is None:
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
"GrootN17ActionDecodeStep requires cached raw state from GrootN17PackInputsStep "
|
"GrootN17ActionDecodeStep requires the raw state cached by its connected "
|
||||||
"to convert relative N1.7 actions back to absolute actions."
|
"GrootN17PackInputsStep to convert relative N1.7 actions back to absolute actions. "
|
||||||
|
"Build both pipelines through make_groot_pre_post_processors (or load them together "
|
||||||
|
"via make_groot_pre_post_processors_from_pretrained) and run the preprocessor on an "
|
||||||
|
"observation before decoding actions."
|
||||||
)
|
)
|
||||||
for idx, key in enumerate(action_keys):
|
for idx, key in enumerate(action_keys):
|
||||||
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
|
if not isinstance(key, str) or key not in decoded_groups or idx >= len(action_configs):
|
||||||
@@ -1451,6 +1746,8 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
action_keys=action_keys,
|
action_keys=action_keys,
|
||||||
decoded_groups=decoded_groups,
|
decoded_groups=decoded_groups,
|
||||||
)
|
)
|
||||||
|
if squeeze_horizon:
|
||||||
|
decoded = decoded[:, 0]
|
||||||
new_transition = transition.copy()
|
new_transition = transition.copy()
|
||||||
new_transition[TransitionKey.ACTION] = torch.as_tensor(
|
new_transition[TransitionKey.ACTION] = torch.as_tensor(
|
||||||
decoded, dtype=action.dtype, device=action.device
|
decoded, dtype=action.dtype, device=action.device
|
||||||
@@ -1467,13 +1764,15 @@ class GrootN17ActionDecodeStep(ProcessorStep):
|
|||||||
"modality_config": self.modality_config,
|
"modality_config": self.modality_config,
|
||||||
"use_percentiles": self.use_percentiles,
|
"use_percentiles": self.use_percentiles,
|
||||||
"use_relative_action": self.use_relative_action,
|
"use_relative_action": self.use_relative_action,
|
||||||
"state_cache_key": self.state_cache_key,
|
|
||||||
"action_decode_transform": self.action_decode_transform,
|
"action_decode_transform": self.action_decode_transform,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# v2: unlike the N1.5-era v1 step, this step no longer collapses (B, T, D)
|
||||||
|
# action chunks to the last timestep, so old serialized v1 pipelines must not
|
||||||
|
# silently load into it (v1 is stubbed below with the removal guidance).
|
||||||
@dataclass
|
@dataclass
|
||||||
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v1")
|
@ProcessorStepRegistry.register(name="groot_action_unpack_unnormalize_v2")
|
||||||
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
||||||
env_action_dim: int = 0
|
env_action_dim: int = 0
|
||||||
# Apply inverse of min-max normalization if it was used in preprocessor
|
# Apply inverse of min-max normalization if it was used in preprocessor
|
||||||
@@ -1585,3 +1884,37 @@ class GrootActionUnpackUnnormalizeStep(ProcessorStep):
|
|||||||
|
|
||||||
if reconstructed:
|
if reconstructed:
|
||||||
self.stats = reconstructed
|
self.stats = reconstructed
|
||||||
|
|
||||||
|
|
||||||
|
def _register_removed_n1_5_step_stub(registry_name: str) -> None:
|
||||||
|
"""Register a stub for a processor step that only GR00T N1.5 pipelines serialize.
|
||||||
|
|
||||||
|
Saved N1.5 checkpoints reference these registry names in their processor JSON
|
||||||
|
files. Deserializing them must fail with the canonical N1.5 removal guidance
|
||||||
|
instead of an opaque registry KeyError (or, for
|
||||||
|
``groot_action_unpack_unnormalize_v1``, silently loading the v2 step whose
|
||||||
|
action-chunk semantics changed).
|
||||||
|
"""
|
||||||
|
|
||||||
|
@ProcessorStepRegistry.register(name=registry_name)
|
||||||
|
class _RemovedGrootN15ProcessorStep(ProcessorStep):
|
||||||
|
def __init__(self, **_kwargs: Any) -> None:
|
||||||
|
raise ValueError(
|
||||||
|
f"Processor step '{registry_name}' belongs to a GR00T N1.5 processor pipeline. "
|
||||||
|
f"{GROOT_N1_5_REMOVAL_GUIDANCE}"
|
||||||
|
)
|
||||||
|
|
||||||
|
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def transform_features(self, features):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
|
for _removed_n1_5_step_name in (
|
||||||
|
"groot_pack_inputs_v3",
|
||||||
|
"groot_eagle_encode_v3",
|
||||||
|
"groot_eagle_collate_v3",
|
||||||
|
"groot_action_unpack_unnormalize_v1",
|
||||||
|
):
|
||||||
|
_register_removed_n1_5_step_stub(_removed_n1_5_step_name)
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ from lerobot.policies.groot.processor_groot import (
|
|||||||
GrootN17ActionDecodeStep,
|
GrootN17ActionDecodeStep,
|
||||||
GrootN17PackInputsStep,
|
GrootN17PackInputsStep,
|
||||||
GrootN17VLMEncodeStep,
|
GrootN17VLMEncodeStep,
|
||||||
_transform_n1_7_image_for_vlm,
|
_transform_n1_7_image_for_vlm_albumentations,
|
||||||
make_groot_pre_post_processors,
|
make_groot_pre_post_processors,
|
||||||
)
|
)
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
@@ -1529,13 +1529,12 @@ def test_groot_n1_7_vlm_image_transform_matches_albumentations_eval_path():
|
|||||||
|
|
||||||
image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3)
|
image_np = (np.arange(360 * 360 * 3, dtype=np.uint32) % 251).astype(np.uint8).reshape(360, 360, 3)
|
||||||
|
|
||||||
transformed = _transform_n1_7_image_for_vlm(
|
transformed = _transform_n1_7_image_for_vlm_albumentations(
|
||||||
Image.fromarray(image_np),
|
Image.fromarray(image_np),
|
||||||
image_crop_size=[230, 230],
|
image_crop_size=[230, 230],
|
||||||
image_target_size=[256, 256],
|
image_target_size=[256, 256],
|
||||||
shortest_image_edge=256,
|
shortest_image_edge=256,
|
||||||
crop_fraction=0.95,
|
crop_fraction=0.95,
|
||||||
use_albumentations=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)
|
expected = cv2.resize(image_np, (256, 256), interpolation=cv2.INTER_AREA)
|
||||||
|
|||||||
Reference in New Issue
Block a user