From ab94626b92753c9cc59af2b370ca8cbaa70abc15 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Sun, 3 Aug 2025 18:07:08 +0200 Subject: [PATCH] fix normalization for dtype --- src/lerobot/policies/normalize.py | 54 ++++++++++--------- .../policies/smolvla/modeling_smolvla.py | 29 +++++----- 2 files changed, 43 insertions(+), 40 deletions(-) diff --git a/src/lerobot/policies/normalize.py b/src/lerobot/policies/normalize.py index 119055873..0a2e84655 100644 --- a/src/lerobot/policies/normalize.py +++ b/src/lerobot/policies/normalize.py @@ -24,6 +24,7 @@ def create_stats_buffers( features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ) -> dict[str, dict[str, nn.ParameterDict]]: """ Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max @@ -60,8 +61,8 @@ def create_stats_buffers( buffer = {} if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.ones(shape, dtype=torch.float32) * torch.inf - std = torch.ones(shape, dtype=torch.float32) * torch.inf + mean = torch.ones(shape, dtype=dtype) * torch.inf + std = torch.ones(shape, dtype=dtype) * torch.inf buffer = nn.ParameterDict( { "mean": nn.Parameter(mean, requires_grad=False), @@ -69,8 +70,8 @@ def create_stats_buffers( } ) elif norm_mode is NormalizationMode.MIN_MAX: - min = torch.ones(shape, dtype=torch.float32) * torch.inf - max = torch.ones(shape, dtype=torch.float32) * torch.inf + min = torch.ones(shape, dtype=dtype) * torch.inf + max = torch.ones(shape, dtype=dtype) * torch.inf buffer = nn.ParameterDict( { "min": nn.Parameter(min, requires_grad=False), @@ -82,22 +83,22 @@ def create_stats_buffers( if stats: if isinstance(stats[key]["mean"], np.ndarray): if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) - buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) + buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype) + buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) - buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) + buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype) + buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype) elif isinstance(stats[key]["mean"], torch.Tensor): # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # tensors anywhere (for example, when we use the same stats for normalization and # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. if norm_mode is NormalizationMode.MEAN_STD: - buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) - buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) + buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype) + buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype) elif norm_mode is NormalizationMode.MIN_MAX: - buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) - buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) + buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype) + buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype) else: type_ = type(stats[key]["mean"]) raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") @@ -121,6 +122,7 @@ class Normalize(nn.Module): features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ): """ Args: @@ -144,7 +146,7 @@ class Normalize(nn.Module): self.features = features self.norm_map = norm_map self.stats = stats - stats_buffers = create_stats_buffers(features, norm_map, stats) + stats_buffers = create_stats_buffers(features, norm_map, stats, dtype) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -195,6 +197,7 @@ class Unnormalize(nn.Module): features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ): """ Args: @@ -219,7 +222,7 @@ class Unnormalize(nn.Module): self.norm_map = norm_map self.stats = stats # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` - stats_buffers = create_stats_buffers(features, norm_map, stats) + stats_buffers = create_stats_buffers(features, norm_map, stats, dtype) for key, buffer in stats_buffers.items(): setattr(self, "buffer_" + key.replace(".", "_"), buffer) @@ -262,6 +265,7 @@ def _initialize_stats_buffers( features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ) -> None: """Register statistics buffers (mean/std or min/max) on the given *module*. @@ -282,8 +286,8 @@ def _initialize_stats_buffers( prefix = key.replace(".", "_") if norm_mode is NormalizationMode.MEAN_STD: - mean = torch.full(shape, torch.inf, dtype=torch.float32) - std = torch.full(shape, torch.inf, dtype=torch.float32) + mean = torch.full(shape, torch.inf, dtype=dtype) + std = torch.full(shape, torch.inf, dtype=dtype) if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: mean_data = stats[key]["mean"] @@ -293,8 +297,8 @@ def _initialize_stats_buffers( # tensors anywhere (for example, when we use the same stats for normalization and # unnormalization). See the logic here # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. - mean = mean_data.clone().to(dtype=torch.float32) - std = std_data.clone().to(dtype=torch.float32) + mean = mean_data.clone().to(dtype=dtype) + std = std_data.clone().to(dtype=dtype) else: raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") @@ -303,15 +307,15 @@ def _initialize_stats_buffers( continue if norm_mode is NormalizationMode.MIN_MAX: - min_val = torch.full(shape, torch.inf, dtype=torch.float32) - max_val = torch.full(shape, torch.inf, dtype=torch.float32) + min_val = torch.full(shape, torch.inf, dtype=dtype) + max_val = torch.full(shape, torch.inf, dtype=dtype) if stats and key in stats and "min" in stats[key] and "max" in stats[key]: min_data = stats[key]["min"] max_data = stats[key]["max"] if isinstance(min_data, torch.Tensor): - min_val = min_data.clone().to(dtype=torch.float32) - max_val = max_data.clone().to(dtype=torch.float32) + min_val = min_data.clone().to(dtype=dtype) + max_val = max_data.clone().to(dtype=dtype) else: raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") @@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module): features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ): super().__init__() self.features = features self.norm_map = norm_map - _initialize_stats_buffers(self, features, norm_map, stats) + _initialize_stats_buffers(self, features, norm_map, stats, dtype) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: batch = dict(batch) @@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module): features: dict[str, PolicyFeature], norm_map: dict[str, NormalizationMode], stats: dict[str, dict[str, Tensor]] | None = None, + dtype: torch.dtype = torch.float32, ): super().__init__() self.features = features self.norm_map = norm_map - _initialize_stats_buffers(self, features, norm_map, stats) + _initialize_stats_buffers(self, features, norm_map, stats, dtype) def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: # batch = dict(batch) diff --git a/src/lerobot/policies/smolvla/modeling_smolvla.py b/src/lerobot/policies/smolvla/modeling_smolvla.py index 469645e84..bbd8c4d62 100644 --- a/src/lerobot/policies/smolvla/modeling_smolvla.py +++ b/src/lerobot/policies/smolvla/modeling_smolvla.py @@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module): for params in self.state_proj.parameters(): params.requires_grad = self.config.train_state_proj - def sample_noise(self, shape, device): + def sample_noise(self, shape, device, dtype): noise = torch.normal( mean=0.0, std=1.0, size=shape, - dtype=torch.float32, + dtype=dtype, device=device, ) return noise - def sample_time(self, bsize, device): + def sample_time(self, bsize, device, dtype): beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) - time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) + time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype) time = time_beta * 0.999 + 0.001 return time @@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module): ) -> Tensor: """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" if noise is None: - noise = self.sample_noise(actions.shape, actions.device) + noise = self.sample_noise(actions.shape, actions.device, actions.dtype) if time is None: - time = self.sample_time(actions.shape[0], actions.device) + time = self.sample_time(actions.shape[0], actions.device, actions.dtype) time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions @@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module): """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" bsize = state.shape[0] device = state.device + dtype = state.dtype if noise is None: actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) - noise = self.sample_noise(actions_shape, device) + noise = self.sample_noise(actions_shape, device, dtype) prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( images, img_masks, lang_tokens, lang_masks, state=state @@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module): fill_kv_cache=True, ) dt = -1.0 / self.config.num_steps - dt = torch.tensor(dt, dtype=torch.float32, device=device) + dt = torch.tensor(dt, dtype=dtype, device=device) x_t = noise - time = torch.tensor(1.0, dtype=torch.float32, device=device) + time = torch.tensor(1.0, dtype=dtype, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) - v_t = self.denoise_step( - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) + v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype) # Euler step x_t += dt * v_t time += dt @@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module): past_key_values, x_t, timestep, + dtype, ): """Apply one denoising step of the noise `x_t` at a given timestep.""" suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep) @@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module): ) suffix_out = outputs_embeds[1] suffix_out = suffix_out[:, -self.config.chunk_size :] - suffix_out = suffix_out.to(dtype=torch.float32) + suffix_out = suffix_out.to(dtype=dtype) v_t = self.action_out_proj(suffix_out) return v_t