mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
committed by
Michel Aractingi
parent
cdcf346061
commit
1c8daf11fd
@@ -205,11 +205,16 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if len(self.image_features) == 0 and self.env_state_feature is None:
|
||||
raise ValueError("You must provide at least one image or the environment state among the inputs.")
|
||||
raise ValueError(
|
||||
"You must provide at least one image or the environment state among the inputs."
|
||||
)
|
||||
|
||||
if self.crop_shape is not None:
|
||||
for key, image_ft in self.image_features.items():
|
||||
if self.crop_shape[0] > image_ft.shape[1] or self.crop_shape[1] > image_ft.shape[2]:
|
||||
if (
|
||||
self.crop_shape[0] > image_ft.shape[1]
|
||||
or self.crop_shape[1] > image_ft.shape[2]
|
||||
):
|
||||
raise ValueError(
|
||||
f"`crop_shape` should fit within the images shapes. Got {self.crop_shape} "
|
||||
f"for `crop_shape` and {image_ft.shape} for "
|
||||
|
||||
@@ -70,7 +70,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.normalize_inputs = Normalize(config.input_features, config.normalization_mapping, dataset_stats)
|
||||
self.normalize_inputs = Normalize(
|
||||
config.input_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
self.normalize_targets = Normalize(
|
||||
config.output_features, config.normalization_mapping, dataset_stats
|
||||
)
|
||||
@@ -97,7 +99,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
if self.config.image_features:
|
||||
self._queues["observation.images"] = deque(maxlen=self.config.n_obs_steps)
|
||||
if self.config.env_state_feature:
|
||||
self._queues["observation.environment_state"] = deque(maxlen=self.config.n_obs_steps)
|
||||
self._queues["observation.environment_state"] = deque(
|
||||
maxlen=self.config.n_obs_steps
|
||||
)
|
||||
|
||||
@torch.no_grad
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
@@ -123,7 +127,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -151,7 +157,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
batch = self.normalize_inputs(batch)
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
batch = dict(
|
||||
batch
|
||||
) # shallow copy so that adding a key doesn't modify the original
|
||||
batch["observation.images"] = torch.stack(
|
||||
[batch[key] for key in self.config.image_features], dim=-4
|
||||
)
|
||||
@@ -515,11 +523,15 @@ class DiffusionRgbEncoder(nn.Module):
|
||||
|
||||
# Note: we have a check in the config class to make sure all images have the same shape.
|
||||
images_shape = next(iter(config.image_features.values())).shape
|
||||
dummy_shape_h_w = config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
dummy_shape_h_w = (
|
||||
config.crop_shape if config.crop_shape is not None else images_shape[1:]
|
||||
)
|
||||
dummy_shape = (1, images_shape[0], *dummy_shape_h_w)
|
||||
feature_map_shape = get_output_shape(self.backbone, dummy_shape)[1:]
|
||||
|
||||
self.pool = SpatialSoftmax(feature_map_shape, num_kp=config.spatial_softmax_num_keypoints)
|
||||
self.pool = SpatialSoftmax(
|
||||
feature_map_shape, num_kp=config.spatial_softmax_num_keypoints
|
||||
)
|
||||
self.feature_dim = config.spatial_softmax_num_keypoints * 2
|
||||
self.out = nn.Linear(config.spatial_softmax_num_keypoints * 2, self.feature_dim)
|
||||
self.relu = nn.ReLU()
|
||||
@@ -719,7 +731,9 @@ class DiffusionConditionalUnet1d(nn.Module):
|
||||
)
|
||||
|
||||
self.final_conv = nn.Sequential(
|
||||
DiffusionConv1dBlock(config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size),
|
||||
DiffusionConv1dBlock(
|
||||
config.down_dims[0], config.down_dims[0], kernel_size=config.kernel_size
|
||||
),
|
||||
nn.Conv1d(config.down_dims[0], config.action_feature.shape[0], 1),
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user