mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
sample on cpu
This commit is contained in:
@@ -99,10 +99,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
||||
|
||||
|
||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
||||
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||
return dist.sample((bsize,))
|
||||
return dist.sample((bsize,)).to(device)
|
||||
|
||||
|
||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||
|
||||
Reference in New Issue
Block a user