refactor(train): Remove unnecessary tensor device handling in training loop

This commit is contained in:
AdilZouitine
2025-08-08 19:35:15 +02:00
parent 5bf82f8229
commit fd5d8b3d5f
-4
View File
@@ -209,10 +209,6 @@ def train(cfg: TrainPipelineConfig):
batch = preprocessor(batch)
train_tracker.dataloading_s = time.perf_counter() - start_time
for key in batch:
if isinstance(batch[key], torch.Tensor):
batch[key] = batch[key].to(device, non_blocking=device.type == "cuda")
train_tracker, output_dict = update_policy(
train_tracker,
policy,