Compare commits

...

2 Commits

Author SHA1 Message Date
Khalil Meftah 0efa3dc874 fix(stats): handle scalar stats robustly
- Wrap cast_stats_to_numpy with np.atleast_1d to prevent 0-d arrays
from scalar stats causing shape mismatches downstream.
2026-06-15 12:28:18 +02:00
Khalil Meftah 949f4fcbe9 fix(logging): batch wandb metrics
- Batch all metrics into a single wandb.log() call instead of one per
key, reducing API overhead.

- Add support for list-valued metrics by expanding them to indexed keys (e.g.
metric_0, metric_1).
2026-06-15 12:25:06 +02:00
2 changed files with 18 additions and 10 deletions
+17 -9
View File
@@ -180,24 +180,32 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if isinstance(v, list):
for i, elem in enumerate(v):
if isinstance(elem, (int | float)):
batch_data[f"{mode}/{k}_{i}"] = elem
continue
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
continue
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
batch_data[f"{mode}/{k}"] = v
if batch_data:
if custom_step_key is not None:
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
self._wandb.log(batch_data)
else:
self._wandb.log(data=batch_data, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)