mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
chore: improve visual stats reshaping logic and update docstring for clarity
This commit is contained in:
@@ -137,12 +137,21 @@ class _NormalizationMixin:
|
|||||||
self._reshape_visual_stats()
|
self._reshape_visual_stats()
|
||||||
|
|
||||||
def _reshape_visual_stats(self) -> None:
|
def _reshape_visual_stats(self) -> None:
|
||||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
"""Reshape flat ``(C,)`` visual stats to ``(C, 1, 1)`` for image broadcasting.
|
||||||
|
|
||||||
|
No-op for stats from :func:`~lerobot.datasets.compute_stats.compute_stats`
|
||||||
|
(already ``(C, 1, 1)``). Needed by RL training, which can start without
|
||||||
|
a dataset and supplies stats manually via JSON config.
|
||||||
|
"""
|
||||||
for key, feature in self.features.items():
|
for key, feature in self.features.items():
|
||||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
if feature.type != FeatureType.VISUAL:
|
||||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
continue
|
||||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
if key not in self._tensor_stats:
|
||||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
continue
|
||||||
|
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||||
|
if not isinstance(stat_tensor, Tensor) or stat_tensor.ndim != 1:
|
||||||
|
continue
|
||||||
|
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||||
|
|
||||||
def to(
|
def to(
|
||||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||||
|
|||||||
Reference in New Issue
Block a user