[pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci
This commit is contained in:
pre-commit-ci[bot]
2025-03-24 13:16:38 +00:00
committed by AdilZouitine
parent 761a2dbcb3
commit 8e6d5f504c
97 changed files with 1596 additions and 492 deletions
@@ -171,7 +171,9 @@ class ACTConfig(PreTrainedConfig):
def validate_features(self) -> None:
if not self.image_features and not self.env_state_feature:
raise ValueError("You must provide at least one image or the environment state among the inputs.")
raise ValueError(
"You must provide at least one image or the environment state among the inputs."
)
@property
def observation_delta_indices(self) -> None:
+35 -11
View File
@@ -63,7 +63,9 @@ class ACTPolicy(PreTrainedPolicy):
config.validate_features()
self.config = config
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
self.normalize_inputs = Normalize(
config.input_features, config.normalization_mapping, dataset_stats
)
self.normalize_targets = Normalize(
config.output_features, config.normalization_mapping, dataset_stats
)
@@ -120,8 +122,12 @@ class ACTPolicy(PreTrainedPolicy):
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
# If we are doing temporal ensembling, do online updates where we keep track of the number of actions
# we are ensembling over.
@@ -148,8 +154,12 @@ class ACTPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
batch = self.normalize_inputs(batch)
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [batch[key] for key in self.config.image_features]
batch = dict(
batch
) # shallow copy so that adding a key doesn't modify the original
batch["observation.images"] = [
batch[key] for key in self.config.image_features
]
batch = self.normalize_targets(batch)
actions_hat, (mu_hat, log_sigma_x2_hat) = self.model(batch)
@@ -406,14 +416,18 @@ class ACT(nn.Module):
n_1d_tokens += 1
self.encoder_1d_feature_pos_embed = nn.Embedding(n_1d_tokens, config.dim_model)
if self.config.image_features:
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(config.dim_model // 2)
self.encoder_cam_feat_pos_embed = ACTSinusoidalPositionEmbedding2d(
config.dim_model // 2
)
# Transformer decoder.
# Learnable positional embedding for the transformer's decoder (in the style of DETR object queries).
self.decoder_pos_embed = nn.Embedding(config.chunk_size, config.dim_model)
# Final action regression head on the output of the transformer's decoder.
self.action_head = nn.Linear(config.dim_model, self.config.action_feature.shape[0])
self.action_head = nn.Linear(
config.dim_model, self.config.action_feature.shape[0]
)
self._reset_parameters()
@@ -461,14 +475,20 @@ class ACT(nn.Module):
self.vae_encoder_cls_embed.weight, "1 d -> b 1 d", b=batch_size
) # (B, 1, D)
if self.config.robot_state_feature:
robot_state_embed = self.vae_encoder_robot_state_input_proj(batch["observation.state"])
robot_state_embed = self.vae_encoder_robot_state_input_proj(
batch["observation.state"]
)
robot_state_embed = robot_state_embed.unsqueeze(1) # (B, 1, D)
action_embed = self.vae_encoder_action_input_proj(
batch["action"]
) # (B, S, D)
if self.config.robot_state_feature:
vae_encoder_input = [cls_embed, robot_state_embed, action_embed] # (B, S+2, D)
vae_encoder_input = [
cls_embed,
robot_state_embed,
action_embed,
] # (B, S+2, D)
else:
vae_encoder_input = [cls_embed, action_embed]
vae_encoder_input = torch.cat(vae_encoder_input, axis=1)
@@ -517,7 +537,9 @@ class ACT(nn.Module):
)
# Robot state token.
if self.config.robot_state_feature:
encoder_in_tokens.append(self.encoder_robot_state_input_proj(batch["observation.state"]))
encoder_in_tokens.append(
self.encoder_robot_state_input_proj(batch["observation.state"])
)
# Environment state token.
if self.config.env_state_feature:
encoder_in_tokens.append(
@@ -534,7 +556,9 @@ class ACT(nn.Module):
# For a list of images, the H and W may vary but H*W is constant.
for img in batch["observation.images"]:
cam_features = self.backbone(img)["feature_map"]
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(dtype=cam_features.dtype)
cam_pos_embed = self.encoder_cam_feat_pos_embed(cam_features).to(
dtype=cam_features.dtype
)
cam_features = self.encoder_img_feat_input_proj(cam_features)
# Rearrange features to (sequence, batch, dim).