diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 3534c7ae8..b3a877b7b 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -748,16 +748,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - def forward( - self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None - ) -> Tensor: + def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor: """Do a full training forward pass and compute the loss.""" - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -1292,8 +1284,11 @@ class PI0Policy(PreTrainedPolicy): state = self.prepare_state(batch) actions = self.prepare_action(batch) + noise = self.model.sample_noise(actions.shape, actions.device) + time = self.model.sample_time(actions.shape[0], actions.device) + # Compute loss - losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions) + losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 56786fbcd..bb206d608 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return embs, pad_masks, att_masks, adarms_cond - def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor: + def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor: """Do a full training forward pass and compute the loss.""" - if noise is None: - noise = self.sample_noise(actions.shape, actions.device) - - if time is None: - time = self.sample_time(actions.shape[0], actions.device) - time_expanded = time[:, None, None] x_t = time_expanded * noise + (1 - time_expanded) * actions u_t = noise - actions @@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy): actions = self.prepare_action(batch) + noise = self.model.sample_noise(actions.shape, actions.device) + time = self.model.sample_time(actions.shape[0], actions.device) + # Compute loss (no separate state needed for PI05) - losses = self.model.forward(images, img_masks, tokens, masks, actions) + losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time) # Truncate losses to actual action dimensions original_action_dim = self.config.output_features[ACTION].shape[0]