mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
sync recent changes
This commit is contained in:
@@ -0,0 +1,142 @@
|
||||
import numpy as np
|
||||
import torch
|
||||
from typing import Any
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
import logging
|
||||
from dataclasses import dataclass
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
|
||||
import logging_mp
|
||||
|
||||
logging_mp.basic_config(level=logging_mp.INFO)
|
||||
logger_mp = logging_mp.get_logger(__name__)
|
||||
|
||||
|
||||
def extract_observation(step: dict):
|
||||
observation = {}
|
||||
|
||||
for key, value in step.items():
|
||||
if key.startswith("observation.images."):
|
||||
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
|
||||
value = np.transpose(value, (2, 0, 1))
|
||||
observation[key] = value
|
||||
|
||||
elif key == "observation.state":
|
||||
observation[key] = value
|
||||
|
||||
return observation
|
||||
|
||||
|
||||
def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
use_dataset: bool | None = False,
|
||||
):
|
||||
observation = copy(observation)
|
||||
with (
|
||||
torch.inference_mode(),
|
||||
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
||||
):
|
||||
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
||||
for name in observation:
|
||||
if not use_dataset:
|
||||
# Skip non-tensor observations (like task strings)
|
||||
if not hasattr(observation[name], "unsqueeze"):
|
||||
continue
|
||||
if "images" in name:
|
||||
observation[name] = observation[name].type(torch.float32) / 255
|
||||
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
||||
|
||||
observation[name] = observation[name].unsqueeze(0).to(device)
|
||||
|
||||
observation["task"] = [task if task else ""]
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
# Move to cpu, if not already the case
|
||||
action = action.to("cpu")
|
||||
|
||||
return action
|
||||
|
||||
|
||||
def reset_policy(policy: PreTrainedPolicy):
|
||||
policy.reset()
|
||||
|
||||
|
||||
def cleanup_resources(image_info: dict[str, Any]):
|
||||
"""Safely close and unlink shared memory resources."""
|
||||
logger_mp.info("Cleaning up shared memory resources.")
|
||||
for shm in image_info["shm_resources"]:
|
||||
if shm:
|
||||
shm.close()
|
||||
shm.unlink()
|
||||
|
||||
|
||||
def to_list(x):
|
||||
if torch is not None and isinstance(x, torch.Tensor):
|
||||
return x.detach().cpu().ravel().tolist()
|
||||
if isinstance(x, np.ndarray):
|
||||
return x.ravel().tolist()
|
||||
if isinstance(x, (list, tuple)):
|
||||
return list(x)
|
||||
return [x]
|
||||
|
||||
|
||||
def to_scalar(x):
|
||||
if torch is not None and isinstance(x, torch.Tensor):
|
||||
return float(x.detach().cpu().ravel()[0].item())
|
||||
if isinstance(x, np.ndarray):
|
||||
return float(x.ravel()[0])
|
||||
if isinstance(x, (list, tuple)):
|
||||
return float(x[0])
|
||||
return float(x)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalRealConfig:
|
||||
repo_id: str
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
root: str = ""
|
||||
episodes: int = 0
|
||||
frequency: float = 30.0
|
||||
|
||||
# Basic control parameters
|
||||
arm: str = "G1_29" # G1_29, G1_23
|
||||
ee: str = "dex3" # dex3, dex1, inspire1, brainco
|
||||
|
||||
# Mode flags
|
||||
motion: bool = False
|
||||
headless: bool = False
|
||||
visualization: bool = False
|
||||
send_real_robot: bool = False
|
||||
use_dataset: bool = False
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
logging.warning(
|
||||
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
Reference in New Issue
Block a user