diff --git a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py index 39a17aa07..6220a5f10 100644 --- a/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py +++ b/src/lerobot/policies/multi_task_dit/modeling_multi_task_dit.py @@ -112,9 +112,10 @@ class MultiTaskDiTPolicy(PreTrainedPolicy): conditioning_vec = self.observation_encoder.encode(batch) actions = self.objective.conditional_sample(self.noise_predictor, batch_size, conditioning_vec) - start_idx = n_obs_steps - 1 - end_idx = start_idx + self.config.n_action_steps - return actions[:, start_idx:end_idx] + start = n_obs_steps - 1 + end = start + self.config.n_action_steps + actions = actions[:, start:end] + return actions def reset(self): """Clear observation and action queues. Should be called on `env.reset()`"""