feat: add async server-client streaming support for Groot policy (#2812)

This commit is contained in:
Maximilian Ofir
2026-01-19 22:13:48 +01:00
committed by GitHub
parent 5286ef8439
commit 66929c5935
2 changed files with 131 additions and 2 deletions
+1 -1
View File
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
DEFAULT_OBS_QUEUE_TIMEOUT = 2 DEFAULT_OBS_QUEUE_TIMEOUT = 2
# All action chunking policies # All action chunking policies
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"] SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
# TODO: Add all other robots # TODO: Add all other robots
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"] SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
+130 -1
View File
@@ -32,16 +32,22 @@ Notes:
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below. from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
""" """
import builtins
import os import os
from collections import deque from collections import deque
from pathlib import Path
from typing import TypeVar
import torch import torch
from torch import Tensor from torch import Tensor
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.groot.configuration_groot import GrootConfig from lerobot.policies.groot.configuration_groot import GrootConfig
from lerobot.policies.groot.groot_n1 import GR00TN15 from lerobot.policies.groot.groot_n1 import GR00TN15
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.utils.constants import ACTION from lerobot.utils.constants import ACTION, OBS_IMAGES
T = TypeVar("T", bound="GrootPolicy")
class GrootPolicy(PreTrainedPolicy): class GrootPolicy(PreTrainedPolicy):
@@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy):
"""Reset policy state when environment resets.""" """Reset policy state when environment resets."""
self._action_queue = deque([], maxlen=self.config.n_action_steps) self._action_queue = deque([], maxlen=self.config.n_action_steps)
@classmethod
def from_pretrained(
cls: builtins.type[T],
pretrained_name_or_path: str | Path,
*,
config: GrootConfig | None = None,
force_download: bool = False,
resume_download: bool | None = None,
proxies: dict | None = None,
token: str | bool | None = None,
cache_dir: str | Path | None = None,
local_files_only: bool = False,
revision: str | None = None,
strict: bool = True,
**kwargs,
) -> T:
"""Load Groot policy from pretrained model.
Handles two cases:
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
Args:
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
config: Optional GrootConfig. If None, loads from checkpoint or creates default
force_download: Force download even if cached
resume_download: Resume interrupted download
proxies: Proxy settings
token: HuggingFace authentication token
cache_dir: Cache directory path
local_files_only: Only use local files
revision: Specific model revision
strict: Strict state dict loading
**kwargs: Additional arguments (passed to config)
Returns:
Initialized GrootPolicy instance with loaded model
"""
from huggingface_hub import hf_hub_download
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
from huggingface_hub.errors import HfHubHTTPError
print(
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
f"Loading pretrained model from: {pretrained_name_or_path}"
)
model_id = str(pretrained_name_or_path)
is_finetuned_checkpoint = False
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
try:
if os.path.isdir(model_id):
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
else:
# Try to download the safetensors file to check if it exists
try:
hf_hub_download(
repo_id=model_id,
filename=SAFETENSORS_SINGLE_FILE,
revision=revision,
cache_dir=cache_dir,
force_download=False, # Just check, don't force download
proxies=proxies,
token=token,
local_files_only=local_files_only,
)
is_finetuned_checkpoint = True
except HfHubHTTPError:
is_finetuned_checkpoint = False
except Exception:
is_finetuned_checkpoint = False
if is_finetuned_checkpoint:
# This is a fine-tuned LeRobot checkpoint - use parent class loading
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
return super().from_pretrained(
pretrained_name_or_path=pretrained_name_or_path,
config=config,
force_download=force_download,
resume_download=resume_download,
proxies=proxies,
token=token,
cache_dir=cache_dir,
local_files_only=local_files_only,
revision=revision,
strict=strict,
**kwargs,
)
# This is a base GR00T model - load it fresh
print("Detected base GR00T model, loading from HuggingFace...")
if config is None:
# Create default config with the pretrained path
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
# Add minimal visual feature required for validation
# validate_features() will automatically add state and action features
# These are placeholders - actual robot features come from the preprocessor
if not config.input_features:
config.input_features = {
f"{OBS_IMAGES}.camera": PolicyFeature(
type=FeatureType.VISUAL,
shape=(3, 224, 224), # Default image size from config
),
}
else:
# Override the base_model_path with the provided path
config.base_model_path = str(pretrained_name_or_path)
# Pass through any additional config overrides from kwargs
for key, value in kwargs.items():
if hasattr(config, key):
setattr(config, key, value)
# Create a fresh policy instance - this will automatically load the GR00T model
# in __init__ via _create_groot_model()
policy = cls(config)
policy.eval()
return policy
def get_optim_params(self) -> dict: def get_optim_params(self) -> dict:
return self.parameters() return self.parameters()