From b3d74f80f042a137b5ba46cf977e85a4da0e6ad3 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Fri, 19 Jun 2026 18:31:12 +0200 Subject: [PATCH] Fix batch wandb logging metrics and handle scalar stats (#3821) * fix(logging): batch wandb metrics - Batch all metrics into a single wandb.log() call instead of one per key, reducing API overhead. - Add support for list-valued metrics by expanding them to indexed keys (e.g. metric_0, metric_1). * fix(stats): handle scalar stats robustly - Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays from scalar stats causing shape mismatches downstream. * fix(logging): remove unused list-valued metric expansion --------- Co-authored-by: Steven Palma --- src/lerobot/common/wandb_utils.py | 20 +++++++++++--------- src/lerobot/datasets/io_utils.py | 2 +- 2 files changed, 12 insertions(+), 10 deletions(-) diff --git a/src/lerobot/common/wandb_utils.py b/src/lerobot/common/wandb_utils.py index b782cd751..c229b5eaa 100644 --- a/src/lerobot/common/wandb_utils.py +++ b/src/lerobot/common/wandb_utils.py @@ -180,24 +180,26 @@ class WandBLogger: self._wandb_custom_step_key.add(new_custom_key) self._wandb.define_metric(new_custom_key, hidden=True) + batch_data = {} for k, v in d.items(): + # Skip the custom step key here, it's added to the batch below. + if custom_step_key is not None and k == custom_step_key: + continue + if not isinstance(v, (int | float | str)): logging.warning( f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.' ) continue - # Do not log the custom step key itself. - if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key: - continue + batch_data[f"{mode}/{k}"] = v + if batch_data: if custom_step_key is not None: - value_custom_step = d[custom_step_key] - data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step} - self._wandb.log(data) - continue - - self._wandb.log(data={f"{mode}/{k}": v}, step=step) + batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key] + self._wandb.log(batch_data) + else: + self._wandb.log(data=batch_data, step=step) def log_video(self, video_path: str, step: int, mode: str = "train"): if mode not in {"train", "eval"}: diff --git a/src/lerobot/datasets/io_utils.py b/src/lerobot/datasets/io_utils.py index b6344942c..be94f3b3a 100644 --- a/src/lerobot/datasets/io_utils.py +++ b/src/lerobot/datasets/io_utils.py @@ -154,7 +154,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]: Returns: dict: The statistics dictionary with values cast to numpy arrays. """ - stats = {key: np.array(value) for key, value in flatten_dict(stats).items()} + stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()} return unflatten_dict(stats)