mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix normalization for dtype
This commit is contained in:
@@ -24,6 +24,7 @@ def create_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||
"""
|
||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||
@@ -60,8 +61,8 @@ def create_stats_buffers(
|
||||
|
||||
buffer = {}
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
mean = torch.ones(shape, dtype=dtype) * torch.inf
|
||||
std = torch.ones(shape, dtype=dtype) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"mean": nn.Parameter(mean, requires_grad=False),
|
||||
@@ -69,8 +70,8 @@ def create_stats_buffers(
|
||||
}
|
||||
)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
||||
min = torch.ones(shape, dtype=dtype) * torch.inf
|
||||
max = torch.ones(shape, dtype=dtype) * torch.inf
|
||||
buffer = nn.ParameterDict(
|
||||
{
|
||||
"min": nn.Parameter(min, requires_grad=False),
|
||||
@@ -82,22 +83,22 @@ def create_stats_buffers(
|
||||
if stats:
|
||||
if isinstance(stats[key]["mean"], np.ndarray):
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype)
|
||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype)
|
||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype)
|
||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype)
|
||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype)
|
||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype)
|
||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype)
|
||||
else:
|
||||
type_ = type(stats[key]["mean"])
|
||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||
@@ -121,6 +122,7 @@ class Normalize(nn.Module):
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -144,7 +146,7 @@ class Normalize(nn.Module):
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -195,6 +197,7 @@ class Unnormalize(nn.Module):
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
"""
|
||||
Args:
|
||||
@@ -219,7 +222,7 @@ class Unnormalize(nn.Module):
|
||||
self.norm_map = norm_map
|
||||
self.stats = stats
|
||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
||||
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
|
||||
for key, buffer in stats_buffers.items():
|
||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||
|
||||
@@ -262,6 +265,7 @@ def _initialize_stats_buffers(
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
) -> None:
|
||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||
|
||||
@@ -282,8 +286,8 @@ def _initialize_stats_buffers(
|
||||
prefix = key.replace(".", "_")
|
||||
|
||||
if norm_mode is NormalizationMode.MEAN_STD:
|
||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
mean = torch.full(shape, torch.inf, dtype=dtype)
|
||||
std = torch.full(shape, torch.inf, dtype=dtype)
|
||||
|
||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||
mean_data = stats[key]["mean"]
|
||||
@@ -293,8 +297,8 @@ def _initialize_stats_buffers(
|
||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||
# unnormalization). See the logic here
|
||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||
mean = mean_data.clone().to(dtype=torch.float32)
|
||||
std = std_data.clone().to(dtype=torch.float32)
|
||||
mean = mean_data.clone().to(dtype=dtype)
|
||||
std = std_data.clone().to(dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
@@ -303,15 +307,15 @@ def _initialize_stats_buffers(
|
||||
continue
|
||||
|
||||
if norm_mode is NormalizationMode.MIN_MAX:
|
||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
||||
min_val = torch.full(shape, torch.inf, dtype=dtype)
|
||||
max_val = torch.full(shape, torch.inf, dtype=dtype)
|
||||
|
||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||
min_data = stats[key]["min"]
|
||||
max_data = stats[key]["max"]
|
||||
if isinstance(min_data, torch.Tensor):
|
||||
min_val = min_data.clone().to(dtype=torch.float32)
|
||||
max_val = max_data.clone().to(dtype=torch.float32)
|
||||
min_val = min_data.clone().to(dtype=dtype)
|
||||
max_val = max_data.clone().to(dtype=dtype)
|
||||
else:
|
||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||
|
||||
@@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module):
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
batch = dict(batch)
|
||||
@@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module):
|
||||
features: dict[str, PolicyFeature],
|
||||
norm_map: dict[str, NormalizationMode],
|
||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||
dtype: torch.dtype = torch.float32,
|
||||
):
|
||||
super().__init__()
|
||||
self.features = features
|
||||
self.norm_map = norm_map
|
||||
|
||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
||||
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
# batch = dict(batch)
|
||||
|
||||
@@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module):
|
||||
for params in self.state_proj.parameters():
|
||||
params.requires_grad = self.config.train_state_proj
|
||||
|
||||
def sample_noise(self, shape, device):
|
||||
def sample_noise(self, shape, device, dtype):
|
||||
noise = torch.normal(
|
||||
mean=0.0,
|
||||
std=1.0,
|
||||
size=shape,
|
||||
dtype=torch.float32,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
)
|
||||
return noise
|
||||
|
||||
def sample_time(self, bsize, device):
|
||||
def sample_time(self, bsize, device, dtype):
|
||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
|
||||
time = time_beta * 0.999 + 0.001
|
||||
return time
|
||||
|
||||
@@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module):
|
||||
) -> Tensor:
|
||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||
if noise is None:
|
||||
noise = self.sample_noise(actions.shape, actions.device)
|
||||
noise = self.sample_noise(actions.shape, actions.device, actions.dtype)
|
||||
|
||||
if time is None:
|
||||
time = self.sample_time(actions.shape[0], actions.device)
|
||||
time = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||
|
||||
time_expanded = time[:, None, None]
|
||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||
@@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module):
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
dtype = state.dtype
|
||||
|
||||
if noise is None:
|
||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||
noise = self.sample_noise(actions_shape, device)
|
||||
noise = self.sample_noise(actions_shape, device, dtype)
|
||||
|
||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||
images, img_masks, lang_tokens, lang_masks, state=state
|
||||
@@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module):
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
dt = torch.tensor(dt, dtype=dtype, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
time = torch.tensor(1.0, dtype=dtype, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype)
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
time += dt
|
||||
@@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module):
|
||||
past_key_values,
|
||||
x_t,
|
||||
timestep,
|
||||
dtype,
|
||||
):
|
||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||
@@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module):
|
||||
)
|
||||
suffix_out = outputs_embeds[1]
|
||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
||||
suffix_out = suffix_out.to(dtype=dtype)
|
||||
v_t = self.action_out_proj(suffix_out)
|
||||
return v_t
|
||||
|
||||
Reference in New Issue
Block a user