This commit is contained in:
Pepijn
2026-02-21 17:28:26 +01:00
parent 2697f65cf6
commit acae8417aa
+2 -2
View File
@@ -274,8 +274,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
continue
chunk_data = hf[idx:end_idx]
actions = torch.tensor(np.stack([np.array(a, copy=False) for a in chunk_data["action"]])).float()
state = torch.tensor(np.array(chunk_data["observation.state"][0], copy=False)).float()
actions = torch.tensor(np.stack([np.asarray(a) for a in chunk_data["action"]])).float()
state = torch.tensor(np.asarray(chunk_data["observation.state"][0])).float()
mask = [True] * actions.shape[-1]
delta = to_delta_actions(actions.unsqueeze(0), state.unsqueeze(0), mask).squeeze(0)