mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
feat: add async server-client streaming support for Groot policy (#2812)
This commit is contained in:
@@ -23,7 +23,7 @@ DEFAULT_INFERENCE_LATENCY = 1 / DEFAULT_FPS
|
|||||||
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
DEFAULT_OBS_QUEUE_TIMEOUT = 2
|
||||||
|
|
||||||
# All action chunking policies
|
# All action chunking policies
|
||||||
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05"]
|
SUPPORTED_POLICIES = ["act", "smolvla", "diffusion", "tdmpc", "vqbet", "pi0", "pi05", "groot"]
|
||||||
|
|
||||||
# TODO: Add all other robots
|
# TODO: Add all other robots
|
||||||
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
SUPPORTED_ROBOTS = ["so100_follower", "so101_follower", "bi_so_follower", "omx_follower"]
|
||||||
|
|||||||
@@ -32,16 +32,22 @@ Notes:
|
|||||||
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
from LeRobot, see `GrootPolicy.finetune_with_groot_runner` below.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import builtins
|
||||||
import os
|
import os
|
||||||
from collections import deque
|
from collections import deque
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||||
from lerobot.policies.groot.configuration_groot import GrootConfig
|
from lerobot.policies.groot.configuration_groot import GrootConfig
|
||||||
from lerobot.policies.groot.groot_n1 import GR00TN15
|
from lerobot.policies.groot.groot_n1 import GR00TN15
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.utils.constants import ACTION
|
from lerobot.utils.constants import ACTION, OBS_IMAGES
|
||||||
|
|
||||||
|
T = TypeVar("T", bound="GrootPolicy")
|
||||||
|
|
||||||
|
|
||||||
class GrootPolicy(PreTrainedPolicy):
|
class GrootPolicy(PreTrainedPolicy):
|
||||||
@@ -90,6 +96,129 @@ class GrootPolicy(PreTrainedPolicy):
|
|||||||
"""Reset policy state when environment resets."""
|
"""Reset policy state when environment resets."""
|
||||||
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
self._action_queue = deque([], maxlen=self.config.n_action_steps)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def from_pretrained(
|
||||||
|
cls: builtins.type[T],
|
||||||
|
pretrained_name_or_path: str | Path,
|
||||||
|
*,
|
||||||
|
config: GrootConfig | None = None,
|
||||||
|
force_download: bool = False,
|
||||||
|
resume_download: bool | None = None,
|
||||||
|
proxies: dict | None = None,
|
||||||
|
token: str | bool | None = None,
|
||||||
|
cache_dir: str | Path | None = None,
|
||||||
|
local_files_only: bool = False,
|
||||||
|
revision: str | None = None,
|
||||||
|
strict: bool = True,
|
||||||
|
**kwargs,
|
||||||
|
) -> T:
|
||||||
|
"""Load Groot policy from pretrained model.
|
||||||
|
|
||||||
|
Handles two cases:
|
||||||
|
1. Base GR00T models (e.g., 'nvidia/GR00T-N1.5-3B') - loads the raw model
|
||||||
|
2. Fine-tuned LeRobot checkpoints - loads config and weights from safetensors
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pretrained_name_or_path: Path to the GR00T model or fine-tuned checkpoint
|
||||||
|
config: Optional GrootConfig. If None, loads from checkpoint or creates default
|
||||||
|
force_download: Force download even if cached
|
||||||
|
resume_download: Resume interrupted download
|
||||||
|
proxies: Proxy settings
|
||||||
|
token: HuggingFace authentication token
|
||||||
|
cache_dir: Cache directory path
|
||||||
|
local_files_only: Only use local files
|
||||||
|
revision: Specific model revision
|
||||||
|
strict: Strict state dict loading
|
||||||
|
**kwargs: Additional arguments (passed to config)
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Initialized GrootPolicy instance with loaded model
|
||||||
|
"""
|
||||||
|
from huggingface_hub import hf_hub_download
|
||||||
|
from huggingface_hub.constants import SAFETENSORS_SINGLE_FILE
|
||||||
|
from huggingface_hub.errors import HfHubHTTPError
|
||||||
|
|
||||||
|
print(
|
||||||
|
"The Groot policy is a wrapper around Nvidia's GR00T N1.5 model.\n"
|
||||||
|
f"Loading pretrained model from: {pretrained_name_or_path}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_id = str(pretrained_name_or_path)
|
||||||
|
is_finetuned_checkpoint = False
|
||||||
|
|
||||||
|
# Check if this is a fine-tuned LeRobot checkpoint (has model.safetensors)
|
||||||
|
try:
|
||||||
|
if os.path.isdir(model_id):
|
||||||
|
is_finetuned_checkpoint = os.path.exists(os.path.join(model_id, SAFETENSORS_SINGLE_FILE))
|
||||||
|
else:
|
||||||
|
# Try to download the safetensors file to check if it exists
|
||||||
|
try:
|
||||||
|
hf_hub_download(
|
||||||
|
repo_id=model_id,
|
||||||
|
filename=SAFETENSORS_SINGLE_FILE,
|
||||||
|
revision=revision,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
force_download=False, # Just check, don't force download
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
)
|
||||||
|
is_finetuned_checkpoint = True
|
||||||
|
except HfHubHTTPError:
|
||||||
|
is_finetuned_checkpoint = False
|
||||||
|
except Exception:
|
||||||
|
is_finetuned_checkpoint = False
|
||||||
|
|
||||||
|
if is_finetuned_checkpoint:
|
||||||
|
# This is a fine-tuned LeRobot checkpoint - use parent class loading
|
||||||
|
print("Detected fine-tuned LeRobot checkpoint, loading with state dict...")
|
||||||
|
return super().from_pretrained(
|
||||||
|
pretrained_name_or_path=pretrained_name_or_path,
|
||||||
|
config=config,
|
||||||
|
force_download=force_download,
|
||||||
|
resume_download=resume_download,
|
||||||
|
proxies=proxies,
|
||||||
|
token=token,
|
||||||
|
cache_dir=cache_dir,
|
||||||
|
local_files_only=local_files_only,
|
||||||
|
revision=revision,
|
||||||
|
strict=strict,
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is a base GR00T model - load it fresh
|
||||||
|
print("Detected base GR00T model, loading from HuggingFace...")
|
||||||
|
|
||||||
|
if config is None:
|
||||||
|
# Create default config with the pretrained path
|
||||||
|
config = GrootConfig(base_model_path=str(pretrained_name_or_path))
|
||||||
|
|
||||||
|
# Add minimal visual feature required for validation
|
||||||
|
# validate_features() will automatically add state and action features
|
||||||
|
# These are placeholders - actual robot features come from the preprocessor
|
||||||
|
if not config.input_features:
|
||||||
|
config.input_features = {
|
||||||
|
f"{OBS_IMAGES}.camera": PolicyFeature(
|
||||||
|
type=FeatureType.VISUAL,
|
||||||
|
shape=(3, 224, 224), # Default image size from config
|
||||||
|
),
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
# Override the base_model_path with the provided path
|
||||||
|
config.base_model_path = str(pretrained_name_or_path)
|
||||||
|
|
||||||
|
# Pass through any additional config overrides from kwargs
|
||||||
|
for key, value in kwargs.items():
|
||||||
|
if hasattr(config, key):
|
||||||
|
setattr(config, key, value)
|
||||||
|
|
||||||
|
# Create a fresh policy instance - this will automatically load the GR00T model
|
||||||
|
# in __init__ via _create_groot_model()
|
||||||
|
policy = cls(config)
|
||||||
|
|
||||||
|
policy.eval()
|
||||||
|
return policy
|
||||||
|
|
||||||
def get_optim_params(self) -> dict:
|
def get_optim_params(self) -> dict:
|
||||||
return self.parameters()
|
return self.parameters()
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user