diff --git a/src/lerobot/policies/act/modeling_act.py b/src/lerobot/policies/act/modeling_act.py index 5651fbfb1..1432b68a5 100644 --- a/src/lerobot/policies/act/modeling_act.py +++ b/src/lerobot/policies/act/modeling_act.py @@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy): l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1) loss_dict = {"l1_loss": l1_loss.item()} - if self.config.use_vae: + if self.config.use_vae and log_sigma_x2_hat is not None: # Calculate Dā‚–ā‚—(latent_pdf || standard_normal). Note: After computing the KL-divergence for # each dimension independently, we sum over the latent dimension to get the total # KL-divergence per batch element, then take the mean over the batch. diff --git a/src/lerobot/policies/diffusion/modeling_diffusion.py b/src/lerobot/policies/diffusion/modeling_diffusion.py index 9fbe1f703..8758a7e29 100644 --- a/src/lerobot/policies/diffusion/modeling_diffusion.py +++ b/src/lerobot/policies/diffusion/modeling_diffusion.py @@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy): @torch.no_grad() def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: - """Predict a chunk of actions given environment observations.""" - # stack n latest observations from the queue - batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} - actions = self.diffusion.generate_actions(batch, noise=noise) + """Predict a chunk of actions given environment observations. + Supports two modes: + - Online (queues populated via select_action): stacks observations from internal queues. + - Offline (empty queues, e.g. dataloader batch): uses the batch directly. + """ + queues_populated = any(len(q) > 0 for q in self._queues.values()) + if queues_populated: + batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues} + else: + batch = dict(batch) + if self.config.image_features: + for key in self.config.image_features: + if batch[key].ndim == 4: + batch[key] = batch[key].unsqueeze(1) + batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) + actions = self.diffusion.generate_actions(batch, noise=noise) return actions @torch.no_grad()