mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
fix normalization for dtype
This commit is contained in:
@@ -24,6 +24,7 @@ def create_stats_buffers(
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
) -> dict[str, dict[str, nn.ParameterDict]]:
|
) -> dict[str, dict[str, nn.ParameterDict]]:
|
||||||
"""
|
"""
|
||||||
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
Create buffers per modality (e.g. "observation.image", "action") containing their mean, std, min, max
|
||||||
@@ -60,8 +61,8 @@ def create_stats_buffers(
|
|||||||
|
|
||||||
buffer = {}
|
buffer = {}
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
mean = torch.ones(shape, dtype=torch.float32) * torch.inf
|
mean = torch.ones(shape, dtype=dtype) * torch.inf
|
||||||
std = torch.ones(shape, dtype=torch.float32) * torch.inf
|
std = torch.ones(shape, dtype=dtype) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"mean": nn.Parameter(mean, requires_grad=False),
|
"mean": nn.Parameter(mean, requires_grad=False),
|
||||||
@@ -69,8 +70,8 @@ def create_stats_buffers(
|
|||||||
}
|
}
|
||||||
)
|
)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min = torch.ones(shape, dtype=torch.float32) * torch.inf
|
min = torch.ones(shape, dtype=dtype) * torch.inf
|
||||||
max = torch.ones(shape, dtype=torch.float32) * torch.inf
|
max = torch.ones(shape, dtype=dtype) * torch.inf
|
||||||
buffer = nn.ParameterDict(
|
buffer = nn.ParameterDict(
|
||||||
{
|
{
|
||||||
"min": nn.Parameter(min, requires_grad=False),
|
"min": nn.Parameter(min, requires_grad=False),
|
||||||
@@ -82,22 +83,22 @@ def create_stats_buffers(
|
|||||||
if stats:
|
if stats:
|
||||||
if isinstance(stats[key]["mean"], np.ndarray):
|
if isinstance(stats[key]["mean"], np.ndarray):
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=torch.float32)
|
buffer["mean"].data = torch.from_numpy(stats[key]["mean"]).to(dtype=dtype)
|
||||||
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=torch.float32)
|
buffer["std"].data = torch.from_numpy(stats[key]["std"]).to(dtype=dtype)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=torch.float32)
|
buffer["min"].data = torch.from_numpy(stats[key]["min"]).to(dtype=dtype)
|
||||||
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=torch.float32)
|
buffer["max"].data = torch.from_numpy(stats[key]["max"]).to(dtype=dtype)
|
||||||
elif isinstance(stats[key]["mean"], torch.Tensor):
|
elif isinstance(stats[key]["mean"], torch.Tensor):
|
||||||
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
# Note: The clone is needed to make sure that the logic in save_pretrained doesn't see duplicated
|
||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
# unnormalization). See the logic here
|
# unnormalization). See the logic here
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=torch.float32)
|
buffer["mean"].data = stats[key]["mean"].clone().to(dtype=dtype)
|
||||||
buffer["std"].data = stats[key]["std"].clone().to(dtype=torch.float32)
|
buffer["std"].data = stats[key]["std"].clone().to(dtype=dtype)
|
||||||
elif norm_mode is NormalizationMode.MIN_MAX:
|
elif norm_mode is NormalizationMode.MIN_MAX:
|
||||||
buffer["min"].data = stats[key]["min"].clone().to(dtype=torch.float32)
|
buffer["min"].data = stats[key]["min"].clone().to(dtype=dtype)
|
||||||
buffer["max"].data = stats[key]["max"].clone().to(dtype=torch.float32)
|
buffer["max"].data = stats[key]["max"].clone().to(dtype=dtype)
|
||||||
else:
|
else:
|
||||||
type_ = type(stats[key]["mean"])
|
type_ = type(stats[key]["mean"])
|
||||||
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
raise ValueError(f"np.ndarray or torch.Tensor expected, but type is '{type_}' instead.")
|
||||||
@@ -121,6 +122,7 @@ class Normalize(nn.Module):
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -144,7 +146,7 @@ class Normalize(nn.Module):
|
|||||||
self.features = features
|
self.features = features
|
||||||
self.norm_map = norm_map
|
self.norm_map = norm_map
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
|
||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
@@ -195,6 +197,7 @@ class Unnormalize(nn.Module):
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Args:
|
Args:
|
||||||
@@ -219,7 +222,7 @@ class Unnormalize(nn.Module):
|
|||||||
self.norm_map = norm_map
|
self.norm_map = norm_map
|
||||||
self.stats = stats
|
self.stats = stats
|
||||||
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
# `self.buffer_observation_state["mean"]` contains `torch.tensor(state_dim)`
|
||||||
stats_buffers = create_stats_buffers(features, norm_map, stats)
|
stats_buffers = create_stats_buffers(features, norm_map, stats, dtype)
|
||||||
for key, buffer in stats_buffers.items():
|
for key, buffer in stats_buffers.items():
|
||||||
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
setattr(self, "buffer_" + key.replace(".", "_"), buffer)
|
||||||
|
|
||||||
@@ -262,6 +265,7 @@ def _initialize_stats_buffers(
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
"""Register statistics buffers (mean/std or min/max) on the given *module*.
|
||||||
|
|
||||||
@@ -282,8 +286,8 @@ def _initialize_stats_buffers(
|
|||||||
prefix = key.replace(".", "_")
|
prefix = key.replace(".", "_")
|
||||||
|
|
||||||
if norm_mode is NormalizationMode.MEAN_STD:
|
if norm_mode is NormalizationMode.MEAN_STD:
|
||||||
mean = torch.full(shape, torch.inf, dtype=torch.float32)
|
mean = torch.full(shape, torch.inf, dtype=dtype)
|
||||||
std = torch.full(shape, torch.inf, dtype=torch.float32)
|
std = torch.full(shape, torch.inf, dtype=dtype)
|
||||||
|
|
||||||
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
if stats and key in stats and "mean" in stats[key] and "std" in stats[key]:
|
||||||
mean_data = stats[key]["mean"]
|
mean_data = stats[key]["mean"]
|
||||||
@@ -293,8 +297,8 @@ def _initialize_stats_buffers(
|
|||||||
# tensors anywhere (for example, when we use the same stats for normalization and
|
# tensors anywhere (for example, when we use the same stats for normalization and
|
||||||
# unnormalization). See the logic here
|
# unnormalization). See the logic here
|
||||||
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
# https://github.com/huggingface/safetensors/blob/079781fd0dc455ba0fe851e2b4507c33d0c0d407/bindings/python/py_src/safetensors/torch.py#L97.
|
||||||
mean = mean_data.clone().to(dtype=torch.float32)
|
mean = mean_data.clone().to(dtype=dtype)
|
||||||
std = std_data.clone().to(dtype=torch.float32)
|
std = std_data.clone().to(dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||||
|
|
||||||
@@ -303,15 +307,15 @@ def _initialize_stats_buffers(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if norm_mode is NormalizationMode.MIN_MAX:
|
if norm_mode is NormalizationMode.MIN_MAX:
|
||||||
min_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
min_val = torch.full(shape, torch.inf, dtype=dtype)
|
||||||
max_val = torch.full(shape, torch.inf, dtype=torch.float32)
|
max_val = torch.full(shape, torch.inf, dtype=dtype)
|
||||||
|
|
||||||
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
if stats and key in stats and "min" in stats[key] and "max" in stats[key]:
|
||||||
min_data = stats[key]["min"]
|
min_data = stats[key]["min"]
|
||||||
max_data = stats[key]["max"]
|
max_data = stats[key]["max"]
|
||||||
if isinstance(min_data, torch.Tensor):
|
if isinstance(min_data, torch.Tensor):
|
||||||
min_val = min_data.clone().to(dtype=torch.float32)
|
min_val = min_data.clone().to(dtype=dtype)
|
||||||
max_val = max_data.clone().to(dtype=torch.float32)
|
max_val = max_data.clone().to(dtype=dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
raise ValueError(f"Unsupported stats type for key '{key}' (expected ndarray or Tensor).")
|
||||||
|
|
||||||
@@ -330,12 +334,13 @@ class NormalizeBuffer(nn.Module):
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.features = features
|
self.features = features
|
||||||
self.norm_map = norm_map
|
self.norm_map = norm_map
|
||||||
|
|
||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
batch = dict(batch)
|
batch = dict(batch)
|
||||||
@@ -379,12 +384,13 @@ class UnnormalizeBuffer(nn.Module):
|
|||||||
features: dict[str, PolicyFeature],
|
features: dict[str, PolicyFeature],
|
||||||
norm_map: dict[str, NormalizationMode],
|
norm_map: dict[str, NormalizationMode],
|
||||||
stats: dict[str, dict[str, Tensor]] | None = None,
|
stats: dict[str, dict[str, Tensor]] | None = None,
|
||||||
|
dtype: torch.dtype = torch.float32,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.features = features
|
self.features = features
|
||||||
self.norm_map = norm_map
|
self.norm_map = norm_map
|
||||||
|
|
||||||
_initialize_stats_buffers(self, features, norm_map, stats)
|
_initialize_stats_buffers(self, features, norm_map, stats, dtype)
|
||||||
|
|
||||||
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
def forward(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||||
# batch = dict(batch)
|
# batch = dict(batch)
|
||||||
|
|||||||
@@ -673,19 +673,19 @@ class VLAFlowMatching(nn.Module):
|
|||||||
for params in self.state_proj.parameters():
|
for params in self.state_proj.parameters():
|
||||||
params.requires_grad = self.config.train_state_proj
|
params.requires_grad = self.config.train_state_proj
|
||||||
|
|
||||||
def sample_noise(self, shape, device):
|
def sample_noise(self, shape, device, dtype):
|
||||||
noise = torch.normal(
|
noise = torch.normal(
|
||||||
mean=0.0,
|
mean=0.0,
|
||||||
std=1.0,
|
std=1.0,
|
||||||
size=shape,
|
size=shape,
|
||||||
dtype=torch.float32,
|
dtype=dtype,
|
||||||
device=device,
|
device=device,
|
||||||
)
|
)
|
||||||
return noise
|
return noise
|
||||||
|
|
||||||
def sample_time(self, bsize, device):
|
def sample_time(self, bsize, device, dtype):
|
||||||
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
beta_dist = torch.distributions.Beta(concentration1=1.5, concentration0=1.0)
|
||||||
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=torch.float32)
|
time_beta = beta_dist.sample((bsize,)).to(device=device, dtype=dtype)
|
||||||
time = time_beta * 0.999 + 0.001
|
time = time_beta * 0.999 + 0.001
|
||||||
return time
|
return time
|
||||||
|
|
||||||
@@ -831,10 +831,10 @@ class VLAFlowMatching(nn.Module):
|
|||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
"""Do a full training forward pass and compute the loss (batch_size x num_steps x num_motors)"""
|
||||||
if noise is None:
|
if noise is None:
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
noise = self.sample_noise(actions.shape, actions.device, actions.dtype)
|
||||||
|
|
||||||
if time is None:
|
if time is None:
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
time = self.sample_time(actions.shape[0], actions.device, actions.dtype)
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
@@ -868,10 +868,11 @@ class VLAFlowMatching(nn.Module):
|
|||||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||||
bsize = state.shape[0]
|
bsize = state.shape[0]
|
||||||
device = state.device
|
device = state.device
|
||||||
|
dtype = state.dtype
|
||||||
|
|
||||||
if noise is None:
|
if noise is None:
|
||||||
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
actions_shape = (bsize, self.config.chunk_size, self.config.max_action_dim)
|
||||||
noise = self.sample_noise(actions_shape, device)
|
noise = self.sample_noise(actions_shape, device, dtype)
|
||||||
|
|
||||||
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
prefix_embs, prefix_pad_masks, prefix_att_masks = self.embed_prefix(
|
||||||
images, img_masks, lang_tokens, lang_masks, state=state
|
images, img_masks, lang_tokens, lang_masks, state=state
|
||||||
@@ -888,18 +889,13 @@ class VLAFlowMatching(nn.Module):
|
|||||||
fill_kv_cache=True,
|
fill_kv_cache=True,
|
||||||
)
|
)
|
||||||
dt = -1.0 / self.config.num_steps
|
dt = -1.0 / self.config.num_steps
|
||||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
dt = torch.tensor(dt, dtype=dtype, device=device)
|
||||||
|
|
||||||
x_t = noise
|
x_t = noise
|
||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
time = torch.tensor(1.0, dtype=dtype, device=device)
|
||||||
while time >= -dt / 2:
|
while time >= -dt / 2:
|
||||||
expanded_time = time.expand(bsize)
|
expanded_time = time.expand(bsize)
|
||||||
v_t = self.denoise_step(
|
v_t = self.denoise_step(prefix_pad_masks, past_key_values, x_t, expanded_time, dtype)
|
||||||
prefix_pad_masks,
|
|
||||||
past_key_values,
|
|
||||||
x_t,
|
|
||||||
expanded_time,
|
|
||||||
)
|
|
||||||
# Euler step
|
# Euler step
|
||||||
x_t += dt * v_t
|
x_t += dt * v_t
|
||||||
time += dt
|
time += dt
|
||||||
@@ -911,6 +907,7 @@ class VLAFlowMatching(nn.Module):
|
|||||||
past_key_values,
|
past_key_values,
|
||||||
x_t,
|
x_t,
|
||||||
timestep,
|
timestep,
|
||||||
|
dtype,
|
||||||
):
|
):
|
||||||
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
"""Apply one denoising step of the noise `x_t` at a given timestep."""
|
||||||
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
suffix_embs, suffix_pad_masks, suffix_att_masks = self.embed_suffix(x_t, timestep)
|
||||||
@@ -936,6 +933,6 @@ class VLAFlowMatching(nn.Module):
|
|||||||
)
|
)
|
||||||
suffix_out = outputs_embeds[1]
|
suffix_out = outputs_embeds[1]
|
||||||
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
suffix_out = suffix_out[:, -self.config.chunk_size :]
|
||||||
suffix_out = suffix_out.to(dtype=torch.float32)
|
suffix_out = suffix_out.to(dtype=dtype)
|
||||||
v_t = self.action_out_proj(suffix_out)
|
v_t = self.action_out_proj(suffix_out)
|
||||||
return v_t
|
return v_t
|
||||||
|
|||||||
Reference in New Issue
Block a user