Compare commits

..

1 Commits

Author SHA1 Message Date
Khalil Meftah 84abfe5c60 fix(policies): support offline batch inference for ACT and Diffusion
- Guard ACT's KL divergence computation against None latent params to
prevent crashes during eval when use_vae is set but the forward path
returns no VAE outputs.
- Add offline batch fallback to Diffusion's predict_action_chunk() so
it works with dataloader batches (empty queues) in addition to the
existing online rollout path (populated queues). This enables batched
action prediction for offline evaluation.
2026-06-15 11:35:06 +02:00
4 changed files with 19 additions and 33 deletions
-20
View File
@@ -126,26 +126,6 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
if "camera_obs" in observations:
return_observations[f"{OBS_STR}.camera_obs"] = observations["camera_obs"]
# Pass through any remaining ndarray/tensor keys not already handled above,
# so env plugins can expose extra observation keys via get_env_processors().
_handled = {"pixels", "environment_state", "agent_pos", "robot_state", "policy", "camera_obs"}
for key, value in observations.items():
if key in _handled:
continue
target = f"{OBS_STR}.{key}"
if target in return_observations:
continue
if isinstance(value, np.ndarray):
val = torch.from_numpy(value).float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
elif isinstance(value, Tensor):
val = value.float()
if val.dim() == 1:
val = val.unsqueeze(0)
return_observations[target] = val
return return_observations
+1 -1
View File
@@ -148,7 +148,7 @@ class ACTPolicy(PreTrainedPolicy):
l1_loss = (abs_err * valid_mask).sum() / num_valid.clamp_min(1)
loss_dict = {"l1_loss": l1_loss.item()}
if self.config.use_vae:
if self.config.use_vae and log_sigma_x2_hat is not None:
# Calculate Dₖₗ(latent_pdf || standard_normal). Note: After computing the KL-divergence for
# each dimension independently, we sum over the latent dimension to get the total
# KL-divergence per batch element, then take the mean over the batch.
@@ -101,11 +101,23 @@ class DiffusionPolicy(PreTrainedPolicy):
@torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
"""Predict a chunk of actions given environment observations."""
# stack n latest observations from the queue
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
actions = self.diffusion.generate_actions(batch, noise=noise)
"""Predict a chunk of actions given environment observations.
Supports two modes:
- Online (queues populated via select_action): stacks observations from internal queues.
- Offline (empty queues, e.g. dataloader batch): uses the batch directly.
"""
queues_populated = any(len(q) > 0 for q in self._queues.values())
if queues_populated:
batch = {k: torch.stack(list(self._queues[k]), dim=1) for k in batch if k in self._queues}
else:
batch = dict(batch)
if self.config.image_features:
for key in self.config.image_features:
if batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
actions = self.diffusion.generate_actions(batch, noise=noise)
return actions
@torch.no_grad()
+2 -8
View File
@@ -216,15 +216,9 @@ def register_third_party_plugins() -> None:
This function uses `importlib.metadata` to find packages installed in the environment
(including editable installs) starting with 'lerobot_robot_', 'lerobot_camera_',
'lerobot_teleoperator_', 'lerobot_policy_', or 'lerobot_env_' and imports them.
'lerobot_teleoperator_', or 'lerobot_policy_' and imports them.
"""
prefixes = (
"lerobot_robot_",
"lerobot_camera_",
"lerobot_teleoperator_",
"lerobot_policy_",
"lerobot_env_",
)
prefixes = ("lerobot_robot_", "lerobot_camera_", "lerobot_teleoperator_", "lerobot_policy_")
imported: list[str] = []
failed: list[str] = []