diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 89b6f1c18..1649b4b31 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -137,12 +137,21 @@ class _NormalizationMixin: self._reshape_visual_stats() def _reshape_visual_stats(self) -> None: - """Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting.""" + """Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting. + + No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats` + (already ``(C, 1, 1)``). Needed by RL training, which can start without + a dataset and supplies stats manually via JSON config. + """ for key, feature in self.features.items(): - if feature.type == FeatureType.VISUAL and key in self._tensor_stats: - for stat_name, stat_tensor in self._tensor_stats[key].items(): - if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1: - self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) + if feature.type != FeatureType.VISUAL: + continue + if key not in self._tensor_stats: + continue + for stat_name, stat_tensor in self._tensor_stats[key].items(): + if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1: + continue + self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1) def to( self, device: torch.device | str | None = None, dtype: torch.dtype | None = None