mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
test_datasets.py are passing!
This commit is contained in:
@@ -429,7 +429,7 @@ class TDMPCPolicy(nn.Module):
|
||||
batch[key] = batch[key].transpose(1, 0)
|
||||
|
||||
action = batch["action"]
|
||||
reward = batch["next.reward"][:, :, None] # add extra channel dimension
|
||||
reward = batch["next.reward"]
|
||||
# idxs = batch["index"] # TODO(rcadene): use idxs to update sampling weights
|
||||
done = torch.zeros_like(reward, dtype=torch.bool, device=reward.device)
|
||||
mask = torch.ones_like(reward, dtype=torch.bool, device=reward.device)
|
||||
|
||||
Reference in New Issue
Block a user