chore: improve visual stats reshaping logic and update docstring for clarity

This commit is contained in:
Khalil Meftah
2026-05-07 11:14:57 +02:00
parent ac83f4797c
commit 84f74cf0bf
+14 -5
View File
@@ -137,12 +137,21 @@ class _NormalizationMixin:
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
(already ``(C, 1, 1)``). Needed by RL training, which can start without
a dataset and supplies stats manually via JSON config.
"""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
if feature.type != FeatureType.VISUAL:
continue
if key not in self._tensor_stats:
continue
for stat_name, stat_tensor in self._tensor_stats[key].items():
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
continue
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None