diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 4a74250a0..dc5eb20ec 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -99,10 +99,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy) - alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device) - beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device) + # Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU + alpha_t = torch.tensor(alpha, dtype=torch.float32) + beta_t = torch.tensor(beta, dtype=torch.float32) dist = torch.distributions.Beta(alpha_t, beta_t) - return dist.sample((bsize,)) + return dist.sample((bsize,)).to(device) def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)