mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 01:30:14 +00:00
143 lines
4.3 KiB
Python
143 lines
4.3 KiB
Python
import numpy as np
|
|
import torch
|
|
from typing import Any
|
|
from contextlib import nullcontext
|
|
from copy import copy
|
|
import logging
|
|
from dataclasses import dataclass
|
|
from lerobot.configs import parser
|
|
from lerobot.configs.policies import PreTrainedConfig
|
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
|
|
|
|
|
import logging_mp
|
|
|
|
logging_mp.basic_config(level=logging_mp.INFO)
|
|
logger_mp = logging_mp.get_logger(__name__)
|
|
|
|
|
|
def extract_observation(step: dict):
|
|
observation = {}
|
|
|
|
for key, value in step.items():
|
|
if key.startswith("observation.images."):
|
|
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
|
|
value = np.transpose(value, (2, 0, 1))
|
|
observation[key] = value
|
|
|
|
elif key == "observation.state":
|
|
observation[key] = value
|
|
|
|
return observation
|
|
|
|
|
|
def predict_action(
|
|
observation: dict[str, np.ndarray],
|
|
policy: PreTrainedPolicy,
|
|
device: torch.device,
|
|
use_amp: bool,
|
|
task: str | None = None,
|
|
use_dataset: bool | None = False,
|
|
):
|
|
observation = copy(observation)
|
|
with (
|
|
torch.inference_mode(),
|
|
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
|
|
):
|
|
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
|
|
for name in observation:
|
|
if not use_dataset:
|
|
# Skip non-tensor observations (like task strings)
|
|
if not hasattr(observation[name], "unsqueeze"):
|
|
continue
|
|
if "images" in name:
|
|
observation[name] = observation[name].type(torch.float32) / 255
|
|
observation[name] = observation[name].permute(2, 0, 1).contiguous()
|
|
|
|
observation[name] = observation[name].unsqueeze(0).to(device)
|
|
|
|
observation["task"] = [task if task else ""]
|
|
|
|
# Compute the next action with the policy
|
|
# based on the current observation
|
|
action = policy.select_action(observation)
|
|
|
|
# Remove batch dimension
|
|
action = action.squeeze(0)
|
|
|
|
# Move to cpu, if not already the case
|
|
action = action.to("cpu")
|
|
|
|
return action
|
|
|
|
|
|
def reset_policy(policy: PreTrainedPolicy):
|
|
policy.reset()
|
|
|
|
|
|
def cleanup_resources(image_info: dict[str, Any]):
|
|
"""Safely close and unlink shared memory resources."""
|
|
logger_mp.info("Cleaning up shared memory resources.")
|
|
for shm in image_info["shm_resources"]:
|
|
if shm:
|
|
shm.close()
|
|
shm.unlink()
|
|
|
|
|
|
def to_list(x):
|
|
if torch is not None and isinstance(x, torch.Tensor):
|
|
return x.detach().cpu().ravel().tolist()
|
|
if isinstance(x, np.ndarray):
|
|
return x.ravel().tolist()
|
|
if isinstance(x, (list, tuple)):
|
|
return list(x)
|
|
return [x]
|
|
|
|
|
|
def to_scalar(x):
|
|
if torch is not None and isinstance(x, torch.Tensor):
|
|
return float(x.detach().cpu().ravel()[0].item())
|
|
if isinstance(x, np.ndarray):
|
|
return float(x.ravel()[0])
|
|
if isinstance(x, (list, tuple)):
|
|
return float(x[0])
|
|
return float(x)
|
|
|
|
|
|
@dataclass
|
|
class EvalRealConfig:
|
|
repo_id: str
|
|
policy: PreTrainedConfig | None = None
|
|
|
|
root: str = ""
|
|
episodes: int = 0
|
|
frequency: float = 30.0
|
|
|
|
# Basic control parameters
|
|
arm: str = "G1_29" # G1_29, G1_23
|
|
ee: str = "dex3" # dex3, dex1, inspire1, brainco
|
|
|
|
# Mode flags
|
|
motion: bool = False
|
|
headless: bool = False
|
|
visualization: bool = False
|
|
send_real_robot: bool = False
|
|
use_dataset: bool = False
|
|
|
|
def __post_init__(self):
|
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
|
policy_path = parser.get_path_arg("policy")
|
|
if policy_path:
|
|
cli_overrides = parser.get_cli_overrides("policy")
|
|
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
|
self.policy.pretrained_path = policy_path
|
|
else:
|
|
logging.warning(
|
|
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
|
|
)
|
|
|
|
@classmethod
|
|
def __get_path_fields__(cls) -> list[str]:
|
|
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
|
return ["policy"]
|