Compare commits

..

1 Commits

Author SHA1 Message Date
Khalil Meftah 84abfe5c60 fix(policies): support offline batch inference for ACT and Diffusion
- Guard ACT's KL divergence computation against None latent params to
prevent crashes during eval when use_vae is set but the forward path
returns no VAE outputs.
- Add offline batch fallback to Diffusion's predict_action_chunk() so
it works with dataloader batches (empty queues) in addition to the
existing online rollout path (populated queues). This enables batched
action prediction for offline evaluation.
2026-06-15 11:35:06 +02:00
4 changed files with 27 additions and 23 deletions
+9 -17
View File
@@ -180,32 +180,24 @@ class WandBLogger:
self._wandb_custom_step_key.add(new_custom_key)
self._wandb.define_metric(new_custom_key, hidden=True)
batch_data = {}
for k, v in d.items():
# Skip the custom step key here, it's added to the batch below.
if custom_step_key is not None and k == custom_step_key:
continue
if isinstance(v, list):
for i, elem in enumerate(v):
if isinstance(elem, (int | float)):
batch_data[f"{mode}/{k}_{i}"] = elem
continue
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
continue
batch_data[f"{mode}/{k}"] = v
# Do not log the custom step key itself.
if self._wandb_custom_step_key is not None and k in self._wandb_custom_step_key:
continue
if batch_data:
if custom_step_key is not None:
batch_data[f"{mode}/{custom_step_key}"] = d[custom_step_key]
self._wandb.log(batch_data)
else:
self._wandb.log(data=batch_data, step=step)
value_custom_step = d[custom_step_key]
data = {f"{mode}/{k}": v, f"{mode}/{custom_step_key}": value_custom_step}
self._wandb.log(data)
continue
self._wandb.log(data={f"{mode}/{k}": v}, step=step)
def log_video(self, video_path: str, step: int, mode: str = "train"):
if mode not in {"train", "eval"}:
+1 -1
View File
@@ -153,7 +153,7 @@ def cast_stats_to_numpy(stats: dict) -> dict[str, dict[str, np.ndarray]]:
Returns:
dict: The statistics dictionary with values cast to numpy arrays.
"""
stats = {key: np.atleast_1d(np.array(value)) for key, value in flatten_dict(stats).items()}
stats = {key: np.array(value) for key, value in flatten_dict(stats).items()}
return unflatten_dict(stats)
+1 -1
View File
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
if self.config.use_vae and log_sigma_x2_hat is not None:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch, noise=noise)
"""Predict a chunk of actions given environment observations.
Supports two modes:
- Online (queues populated via select_action): stacks observations from internal queues.
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
"""
queues_populated = any(len(q) > 0 for q in self._queues.values())
if queues_populated:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
else:
batch = dict(batch)
if self.config.image_features:
for key in self.config.image_features:
if batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
actions = self.diffusion.generate_actions(batch, noise=noise)
return actions
@torch.no_grad()