Files
lerobot/eval_robot/utils/utils.py
T
2025-11-21 14:13:05 +01:00

143 lines
4.3 KiB
Python

import numpy as np
import torch
from typing import Any
from contextlib import nullcontext
from copy import copy
import logging
from dataclasses import dataclass
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pretrained import PreTrainedPolicy
import logging_mp
logging_mp.basic_config(level=logging_mp.INFO)
logger_mp = logging_mp.get_logger(__name__)
def extract_observation(step: dict):
observation = {}
for key, value in step.items():
if key.startswith("observation.images."):
if isinstance(value, np.ndarray) and value.ndim == 3 and value.shape[-1] in [1, 3]:
value = np.transpose(value, (2, 0, 1))
observation[key] = value
elif key == "observation.state":
observation[key] = value
return observation
def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
use_amp: bool,
task: str | None = None,
use_dataset: bool | None = False,
):
observation = copy(observation)
with (
torch.inference_mode(),
torch.autocast(device_type=device.type) if device.type == "cuda" and use_amp else nullcontext(),
):
# Convert to pytorch format: channel first and float32 in [0,1] with batch dimension
for name in observation:
if not use_dataset:
# Skip non-tensor observations (like task strings)
if not hasattr(observation[name], "unsqueeze"):
continue
if "images" in name:
observation[name] = observation[name].type(torch.float32) / 255
observation[name] = observation[name].permute(2, 0, 1).contiguous()
observation[name] = observation[name].unsqueeze(0).to(device)
observation["task"] = [task if task else ""]
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
# Remove batch dimension
action = action.squeeze(0)
# Move to cpu, if not already the case
action = action.to("cpu")
return action
def reset_policy(policy: PreTrainedPolicy):
policy.reset()
def cleanup_resources(image_info: dict[str, Any]):
"""Safely close and unlink shared memory resources."""
logger_mp.info("Cleaning up shared memory resources.")
for shm in image_info["shm_resources"]:
if shm:
shm.close()
shm.unlink()
def to_list(x):
if torch is not None and isinstance(x, torch.Tensor):
return x.detach().cpu().ravel().tolist()
if isinstance(x, np.ndarray):
return x.ravel().tolist()
if isinstance(x, (list, tuple)):
return list(x)
return [x]
def to_scalar(x):
if torch is not None and isinstance(x, torch.Tensor):
return float(x.detach().cpu().ravel()[0].item())
if isinstance(x, np.ndarray):
return float(x.ravel()[0])
if isinstance(x, (list, tuple)):
return float(x[0])
return float(x)
@dataclass
class EvalRealConfig:
repo_id: str
policy: PreTrainedConfig | None = None
root: str = ""
episodes: int = 0
frequency: float = 30.0
# Basic control parameters
arm: str = "G1_29" # G1_29, G1_23
ee: str = "dex3" # dex3, dex1, inspire1, brainco
# Mode flags
motion: bool = False
headless: bool = False
visualization: bool = False
send_real_robot: bool = False
use_dataset: bool = False
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
policy_path = parser.get_path_arg("policy")
if policy_path:
cli_overrides = parser.get_cli_overrides("policy")
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
self.policy.pretrained_path = policy_path
else:
logging.warning(
"No pretrained path was provided, evaluated policy will be built from scratch (random weights)."
)
@classmethod
def __get_path_fields__(cls) -> list[str]:
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
return ["policy"]