Compare commits

...

1 Commits

Author SHA1 Message Date
AdilZouitine ab94626b92 fix normalization for dtype 2025-08-03 18:07:08 +02:00
2 changed files with 43 additions and 40 deletions
+30 -24
View File
@@ -24,6 +24,7 @@ def create_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> dict[str, dict[str, nn.ParameterDict]]:
"""
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
@@ -60,8 +61,8 @@ def create_stats_buffers(
buffer = {}
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
std = torch.ones(shape, dtype=torch.float32) * torch.inf
mean = torch.ones(shape, dtype=dtype) * torch.inf
std = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict(
{
"mean": nn.Parameter(mean, requires_grad=False),
@@ -69,8 +70,8 @@ def create_stats_buffers(
}
)
elif norm_mode is NormalizationMode.MIN_MAX:
min = torch.ones(shape, dtype=torch.float32) * torch.inf
max = torch.ones(shape, dtype=torch.float32) * torch.inf
min = torch.ones(shape, dtype=dtype) * torch.inf
max = torch.ones(shape, dtype=dtype) * torch.inf
buffer = nn.ParameterDict(
{
"min": nn.Parameter(min, requires_grad=False),
@@ -82,22 +83,22 @@ def create_stats_buffers(
if stats:
if isinstance(stats[key]["mean"], np.ndarray):
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype)
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype)
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype)
elif isinstance(stats[key]["mean"], torch.Tensor):
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
if norm_mode is NormalizationMode.MEAN_STD:
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype)
buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype)
elif norm_mode is NormalizationMode.MIN_MAX:
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype)
buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype)
else:
type_ = type(stats[key]["mean"])
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
@@ -121,6 +122,7 @@ class Normalize(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
@@ -144,7 +146,7 @@ class Normalize(nn.Module):
self.features = features
self.norm_map = norm_map
self.stats = stats
stats_buffers = create_stats_buffers(features, norm_map, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -195,6 +197,7 @@ class Unnormalize(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
"""
Args:
@@ -219,7 +222,7 @@ class Unnormalize(nn.Module):
self.norm_map = norm_map
self.stats = stats
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
stats_buffers = create_stats_buffers(features, norm_map, stats)
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
for key, buffer in stats_buffers.items():
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
@@ -262,6 +265,7 @@ def _initialize_stats_buffers(
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
) -> None:
"""Register statistics buffers (mean/std or min/max) on the given *module*.
@@ -282,8 +286,8 @@ def _initialize_stats_buffers(
prefix = key.replace(".", "_")
if norm_mode is NormalizationMode.MEAN_STD:
mean = torch.full(shape, torch.inf, dtype=torch.float32)
std = torch.full(shape, torch.inf, dtype=torch.float32)
mean = torch.full(shape, torch.inf, dtype=dtype)
std = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
mean_data = stats[key]["mean"]
@@ -293,8 +297,8 @@ def _initialize_stats_buffers(
# tensors anywhere (for example, when we use the same stats for normalization and
# unnormalization). See the logic here
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
mean = mean_data.clone().to(dtype=torch.float32)
std = std_data.clone().to(dtype=torch.float32)
mean = mean_data.clone().to(dtype=dtype)
std = std_data.clone().to(dtype=dtype)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -303,15 +307,15 @@ def _initialize_stats_buffers(
continue
if norm_mode is NormalizationMode.MIN_MAX:
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
min_val = torch.full(shape, torch.inf, dtype=dtype)
max_val = torch.full(shape, torch.inf, dtype=dtype)
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
min_data = stats[key]["min"]
max_data = stats[key]["max"]
if isinstance(min_data, torch.Tensor):
min_val = min_data.clone().to(dtype=torch.float32)
max_val = max_data.clone().to(dtype=torch.float32)
min_val = min_data.clone().to(dtype=dtype)
max_val = max_data.clone().to(dtype=dtype)
else:
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
@@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
batch = dict(batch)
@@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module):
features: dict[str, PolicyFeature],
norm_map: dict[str, NormalizationMode],
stats: dict[str, dict[str, Tensor]] | None = None,
dtype: torch.dtype = torch.float32,
):
super().__init__()
self.features = features
self.norm_map = norm_map
_initialize_stats_buffers(self, features, norm_map, stats)
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
# batch = dict(batch)
@@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module):
for params in self.state_proj.parameters():
params.requires_grad = self.config.train_state_proj
def sample_noise(self, shape, device):
def sample_noise(self, shape, device, dtype):
noise = torch.normal(
mean=0.0,
std=1.0,
size=shape,
dtype=torch.float32,
dtype=dtype,
device=device,
)
return noise
def sample_time(self, bsize, device):
def sample_time(self, bsize, device, dtype):
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
time = time_beta * 0.999 + 0.001
return time
@@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module):
) -> Tensor:
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
if noise is None:
noise = self.sample_noise(actions.shape, actions.device)
noise = self.sample_noise(actions.shape, actions.device, actions.dtype)
if time is None:
time = self.sample_time(actions.shape[0], actions.device)
time = self.sample_time(actions.shape[0], actions.device, actions.dtype)
time_expanded = time[:, None, None]
x_t = time_expanded * noise + (1 - time_expanded) * actions
@@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module):
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
bsize = state.shape[0]
device = state.device
dtype = state.dtype
if noise is None:
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
noise = self.sample_noise(actions_shape, device)
noise = self.sample_noise(actions_shape, device, dtype)
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
images, img_masks, lang_tokens, lang_masks, state=state
@@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module):
fill_kv_cache=True,
)
dt = -1.0 / self.config.num_steps
dt = torch.tensor(dt, dtype=torch.float32, device=device)
dt = torch.tensor(dt, dtype=dtype, device=device)
x_t = noise
time = torch.tensor(1.0, dtype=torch.float32, device=device)
time = torch.tensor(1.0, dtype=dtype, device=device)
while time >= -dt / 2:
expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks,
past_key_values,
x_t,
expanded_time,
)
v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype)
# Euler step
x_t += dt * v_t
time += dt
@@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module):
past_key_values,
x_t,
timestep,
dtype,
):
"""Apply one denoising step of the noise `x_t` at a given timestep."""
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
@@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module):
)
suffix_out = outputs_embeds[1]
suffix_out = suffix_out[:, -self.config.chunk_size :]
suffix_out = suffix_out.to(dtype=torch.float32)
suffix_out = suffix_out.to(dtype=dtype)
v_t = self.action_out_proj(suffix_out)
return v_t