mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
fix(test): add missing device placement in multi-task DiT tests (#3349)
This commit is contained in:
@@ -147,6 +147,7 @@ def test_multi_task_dit_policy_forward(batch_size: int, state_dim: int, action_d
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -336,6 +337,7 @@ def test_multi_task_dit_policy_select_action(batch_size: int, state_dim: int, ac
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
policy.reset() # Reset queues before inference
|
policy.reset() # Reset queues before inference
|
||||||
|
|
||||||
@@ -390,6 +392,7 @@ def test_multi_task_dit_policy_diffusion_objective():
|
|||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -468,6 +471,7 @@ def test_multi_task_dit_policy_flow_matching_objective():
|
|||||||
config.validate_features()
|
config.validate_features()
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.train()
|
policy.train()
|
||||||
|
|
||||||
# Use preprocessor to handle tokenization
|
# Use preprocessor to handle tokenization
|
||||||
@@ -533,16 +537,12 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
# Get device before saving
|
|
||||||
device = next(policy.parameters()).device
|
|
||||||
|
|
||||||
policy.save_pretrained(root)
|
policy.save_pretrained(root)
|
||||||
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
loaded_policy = MultiTaskDiTPolicy.from_pretrained(root, config=config)
|
||||||
|
loaded_policy.to(config.device)
|
||||||
# Explicitly move loaded_policy to the same device
|
|
||||||
loaded_policy.to(device)
|
|
||||||
loaded_policy.eval()
|
loaded_policy.eval()
|
||||||
|
|
||||||
batch = create_train_batch(
|
batch = create_train_batch(
|
||||||
@@ -565,10 +565,6 @@ def test_multi_task_dit_policy_save_and_load(tmp_path):
|
|||||||
with seeded_context(12):
|
with seeded_context(12):
|
||||||
# Process batch through preprocessor
|
# Process batch through preprocessor
|
||||||
processed_batch = preprocessor(batch)
|
processed_batch = preprocessor(batch)
|
||||||
# Move batch to the same device as the policy
|
|
||||||
for key in processed_batch:
|
|
||||||
if isinstance(processed_batch[key], torch.Tensor):
|
|
||||||
processed_batch[key] = processed_batch[key].to(device)
|
|
||||||
# Collect policy values before saving
|
# Collect policy values before saving
|
||||||
loss, _ = policy.forward(processed_batch)
|
loss, _ = policy.forward(processed_batch)
|
||||||
|
|
||||||
@@ -608,6 +604,7 @@ def test_multi_task_dit_policy_get_optim_params():
|
|||||||
)
|
)
|
||||||
|
|
||||||
policy = MultiTaskDiTPolicy(config=config)
|
policy = MultiTaskDiTPolicy(config=config)
|
||||||
|
policy.to(config.device)
|
||||||
param_groups = policy.get_optim_params()
|
param_groups = policy.get_optim_params()
|
||||||
|
|
||||||
# Should have 2 parameter groups: non-vision and vision encoder
|
# Should have 2 parameter groups: non-vision and vision encoder
|
||||||
|
|||||||
Reference in New Issue
Block a user