fix normalization for dtype

This commit is contained in:
AdilZouitine
2025-08-03 18:07:08 +02:00
parent f771e3eaf1
commit ab94626b92
2 changed files with 43 additions and 40 deletions
+30 -24
View File
@@ -24,6 +24,7 @@ def create_stats_buffers(
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> dict[str, dict[str, nn.ParameterDict]]: ) -> dict[str, dict[str, nn.ParameterDict]]:
""" """
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
@@ -60,8 +61,8 @@ def create_stats_buffers(
buffer = {} buffer = {}
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf mean = torch.ones(shape, dtype=dtype) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf std = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"mean": nn.Parameter(mean, requires_grad=False), "mean": nn.Parameter(mean, requires_grad=False),
@@ -69,8 +70,8 @@ def create_stats_buffers(
} }
) )
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf min = torch.ones(shape, dtype=dtype) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf max = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict( buffer = nn.ParameterDict(
{ {
"min": nn.Parameter(min, requires_grad=False), "min": nn.Parameter(min, requires_grad=False),
@@ -82,22 +83,22 @@ def create_stats_buffers(
if stats: if stats:
if isinstance(stats[key]["mean"], np.ndarray): if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32) buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32) buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32) buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32) buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype)
elif isinstance(stats[key]["mean"], torch.Tensor): elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated # Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and # tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here # unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32) buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32) buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX: elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32) buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32) buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype)
else: else:
type_ = type(stats[key]["mean"]) type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.") raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
@@ -121,6 +122,7 @@ class Normalize(nn.Module):
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
): ):
""" """
Args: Args:
@@ -144,7 +146,7 @@ class Normalize(nn.Module):
self.features = features self.features = features
self.norm_map = norm_map self.norm_map = norm_map
self.stats = stats self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats) stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -195,6 +197,7 @@ class Unnormalize(nn.Module):
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
): ):
""" """
Args: Args:
@@ -219,7 +222,7 @@ class Unnormalize(nn.Module):
self.norm_map = norm_map self.norm_map = norm_map
self.stats = stats self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)` # `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats) stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items(): for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer) setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -262,6 +265,7 @@ def _initialize_stats_buffers(
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> None: ) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*. """Register statistics buffers (mean/std or min/max) on the given *module*.
@@ -282,8 +286,8 @@ def _initialize_stats_buffers(
prefix = key.replace(".", "_") prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD: if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32) mean = torch.full(shape, torch.inf, dtype=dtype)
std = torch.full(shape, torch.inf, dtype=torch.float32) std = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]: if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"] mean_data = stats[key]["mean"]
@@ -293,8 +297,8 @@ def _initialize_stats_buffers(
# tensors anywhere (for example, when we use the same stats for normalization and # tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here # unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97. # https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32) mean = mean_data.clone().to(dtype=dtype)
std = std_data.clone().to(dtype=torch.float32) std = std_data.clone().to(dtype=dtype)
else: else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -303,15 +307,15 @@ def _initialize_stats_buffers(
continue continue
if norm_mode is NormalizationMode.MIN_MAX: if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32) min_val = torch.full(shape, torch.inf, dtype=dtype)
max_val = torch.full(shape, torch.inf, dtype=torch.float32) max_val = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]: if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"] min_data = stats[key]["min"]
max_data = stats[key]["max"] max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor): if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32) min_val = min_data.clone().to(dtype=dtype)
max_val = max_data.clone().to(dtype=torch.float32) max_val = max_data.clone().to(dtype=dtype)
else: else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).") raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
): ):
super().__init__() super().__init__()
self.features = features self.features = features
self.norm_map = norm_map self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats) _initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch) batch = dict(batch)
@@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature], features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode], norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None, stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
): ):
super().__init__() super().__init__()
self.features = features self.features = features
self.norm_map = norm_map self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats) _initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]: def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch) # batch = dict(batch)
@@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module):
for params in self.state_proj.parameters(): for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device): def sample_noise(self, shape, device, dtype):
noise = torch.normal( noise = torch.normal(
mean=0.0, mean=0.0,
std=1.0, std=1.0,
size=shape, size=shape,
dtype=torch.float32, dtype=dtype,
device=device, device=device,
) )
return noise return noise
def sample_time(self, bsize, device): def sample_time(self, bsize, device, dtype):
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0) beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32) time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
time = time_beta * 0.999 + 0.001 time = time_beta * 0.999 + 0.001
return time return time
@@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module):
) -> Tensor: ) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)""" """Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None: if noise is None:
noise = self.sample_noise(actions.shape, actions.device) noise = self.sample_noise(actions.shape, actions.device, actions.dtype)
if time is None: if time is None:
time = self.sample_time(actions.shape[0], actions.device) time = self.sample_time(actions.shape[0], actions.device, actions.dtype)
time_expanded = time[:, None, None] time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions x_t = time_expanded * noise + (1 - time_expanded) * actions
@@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module):
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)""" """Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0] bsize = state.shape[0]
device = state.device device = state.device
dtype = state.dtype
if noise is None: if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim) actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device) noise = self.sample_noise(actions_shape, device, dtype)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix( prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state images, img_masks, lang_tokens, lang_masks, state=state
@@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module):
fill_kv_cache=True, fill_kv_cache=True,
) )
dt = -1.0 / self.config.num_steps dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device) dt = torch.tensor(dt, dtype=dtype, device=device)
x_t = noise x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device) time = torch.tensor(1.0, dtype=dtype, device=device)
while time >= -dt / 2: while time >= -dt / 2:
expanded_time = time.expand(bsize) expanded_time = time.expand(bsize)
v_t = self.denoise_step( v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype)
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
# Euler step # Euler step
x_t += dt * v_t x_t += dt * v_t
time += dt time += dt
@@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module):
past_key_values, past_key_values,
x_t, x_t,
timestep, timestep,
dtype,
): ):
"""Apply one denoising step of the noise `x_t` at a given timestep.""" """Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep) suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
@@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module):
) )
suffix_out = outputs_embeds[1] suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :] suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32) suffix_out = suffix_out.to(dtype=dtype)
v_t = self.action_out_proj(suffix_out) v_t = self.action_out_proj(suffix_out)
return v_t return v_t