diff --git a/tests/policies/multi_task_dit/test_multi_task_dit.py b/tests/policies/multi_task_dit/test_multi_task_dit.py index 5b70422d4..e4d456d19 100644 --- a/tests/policies/multi_task_dit/test_multi_task_dit.py +++ b/tests/policies/multi_task_dit/test_multi_task_dit.py @@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.eval() policy.reset() # Reset queues before inference @@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective(): config.validate_features() policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective(): config.validate_features() policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.train() # Use preprocessor to handle tokenization @@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path): ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) policy.eval() - # Get device before saving - device = next(policy.parameters()).device - policy.save_pretrained(root) loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config) - - # Explicitly move loaded_policy to the same device - loaded_policy.to(device) + loaded_policy.to(config.device) loaded_policy.eval() batch = create_train_batch( @@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path): with seeded_context(12): # Process batch through preprocessor processed_batch = preprocessor(batch) - # Move batch to the same device as the policy - for key in processed_batch: - if isinstance(processed_batch[key], torch.Tensor): - processed_batch[key] = processed_batch[key].to(device) # Collect policy values before saving loss, _ = policy.forward(processed_batch) @@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params(): ) policy = MultiTaskDiTPolicy(config=config) + policy.to(config.device) param_groups = policy.get_optim_params() # Should have 2 parameter groups: non-vision and vision encoder