remove redundant asserts

This commit is contained in:
Bryson Jones
2025-12-10 13:52:11 -08:00
parent 5524a0d7a7
commit 10cfc17705
@@ -804,11 +804,6 @@ class MultiTaskDiTPolicy(PreTrainedPolicy):
batch = dict(batch)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
n_obs_steps = batch["observation.state"].shape[1]
horizon = batch["action"].shape[1]
assert horizon == self.config.horizon
assert n_obs_steps == self.config.n_obs_steps
conditioning_vec = self.observation_encoder.encode(batch)
loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)