diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 3620ad8fd..82c8cd750 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -804,11 +804,6 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): batch = dict(batch) batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4) - n_obs_steps = batch["observation.state"].shape[1] - horizon = batch["action"].shape[1] - assert horizon == self.config.horizon - assert n_obs_steps == self.config.n_obs_steps - conditioning_vec = self.observation_encoder.encode(batch) loss = self.objective.compute_loss(self.noise_predictor, batch, conditioning_vec)