mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
chore: improve visual stats reshaping logic and update docstring for clarity
This commit is contained in:
@@ -137,12 +137,21 @@ class _NormalizationMixin:
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
|
||||
|
||||
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
|
||||
(already ``(C, 1, 1)``). Needed by RL training, which can start without
|
||||
a dataset and supplies stats manually via JSON config.
|
||||
"""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
if feature.type != FeatureType.VISUAL:
|
||||
continue
|
||||
if key not in self._tensor_stats:
|
||||
continue
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
|
||||
continue
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
|
||||
Reference in New Issue
Block a user