mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-18 00:37:10 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 84abfe5c60 |
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
|
|||||||
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
|
||||||
|
|
||||||
loss_dict = {"l1_loss": l1_loss.item()}
|
loss_dict = {"l1_loss": l1_loss.item()}
|
||||||
if self.config.use_vae:
|
if self.config.use_vae and log_sigma_x2_hat is not None:
|
||||||
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
|
||||||
# each dimension independently, we sum over the latent dimension to get the total
|
# each dimension independently, we sum over the latent dimension to get the total
|
||||||
# KL-divergence per batch element, then take the mean over the batch.
|
# KL-divergence per batch element, then take the mean over the batch.
|
||||||
|
|||||||
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations.
|
||||||
# stack n latest observations from the queue
|
|
||||||
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
|
||||||
actions = self.diffusion.generate_actions(batch, noise=noise)
|
|
||||||
|
|
||||||
|
Supports two modes:
|
||||||
|
- Online (queues populated via select_action): stacks observations from internal queues.
|
||||||
|
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
|
||||||
|
"""
|
||||||
|
queues_populated = any(len(q) > 0 for q in self._queues.values())
|
||||||
|
if queues_populated:
|
||||||
|
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
|
||||||
|
else:
|
||||||
|
batch = dict(batch)
|
||||||
|
if self.config.image_features:
|
||||||
|
for key in self.config.image_features:
|
||||||
|
if batch[key].ndim == 4:
|
||||||
|
batch[key] = batch[key].unsqueeze(1)
|
||||||
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
|
actions = self.diffusion.generate_actions(batch, noise=noise)
|
||||||
return actions
|
return actions
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
|
|||||||
Reference in New Issue
Block a user