chore: improve visual stats reshaping logic and update docstring for clarity

This commit is contained in:
Khalil Meftah
2026-05-07 11:14:57 +02:00
parent ac83f4797c
commit 84f74cf0bf
+12 -3
View File
@@ -137,11 +137,20 @@ class _NormalizationMixin:
self._reshape_visual_stats() self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None: def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting.""" """Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
(already ``(C, 1, 1)``). Needed by RL training, which can start without
a dataset and supplies stats manually via JSON config.
"""
for key, feature in self.features.items(): for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats: if feature.type != FeatureType.VISUAL:
continue
if key not in self._tensor_stats:
continue
for stat_name, stat_tensor in self._tensor_stats[key].items(): for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1: if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
continue
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to( def to(