mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 02:00:03 +00:00
fix(pi): keep training sampling outside compiled forwards (#3487)
Move PI0 and PI0.5 noise/time sampling into the policy wrappers so the compiled PyTorch cores receive them as tensor inputs. This keeps Beta sampling out of torch.compile on MPS, avoiding aten::_sample_dirichlet compilation errors while preserving the CUDA training path. Validation: .venv/bin/python -m pre_commit run --files src/lerobot/policies/pi0/modeling_pi0.py src/lerobot/policies/pi05/modeling_pi05.py; .venv/bin/python -m pytest -sv -rs tests/policies/pi0_pi05/test_pi0.py tests/policies/pi0_pi05/test_pi05.py tests/policies/pi0_pi05/test_pi0_rtc.py tests/policies/pi0_pi05/test_pi05_rtc.py Co-authored-by: Pepijn <138571049+pkooij@users.noreply.github.com>
This commit is contained in:
@@ -748,16 +748,8 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(
|
def forward(self, images, img_masks, lang_tokens, lang_masks, state, actions, noise, time) -> Tensor:
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, actions, noise=None, time=None
|
|
||||||
) -> Tensor:
|
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1292,8 +1284,11 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss
|
# Compute loss
|
||||||
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions)
|
losses = self.model.forward(images, img_masks, lang_tokens, lang_masks, state, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -728,14 +728,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
return embs, pad_masks, att_masks, adarms_cond
|
return embs, pad_masks, att_masks, adarms_cond
|
||||||
|
|
||||||
def forward(self, images, img_masks, tokens, masks, actions, noise=None, time=None) -> Tensor:
|
def forward(self, images, img_masks, tokens, masks, actions, noise, time) -> Tensor:
|
||||||
"""Do a full training forward pass and compute the loss."""
|
"""Do a full training forward pass and compute the loss."""
|
||||||
if noise is None:
|
|
||||||
noise = self.sample_noise(actions.shape, actions.device)
|
|
||||||
|
|
||||||
if time is None:
|
|
||||||
time = self.sample_time(actions.shape[0], actions.device)
|
|
||||||
|
|
||||||
time_expanded = time[:, None, None]
|
time_expanded = time[:, None, None]
|
||||||
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
x_t = time_expanded * noise + (1 - time_expanded) * actions
|
||||||
u_t = noise - actions
|
u_t = noise - actions
|
||||||
@@ -1262,8 +1256,11 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
|
|
||||||
actions = self.prepare_action(batch)
|
actions = self.prepare_action(batch)
|
||||||
|
|
||||||
|
noise = self.model.sample_noise(actions.shape, actions.device)
|
||||||
|
time = self.model.sample_time(actions.shape[0], actions.device)
|
||||||
|
|
||||||
# Compute loss (no separate state needed for PI05)
|
# Compute loss (no separate state needed for PI05)
|
||||||
losses = self.model.forward(images, img_masks, tokens, masks, actions)
|
losses = self.model.forward(images, img_masks, tokens, masks, actions, noise, time)
|
||||||
|
|
||||||
# Truncate losses to actual action dimensions
|
# Truncate losses to actual action dimensions
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user