mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 20:19:43 +00:00
sample on cpu
This commit is contained in:
@@ -99,10 +99,11 @@ def create_sinusoidal_pos_embedding( # see openpi `create_sinusoidal_pos_embedd
|
|||||||
|
|
||||||
|
|
||||||
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
def sample_beta(alpha, beta, bsize, device): # see openpi `sample_beta` (exact copy)
|
||||||
alpha_t = torch.as_tensor(alpha, dtype=torch.float32, device=device)
|
# Beta sampling uses _sample_dirichlet which isn't implemented for MPS, so sample on CPU
|
||||||
beta_t = torch.as_tensor(beta, dtype=torch.float32, device=device)
|
alpha_t = torch.tensor(alpha, dtype=torch.float32)
|
||||||
|
beta_t = torch.tensor(beta, dtype=torch.float32)
|
||||||
dist = torch.distributions.Beta(alpha_t, beta_t)
|
dist = torch.distributions.Beta(alpha_t, beta_t)
|
||||||
return dist.sample((bsize,))
|
return dist.sample((bsize,)).to(device)
|
||||||
|
|
||||||
|
|
||||||
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
def make_att_2d_masks(pad_masks, att_masks): # see openpi `make_att_2d_masks` (exact copy)
|
||||||
|
|||||||
Reference in New Issue
Block a user