mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
fix(policies): support dp train when n_obs_steps=1 (#2430)
Co-authored-by: hukongtao <hukongtao@agibot.com> Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
This commit is contained in:
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
for key in self.config.image_features:
|
||||||
|
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||||
|
batch[key] = batch[key].unsqueeze(1)
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
# no output_dict so returning None
|
# no output_dict so returning None
|
||||||
|
|||||||
Reference in New Issue
Block a user