mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 16:57:12 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a130a9db39 | |||
| 4f5e6596be | |||
| afeeeb8982 | |||
| 040c6b3d66 | |||
| acd31c7de2 | |||
| 240393d238 |
@@ -136,7 +136,6 @@ Learn how to implement your own simulation environment or benchmark and distribu
|
|||||||
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
- **[X](https://x.com/LeRobotHF):** Follow us on X to stay up-to-date with the latest developments.
|
||||||
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
- **[Robot Learning Tutorial](https://huggingface.co/spaces/lerobot/robot-learning-tutorial):** A free, hands-on course to learn robot learning using LeRobot.
|
||||||
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
- **[T-Shirt Folding Experiment](https://huggingface.co/spaces/lerobot/robot-folding):** An end-to-end demonstration of folding t-shirts with LeRobot.
|
||||||
- **[LeLab](https://github.com/huggingface/leLab):** A web interface for LeRobot — teleoperate, calibrate, record datasets, replay, and train your SO arm from the browser, no CLI required.
|
|
||||||
|
|
||||||
## Citation
|
## Citation
|
||||||
|
|
||||||
|
|||||||
@@ -442,12 +442,11 @@ class OpenCVCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
raw_frame = self._read_from_hardware()
|
raw_frame = self._read_from_hardware()
|
||||||
processed_frame = self._postprocess_image(raw_frame)
|
processed_frame = self._postprocess_image(raw_frame)
|
||||||
|
|||||||
@@ -471,12 +471,11 @@ class RealSenseCamera(Camera):
|
|||||||
|
|
||||||
Stops on DeviceNotConnectedError, logs other errors and continues.
|
Stops on DeviceNotConnectedError, logs other errors and continues.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
raise RuntimeError(f"{self}: stop_event is not initialized before starting read loop.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
color_frame_raw = frame.get_color_frame()
|
color_frame_raw = frame.get_color_frame()
|
||||||
|
|||||||
@@ -246,12 +246,11 @@ class ZMQCamera(Camera):
|
|||||||
"""
|
"""
|
||||||
Internal loop run by the background thread for asynchronous reading.
|
Internal loop run by the background thread for asynchronous reading.
|
||||||
"""
|
"""
|
||||||
stop_event = self.stop_event
|
if self.stop_event is None:
|
||||||
if stop_event is None:
|
|
||||||
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
raise RuntimeError(f"{self}: stop_event is not initialized.")
|
||||||
|
|
||||||
failure_count = 0
|
failure_count = 0
|
||||||
while not stop_event.is_set():
|
while not self.stop_event.is_set():
|
||||||
try:
|
try:
|
||||||
frame = self._read_from_hardware()
|
frame = self._read_from_hardware()
|
||||||
capture_time = time.perf_counter()
|
capture_time = time.perf_counter()
|
||||||
|
|||||||
@@ -73,8 +73,17 @@ class EvalConfig:
|
|||||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||||
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
|
||||||
use_async_envs: bool = True
|
use_async_envs: bool = True
|
||||||
|
# Whether to record eval rollouts as a LeRobot dataset on disk.
|
||||||
|
recording: bool = False
|
||||||
|
# If set, push recorded eval datasets to the Hub under this repo id (one repo per task,
|
||||||
|
# suffixed by task and env index). Requires recording=true.
|
||||||
|
recording_repo_id: str | None = None
|
||||||
|
# Whether the pushed recording repositories should be private.
|
||||||
|
recording_private: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
if self.recording_repo_id is not None and not self.recording:
|
||||||
|
raise ValueError("eval.recording_repo_id requires eval.recording=true.")
|
||||||
if self.batch_size == 0:
|
if self.batch_size == 0:
|
||||||
self.batch_size = self._auto_batch_size()
|
self.batch_size = self._auto_batch_size()
|
||||||
if self.batch_size > self.n_episodes:
|
if self.batch_size > self.n_episodes:
|
||||||
|
|||||||
@@ -74,8 +74,6 @@ class DatasetReader:
|
|||||||
self.episodes = episodes
|
self.episodes = episodes
|
||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
self._return_uint8 = return_uint8
|
self._return_uint8 = return_uint8
|
||||||
|
|
||||||
@@ -88,16 +86,6 @@ class DatasetReader:
|
|||||||
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
check_delta_timestamps(delta_timestamps, meta.fps, tolerance_s)
|
||||||
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
self.delta_indices = get_delta_indices(delta_timestamps, meta.fps)
|
||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
|
||||||
"""Replace the transform applied to visual observations."""
|
|
||||||
if image_transforms is not None and not callable(image_transforms):
|
|
||||||
raise TypeError("image_transforms must be callable or None.")
|
|
||||||
self._image_transforms = image_transforms
|
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
|
||||||
"""Remove the transform applied to visual observations."""
|
|
||||||
self._image_transforms = None
|
|
||||||
|
|
||||||
def try_load(self) -> bool:
|
def try_load(self) -> bool:
|
||||||
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
"""Attempt to load from local cache. Returns True if data is sufficient."""
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -201,6 +201,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
super().__init__()
|
super().__init__()
|
||||||
self.repo_id = repo_id
|
self.repo_id = repo_id
|
||||||
self._requested_root = Path(root) if root else None
|
self._requested_root = Path(root) if root else None
|
||||||
|
self.reader = None
|
||||||
|
self.set_image_transforms(image_transforms)
|
||||||
self.delta_timestamps = delta_timestamps
|
self.delta_timestamps = delta_timestamps
|
||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
@@ -247,7 +249,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
return_uint8=self._return_uint8,
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
self.image_transforms = image_transforms
|
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
if force_cache_sync or not self.reader.try_load():
|
if force_cache_sync or not self.reader.try_load():
|
||||||
@@ -504,14 +505,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
def set_image_transforms(self, image_transforms: Callable | None) -> None:
|
||||||
"""Replace the transform applied to visual observations."""
|
"""Replace the transform applied to visual observations."""
|
||||||
self._ensure_reader().set_image_transforms(image_transforms)
|
if image_transforms is not None and not callable(image_transforms):
|
||||||
|
raise TypeError("image_transforms must be callable or None.")
|
||||||
self.image_transforms = image_transforms
|
self.image_transforms = image_transforms
|
||||||
|
if self.reader is not None:
|
||||||
|
self.reader._image_transforms = image_transforms
|
||||||
|
|
||||||
def clear_image_transforms(self) -> None:
|
def clear_image_transforms(self) -> None:
|
||||||
"""Remove the transform applied to visual observations."""
|
"""Remove the transform applied to visual observations."""
|
||||||
if self.reader is not None:
|
self.set_image_transforms(None)
|
||||||
self.reader.set_image_transforms(None)
|
|
||||||
self.image_transforms = None
|
|
||||||
|
|
||||||
# ── Hub methods (stay on facade) ──────────────────────────────────
|
# ── Hub methods (stay on facade) ──────────────────────────────────
|
||||||
|
|
||||||
|
|||||||
@@ -72,8 +72,9 @@ from termcolor import colored
|
|||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from tqdm import trange
|
from tqdm import trange
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import FeatureType, parser
|
||||||
from lerobot.configs.eval import EvalPipelineConfig
|
from lerobot.configs.eval import EvalPipelineConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.envs import (
|
from lerobot.envs import (
|
||||||
check_env_attributes_and_types,
|
check_env_attributes_and_types,
|
||||||
close_envs,
|
close_envs,
|
||||||
@@ -84,7 +85,7 @@ from lerobot.envs import (
|
|||||||
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
|
||||||
from lerobot.processor import PolicyProcessorPipeline
|
from lerobot.processor import PolicyProcessorPipeline
|
||||||
from lerobot.types import PolicyAction
|
from lerobot.types import PolicyAction
|
||||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
|
||||||
from lerobot.utils.device_utils import get_safe_torch_device
|
from lerobot.utils.device_utils import get_safe_torch_device
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
from lerobot.utils.io_utils import write_video
|
from lerobot.utils.io_utils import write_video
|
||||||
@@ -95,6 +96,65 @@ from lerobot.utils.utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _env_features_to_dataset_features(env_features: dict) -> dict:
|
||||||
|
"""Convert EnvConfig.features to the dict format expected by LeRobotDataset.create()."""
|
||||||
|
features = {}
|
||||||
|
for key, ft in env_features.items():
|
||||||
|
shape = tuple(ft.shape)
|
||||||
|
if ft.type is FeatureType.VISUAL:
|
||||||
|
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
|
||||||
|
else:
|
||||||
|
features[key] = {"dtype": "float32", "shape": shape, "names": None}
|
||||||
|
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
|
||||||
|
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||||
|
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
|
||||||
|
return features
|
||||||
|
|
||||||
|
|
||||||
|
def _build_raw_frame(
|
||||||
|
raw_obs: dict,
|
||||||
|
env_idx: int,
|
||||||
|
action: np.ndarray,
|
||||||
|
reward: float,
|
||||||
|
success: bool,
|
||||||
|
done: bool,
|
||||||
|
task: str,
|
||||||
|
env_features: dict,
|
||||||
|
) -> dict:
|
||||||
|
"""Build a dataset frame from raw env observations for one env index.
|
||||||
|
|
||||||
|
Keys in the frame match the keys in env_features so they align with the
|
||||||
|
dataset schema created by _env_features_to_dataset_features().
|
||||||
|
"""
|
||||||
|
frame: dict[str, Any] = {}
|
||||||
|
for key in env_features:
|
||||||
|
if key == ACTION:
|
||||||
|
continue
|
||||||
|
if key.startswith("next."):
|
||||||
|
continue
|
||||||
|
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
|
||||||
|
for cam_name, img in raw_obs["pixels"].items():
|
||||||
|
candidate = f"{OBS_IMAGES}.{cam_name}"
|
||||||
|
if candidate == key:
|
||||||
|
frame[key] = img[env_idx]
|
||||||
|
if key in frame:
|
||||||
|
continue
|
||||||
|
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
|
||||||
|
frame[key] = raw_obs["pixels"][env_idx]
|
||||||
|
continue
|
||||||
|
if key in raw_obs and isinstance(raw_obs[key], np.ndarray):
|
||||||
|
val = raw_obs[key][env_idx]
|
||||||
|
if val.dtype == np.float64:
|
||||||
|
val = val.astype(np.float32)
|
||||||
|
frame[key] = val
|
||||||
|
frame[ACTION] = action
|
||||||
|
frame["next.reward"] = np.atleast_1d(np.float32(reward))
|
||||||
|
frame["next.success"] = np.atleast_1d(np.bool_(success))
|
||||||
|
frame["next.done"] = np.atleast_1d(np.bool_(done))
|
||||||
|
frame["task"] = task
|
||||||
|
return frame
|
||||||
|
|
||||||
|
|
||||||
def rollout(
|
def rollout(
|
||||||
env: gym.vector.VectorEnv,
|
env: gym.vector.VectorEnv,
|
||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
@@ -105,6 +165,10 @@ def rollout(
|
|||||||
seeds: list[int] | None = None,
|
seeds: list[int] | None = None,
|
||||||
return_observations: bool = False,
|
return_observations: bool = False,
|
||||||
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
|
||||||
|
recording_dir: Path | None = None,
|
||||||
|
env_features: dict | None = None,
|
||||||
|
recording_repo_id: str | None = None,
|
||||||
|
recording_private: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""Run a batched policy rollout once through a batch of environments.
|
"""Run a batched policy rollout once through a batch of environments.
|
||||||
|
|
||||||
@@ -145,6 +209,33 @@ def rollout(
|
|||||||
if render_callback is not None:
|
if render_callback is not None:
|
||||||
render_callback(env)
|
render_callback(env)
|
||||||
|
|
||||||
|
recording_datasets: list[LeRobotDataset] | None = None
|
||||||
|
raw_observation = None
|
||||||
|
task_desc = ""
|
||||||
|
if recording_dir is not None and env_features is not None:
|
||||||
|
features = _env_features_to_dataset_features(env_features)
|
||||||
|
fps = env.unwrapped.metadata.get("render_fps", 30)
|
||||||
|
recording_datasets = []
|
||||||
|
for i in range(env.num_envs):
|
||||||
|
multi_env = env.num_envs > 1
|
||||||
|
root = str(recording_dir / f"env_{i}") if multi_env else str(recording_dir)
|
||||||
|
base_repo_id = recording_repo_id or "eval_recording"
|
||||||
|
repo_id = f"{base_repo_id}_env_{i}" if multi_env else base_repo_id
|
||||||
|
recording_datasets.append(
|
||||||
|
LeRobotDataset.create(
|
||||||
|
repo_id=repo_id,
|
||||||
|
fps=fps,
|
||||||
|
features=features,
|
||||||
|
root=root,
|
||||||
|
use_videos=True,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
raw_observation = deepcopy(observation)
|
||||||
|
try:
|
||||||
|
task_desc = list(env.call("task_description"))[0]
|
||||||
|
except (AttributeError, NotImplementedError):
|
||||||
|
task_desc = ""
|
||||||
|
|
||||||
all_observations = []
|
all_observations = []
|
||||||
all_actions = []
|
all_actions = []
|
||||||
all_rewards = []
|
all_rewards = []
|
||||||
@@ -217,6 +308,26 @@ def rollout(
|
|||||||
else:
|
else:
|
||||||
successes = [False] * env.num_envs
|
successes = [False] * env.num_envs
|
||||||
|
|
||||||
|
if recording_datasets is not None and raw_observation is not None:
|
||||||
|
prev_done = done.copy()
|
||||||
|
for env_idx in range(env.num_envs):
|
||||||
|
if prev_done[env_idx]:
|
||||||
|
continue
|
||||||
|
frame = _build_raw_frame(
|
||||||
|
raw_observation,
|
||||||
|
env_idx,
|
||||||
|
action_numpy[env_idx],
|
||||||
|
reward[env_idx],
|
||||||
|
successes[env_idx],
|
||||||
|
bool(terminated[env_idx] | truncated[env_idx]),
|
||||||
|
task_desc,
|
||||||
|
recording_datasets[env_idx].features,
|
||||||
|
)
|
||||||
|
recording_datasets[env_idx].add_frame(frame)
|
||||||
|
if terminated[env_idx] or truncated[env_idx]:
|
||||||
|
recording_datasets[env_idx].save_episode()
|
||||||
|
raw_observation = deepcopy(observation)
|
||||||
|
|
||||||
# Keep track of which environments are done so far.
|
# Keep track of which environments are done so far.
|
||||||
# Mark the episode as done if we reach the maximum step limit.
|
# Mark the episode as done if we reach the maximum step limit.
|
||||||
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
# This ensures that the rollout always terminates cleanly at `max_steps`,
|
||||||
@@ -255,6 +366,12 @@ def rollout(
|
|||||||
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
stacked_observations[key] = torch.stack([obs[key] for obs in all_observations], dim=1)
|
||||||
ret[OBS_STR] = stacked_observations
|
ret[OBS_STR] = stacked_observations
|
||||||
|
|
||||||
|
if recording_datasets is not None:
|
||||||
|
for ds in recording_datasets:
|
||||||
|
ds.finalize()
|
||||||
|
if recording_repo_id is not None:
|
||||||
|
ds.push_to_hub(private=recording_private)
|
||||||
|
|
||||||
if hasattr(policy, "use_original_modules"):
|
if hasattr(policy, "use_original_modules"):
|
||||||
policy.use_original_modules()
|
policy.use_original_modules()
|
||||||
|
|
||||||
@@ -273,6 +390,10 @@ def eval_policy(
|
|||||||
videos_dir: Path | None = None,
|
videos_dir: Path | None = None,
|
||||||
return_episode_data: bool = False,
|
return_episode_data: bool = False,
|
||||||
start_seed: int | None = None,
|
start_seed: int | None = None,
|
||||||
|
recording_dir: Path | None = None,
|
||||||
|
env_features: dict | None = None,
|
||||||
|
recording_repo_id: str | None = None,
|
||||||
|
recording_private: bool = False,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -361,6 +482,10 @@ def eval_policy(
|
|||||||
seeds=list(seeds) if seeds else None,
|
seeds=list(seeds) if seeds else None,
|
||||||
return_observations=return_episode_data,
|
return_observations=return_episode_data,
|
||||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||||
|
recording_dir=recording_dir,
|
||||||
|
env_features=env_features,
|
||||||
|
recording_repo_id=recording_repo_id,
|
||||||
|
recording_private=recording_private,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
# Figure out where in each rollout sequence the first done condition was encountered (results after
|
||||||
@@ -563,6 +688,10 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||||
|
|
||||||
|
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
|
||||||
|
max_episodes_rendered = 0 if cfg.eval.recording else 10
|
||||||
|
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
|
||||||
|
|
||||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||||
info = eval_policy_all(
|
info = eval_policy_all(
|
||||||
envs=envs,
|
envs=envs,
|
||||||
@@ -572,10 +701,15 @@ def eval_main(cfg: EvalPipelineConfig):
|
|||||||
preprocessor=preprocessor,
|
preprocessor=preprocessor,
|
||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
n_episodes=cfg.eval.n_episodes,
|
n_episodes=cfg.eval.n_episodes,
|
||||||
max_episodes_rendered=10,
|
max_episodes_rendered=max_episodes_rendered,
|
||||||
videos_dir=Path(cfg.output_dir) / "videos",
|
videos_dir=videos_dir,
|
||||||
|
return_episode_data=False,
|
||||||
start_seed=cfg.seed,
|
start_seed=cfg.seed,
|
||||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||||
|
recording_dir=recording_dir,
|
||||||
|
env_features=cfg.env.features if cfg.eval.recording else None,
|
||||||
|
recording_repo_id=cfg.eval.recording_repo_id,
|
||||||
|
recording_private=cfg.eval.recording_private,
|
||||||
)
|
)
|
||||||
print("Overall Aggregated Metrics:")
|
print("Overall Aggregated Metrics:")
|
||||||
print(info["overall"])
|
print(info["overall"])
|
||||||
@@ -618,6 +752,10 @@ def eval_one(
|
|||||||
videos_dir: Path | None,
|
videos_dir: Path | None,
|
||||||
return_episode_data: bool,
|
return_episode_data: bool,
|
||||||
start_seed: int | None,
|
start_seed: int | None,
|
||||||
|
recording_dir: Path | None = None,
|
||||||
|
env_features: dict | None = None,
|
||||||
|
recording_repo_id: str | None = None,
|
||||||
|
recording_private: bool = False,
|
||||||
) -> TaskMetrics:
|
) -> TaskMetrics:
|
||||||
"""Evaluates one task_id of one suite using the provided vec env."""
|
"""Evaluates one task_id of one suite using the provided vec env."""
|
||||||
|
|
||||||
@@ -635,6 +773,10 @@ def eval_one(
|
|||||||
videos_dir=task_videos_dir,
|
videos_dir=task_videos_dir,
|
||||||
return_episode_data=return_episode_data,
|
return_episode_data=return_episode_data,
|
||||||
start_seed=start_seed,
|
start_seed=start_seed,
|
||||||
|
recording_dir=recording_dir,
|
||||||
|
env_features=env_features,
|
||||||
|
recording_repo_id=recording_repo_id,
|
||||||
|
recording_private=recording_private,
|
||||||
)
|
)
|
||||||
|
|
||||||
per_episode = task_result["per_episode"]
|
per_episode = task_result["per_episode"]
|
||||||
@@ -661,6 +803,10 @@ def run_one(
|
|||||||
videos_dir: Path | None,
|
videos_dir: Path | None,
|
||||||
return_episode_data: bool,
|
return_episode_data: bool,
|
||||||
start_seed: int | None,
|
start_seed: int | None,
|
||||||
|
recording_dir: Path | None = None,
|
||||||
|
env_features: dict | None = None,
|
||||||
|
recording_repo_id: str | None = None,
|
||||||
|
recording_private: bool = False,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Run eval_one for a single (task_group, task_id, env).
|
Run eval_one for a single (task_group, task_id, env).
|
||||||
@@ -672,7 +818,13 @@ def run_one(
|
|||||||
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
|
||||||
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
task_videos_dir.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
|
task_recording_dir = None
|
||||||
|
task_repo_id = None
|
||||||
|
if recording_dir is not None and env_features is not None:
|
||||||
|
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
|
||||||
|
if recording_repo_id is not None:
|
||||||
|
task_repo_id = f"{recording_repo_id}_{task_group}_{task_id}"
|
||||||
|
|
||||||
metrics = eval_one(
|
metrics = eval_one(
|
||||||
env,
|
env,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
@@ -685,8 +837,12 @@ def run_one(
|
|||||||
videos_dir=task_videos_dir,
|
videos_dir=task_videos_dir,
|
||||||
return_episode_data=return_episode_data,
|
return_episode_data=return_episode_data,
|
||||||
start_seed=start_seed,
|
start_seed=start_seed,
|
||||||
|
recording_dir=task_recording_dir,
|
||||||
|
env_features=env_features,
|
||||||
|
recording_repo_id=task_repo_id,
|
||||||
|
recording_private=recording_private,
|
||||||
)
|
)
|
||||||
# ensure we always provide video_paths key to simplify accumulation
|
|
||||||
if max_episodes_rendered > 0:
|
if max_episodes_rendered > 0:
|
||||||
metrics.setdefault("video_paths", [])
|
metrics.setdefault("video_paths", [])
|
||||||
return task_group, task_id, metrics
|
return task_group, task_id, metrics
|
||||||
@@ -702,6 +858,10 @@ def eval_policy_all(
|
|||||||
n_episodes: int,
|
n_episodes: int,
|
||||||
*,
|
*,
|
||||||
max_episodes_rendered: int = 0,
|
max_episodes_rendered: int = 0,
|
||||||
|
recording_dir: Path | None = None,
|
||||||
|
env_features: dict | None = None,
|
||||||
|
recording_repo_id: str | None = None,
|
||||||
|
recording_private: bool = False,
|
||||||
videos_dir: Path | None = None,
|
videos_dir: Path | None = None,
|
||||||
return_episode_data: bool = False,
|
return_episode_data: bool = False,
|
||||||
start_seed: int | None = None,
|
start_seed: int | None = None,
|
||||||
@@ -761,6 +921,10 @@ def eval_policy_all(
|
|||||||
videos_dir=videos_dir,
|
videos_dir=videos_dir,
|
||||||
return_episode_data=return_episode_data,
|
return_episode_data=return_episode_data,
|
||||||
start_seed=start_seed,
|
start_seed=start_seed,
|
||||||
|
recording_dir=recording_dir,
|
||||||
|
env_features=env_features,
|
||||||
|
recording_repo_id=recording_repo_id,
|
||||||
|
recording_private=recording_private,
|
||||||
)
|
)
|
||||||
|
|
||||||
if max_parallel_tasks <= 1:
|
if max_parallel_tasks <= 1:
|
||||||
|
|||||||
Reference in New Issue
Block a user