mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
fix
This commit is contained in:
@@ -274,8 +274,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
continue
|
||||
|
||||
chunk_data = hf[idx:end_idx]
|
||||
actions = torch.tensor(np.stack([np.array(a, copy=False) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.array(chunk_data["observation.state"][0], copy=False)).float()
|
||||
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
|
||||
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
|
||||
|
||||
mask = [True] * actions.shape[-1]
|
||||
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)
|
||||
|
||||
Reference in New Issue
Block a user