Compare commits

..

3 Commits

Author SHA1 Message Date
Khalil Meftah ae655076ac fix(logging): remove unused list-valued metric expansion 2026-06-16 17:27:39 +02:00
Khalil Meftah 0efa3dc874 fix(stats): handle scalar stats robustly
- Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays
from scalar stats causing shape mismatches downstream.
2026-06-15 12:28:18 +02:00
Khalil Meftah 949f4fcbe9 fix(logging): batch wandb metrics
- Batch all metrics into a single wandb.log() call instead of one per
key, reducing API overhead.

- Add support for list-valued metrics by expanding them to indexed keys (e.g.
metric_0, metric_1).
2026-06-15 12:25:06 +02:00
4 changed files with 31 additions and 171 deletions
+11 -9
View File
@@ -180,24 +180,26 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
batch_data[f"{mode}/{k}"] = v
if batch_data:
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
self._wandb.log(batch_data)
else:
self._wandb.log(data=batch_data, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
-2
View File
@@ -73,8 +73,6 @@ class EvalConfig:
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
# Defaults to True; automatically downgraded to SyncVectorEnv when batch_size=1.
use_async_envs: bool = True
# Whether to record eval rollouts as a LeRobot v3.0 dataset on disk.
recording: bool = False
def __post_init__(self) -> None:
if self.batch_size == 0:
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
+19 -159
View File
@@ -72,9 +72,8 @@ from termcolor import colored
from torch import Tensor, nn
from tqdm import trange
from lerobot.configs import FeatureType, parser
from lerobot.configs import parser
from lerobot.configs.eval import EvalPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs import (
check_env_attributes_and_types,
close_envs,
@@ -85,7 +84,7 @@ from lerobot.envs import (
from lerobot.policies import PreTrainedPolicy, make_policy, make_pre_post_processors
from lerobot.processor import PolicyProcessorPipeline
from lerobot.types import PolicyAction
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_IMAGES, OBS_STR, REWARD
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.io_utils import write_video
@@ -96,81 +95,6 @@ from lerobot.utils.utils import (
)
def _env_features_to_dataset_features(env_features: dict, raw_obs: dict | None = None) -> dict:
"""Convert EnvConfig.features (PolicyFeature objects) to the plain dict format for LeRobotDataset.create().
If raw_obs is provided, visual feature shapes are inferred from the actual observation
to avoid mismatches between the env config and the real observation resolution.
"""
features = {}
for key, ft in env_features.items():
if ft.type is FeatureType.VISUAL:
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
elif raw_obs is not None and "pixels" in raw_obs:
pixels = raw_obs["pixels"]
if isinstance(pixels, dict):
for cam_name, img in pixels.items():
if key == f"{OBS_IMAGES}.{cam_name}" or key == cam_name:
shape = img.shape[1:] # strip batch dim
elif key in ("pixels", OBS_IMAGE):
shape = pixels.shape[1:] # strip batch dim
features[key] = {"dtype": "video", "shape": shape, "names": ["height", "width", "channel"]}
else:
shape = tuple(ft.shape)
if raw_obs is not None and key in raw_obs and isinstance(raw_obs[key], np.ndarray):
shape = raw_obs[key].shape[1:] # strip batch dim
features[key] = {"dtype": "float32", "shape": shape, "names": None}
features["next.reward"] = {"dtype": "float32", "shape": (1,), "names": None}
features["next.success"] = {"dtype": "bool", "shape": (1,), "names": None}
features["next.done"] = {"dtype": "bool", "shape": (1,), "names": None}
return features
def _build_raw_frame(
raw_obs: dict,
env_idx: int,
action: np.ndarray,
reward: float,
success: bool,
done: bool,
task: str,
env_features: dict,
) -> dict:
"""Build a dataset frame from raw env observations for one env index.
Keys in the frame match the keys in env_features so they align with the
dataset schema created by _env_features_to_dataset_features().
"""
frame: dict[str, Any] = {}
for key in env_features:
if key == ACTION:
continue
if "pixels" in raw_obs and isinstance(raw_obs["pixels"], dict):
for cam_name, img in raw_obs["pixels"].items():
candidate = f"{OBS_IMAGES}.{cam_name}"
if candidate == key:
frame[key] = img[env_idx]
if key in frame:
continue
if "pixels" in raw_obs and not isinstance(raw_obs["pixels"], dict) and key in ("pixels", OBS_IMAGE):
frame[key] = raw_obs["pixels"][env_idx]
continue
raw_key = key
if raw_key in raw_obs and isinstance(raw_obs[raw_key], np.ndarray):
val = raw_obs[raw_key][env_idx]
if val.dtype == np.float64:
val = val.astype(np.float32)
frame[key] = val
frame[ACTION] = action
frame["next.reward"] = np.atleast_1d(np.float32(reward))
frame["next.success"] = np.atleast_1d(np.bool_(success))
frame["next.done"] = np.atleast_1d(np.bool_(done))
frame["task"] = task
return frame
def rollout(
env: gym.vector.VectorEnv,
policy: PreTrainedPolicy,
@@ -181,7 +105,6 @@ def rollout(
seeds: list[int] | None = None,
return_observations: bool = False,
render_callback: Callable[[gym.vector.VectorEnv], None] | None = None,
recording_dataset: Any | None = None,
) -> dict:
"""Run a batched policy rollout once through a batch of environments.
@@ -222,14 +145,6 @@ def rollout(
if render_callback is not None:
render_callback(env)
raw_observation = deepcopy(observation) if recording_dataset is not None else None
task_desc = ""
if recording_dataset is not None:
try:
task_desc = list(env.call("task_description"))[0]
except (AttributeError, NotImplementedError):
task_desc = ""
all_observations = []
all_actions = []
all_rewards = []
@@ -302,26 +217,6 @@ def rollout(
else:
successes = [False] * env.num_envs
if recording_dataset is not None and raw_observation is not None:
prev_done = done.copy()
for env_idx in range(env.num_envs):
if prev_done[env_idx]:
continue
frame = _build_raw_frame(
raw_observation,
env_idx,
action_numpy[env_idx],
reward[env_idx],
successes[env_idx],
bool(terminated[env_idx] | truncated[env_idx]),
task_desc,
recording_dataset.features,
)
recording_dataset.add_frame(frame)
if terminated[env_idx] or truncated[env_idx]:
recording_dataset.save_episode()
raw_observation = deepcopy(observation)
# Keep track of which environments are done so far.
# Mark the episode as done if we reach the maximum step limit.
# This ensures that the rollout always terminates cleanly at `max_steps`,
@@ -378,7 +273,6 @@ def eval_policy(
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
recording_dataset: Any | None = None,
) -> dict:
"""
Args:
@@ -467,7 +361,6 @@ def eval_policy(
seeds=list(seeds) if seeds else None,
return_observations=return_episode_data,
render_callback=render_frame if max_episodes_rendered > 0 else None,
recording_dataset=recording_dataset,
)
# Figure out where in each rollout sequence the first done condition was encountered (results after
@@ -670,10 +563,6 @@ def eval_main(cfg: EvalPipelineConfig):
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
recording_dir = Path(cfg.output_dir) / "recordings" if cfg.eval.recording else None
max_episodes_rendered = 0 if cfg.eval.recording else 10
videos_dir = None if cfg.eval.recording else Path(cfg.output_dir) / "videos"
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
info = eval_policy_all(
envs=envs,
@@ -683,13 +572,10 @@ def eval_main(cfg: EvalPipelineConfig):
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=cfg.eval.n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=videos_dir,
return_episode_data=False,
max_episodes_rendered=10,
videos_dir=Path(cfg.output_dir) / "videos",
start_seed=cfg.seed,
max_parallel_tasks=cfg.env.max_parallel_tasks,
recording_dir=recording_dir,
env_features=cfg.env.features if cfg.eval.recording else None,
)
print("Overall Aggregated Metrics:")
print(info["overall"])
@@ -732,7 +618,6 @@ def eval_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dataset: Any | None = None,
) -> TaskMetrics:
"""Evaluates one task_id of one suite using the provided vec env."""
@@ -750,7 +635,6 @@ def eval_one(
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dataset=recording_dataset,
)
per_episode = task_result["per_episode"]
@@ -777,8 +661,6 @@ def run_one(
videos_dir: Path | None,
return_episode_data: bool,
start_seed: int | None,
recording_dir: Path | None = None,
env_features: dict | None = None,
):
"""
Run eval_one for a single (task_group, task_id, env).
@@ -790,39 +672,21 @@ def run_one(
task_videos_dir = videos_dir / f"{task_group}_{task_id}"
task_videos_dir.mkdir(parents=True, exist_ok=True)
recording_dataset = None
if recording_dir is not None and env_features is not None:
task_recording_dir = recording_dir / f"{task_group}_{task_id}"
fps = env.unwrapped.metadata.get("render_fps", 30)
sample_obs, _ = env.reset()
features = _env_features_to_dataset_features(env_features, raw_obs=sample_obs)
recording_dataset = LeRobotDataset.create(
repo_id=f"eval_{task_group}_{task_id}",
fps=fps,
features=features,
root=str(task_recording_dir),
use_videos=True,
)
try:
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dataset=recording_dataset,
)
finally:
if recording_dataset is not None:
recording_dataset.finalize()
# Call the existing eval_one (assumed to return TaskMetrics-like dict)
metrics = eval_one(
env,
policy=policy,
env_preprocessor=env_preprocessor,
env_postprocessor=env_postprocessor,
preprocessor=preprocessor,
postprocessor=postprocessor,
n_episodes=n_episodes,
max_episodes_rendered=max_episodes_rendered,
videos_dir=task_videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
)
# ensure we always provide video_paths key to simplify accumulation
if max_episodes_rendered > 0:
metrics.setdefault("video_paths", [])
return task_group, task_id, metrics
@@ -838,8 +702,6 @@ def eval_policy_all(
n_episodes: int,
*,
max_episodes_rendered: int = 0,
recording_dir: Path | None = None,
env_features: dict | None = None,
videos_dir: Path | None = None,
return_episode_data: bool = False,
start_seed: int | None = None,
@@ -899,8 +761,6 @@ def eval_policy_all(
videos_dir=videos_dir,
return_episode_data=return_episode_data,
start_seed=start_seed,
recording_dir=recording_dir,
env_features=env_features,
)
if max_parallel_tasks <= 1: