feat(DeviceProcessor): Enhance tensor processing with device detection and float dtype conversion

- Improved the _process_tensor method to preserve GPU placement for tensors already on a GPU, facilitating multi-GPU training scenarios.
- Introduced a new _detect_device method in TokenizerProcessor to ensure tokenized tensors match the device of existing tensors in transitions.
- Added comprehensive unit tests to validate the functionality of device detection and float dtype conversion across various scenarios.
This commit is contained in:
AdilZouitine
2025-08-08 19:33:24 +02:00
parent 8bde9d0ab7
commit 5ca3920611
4 changed files with 468 additions and 4 deletions
+137
View File
@@ -820,6 +820,143 @@ def test_complementary_data_none():
assert TransitionKey.COMPLEMENTARY_DATA not in result
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_preserves_gpu_placement():
"""Test that DeviceProcessor preserves GPU placement when tensor is already on GPU."""
processor = DeviceProcessor(device="cuda:0")
# Create tensors already on GPU
observation = {
"observation.state": torch.randn(10).cuda(), # Already on GPU
"observation.image": torch.randn(3, 224, 224).cuda(), # Already on GPU
}
action = torch.randn(5).cuda() # Already on GPU
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check that tensors remain on their original GPU
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.image"].device.type == "cuda"
assert result[TransitionKey.ACTION].device.type == "cuda"
# Verify no unnecessary copies were made (same data pointer)
assert torch.equal(
result[TransitionKey.OBSERVATION]["observation.state"], observation["observation.state"]
)
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_multi_gpu_preservation():
"""Test that DeviceProcessor preserves placement on different GPUs in multi-GPU setup."""
# Test 1: GPU-to-GPU preservation (cuda:0 config, cuda:1 input)
processor_gpu = DeviceProcessor(device="cuda:0")
# Create tensors on cuda:1 (simulating Accelerate placement)
cuda1_device = torch.device("cuda:1")
observation = {
"observation.state": torch.randn(10).to(cuda1_device),
"observation.image": torch.randn(3, 224, 224).to(cuda1_device),
}
action = torch.randn(5).to(cuda1_device)
transition = create_transition(observation=observation, action=action)
result = processor_gpu(transition)
# Check that tensors remain on cuda:1 (not moved to cuda:0)
assert result[TransitionKey.OBSERVATION]["observation.state"].device == cuda1_device
assert result[TransitionKey.OBSERVATION]["observation.image"].device == cuda1_device
assert result[TransitionKey.ACTION].device == cuda1_device
# Test 2: GPU-to-CPU should move to CPU (not preserve GPU)
processor_cpu = DeviceProcessor(device="cpu")
transition_gpu = create_transition(
observation={"observation.state": torch.randn(10).cuda()}, action=torch.randn(5).cuda()
)
result_cpu = processor_cpu(transition_gpu)
# Check that tensors are moved to CPU
assert result_cpu[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
assert result_cpu[TransitionKey.ACTION].device.type == "cpu"
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_multi_gpu_with_cpu_tensors():
"""Test that CPU tensors are moved to configured device even in multi-GPU context."""
# Processor configured for cuda:1
processor = DeviceProcessor(device="cuda:1")
# Mix of CPU and GPU tensors
observation = {
"observation.cpu": torch.randn(10), # CPU tensor
"observation.gpu0": torch.randn(10).cuda(0), # Already on cuda:0
"observation.gpu1": torch.randn(10).cuda(1), # Already on cuda:1
}
action = torch.randn(5) # CPU tensor
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# CPU tensor should move to configured device (cuda:1)
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.type == "cuda"
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 1
assert result[TransitionKey.ACTION].device.type == "cuda"
assert result[TransitionKey.ACTION].device.index == 1
# GPU tensors should stay on their original devices
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1
@pytest.mark.skipif(torch.cuda.device_count() < 2, reason="Requires at least 2 GPUs")
def test_multi_gpu_with_float_dtype():
"""Test float dtype conversion works correctly with multi-GPU preservation."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
# Create float tensors on different GPUs
observation = {
"observation.gpu0": torch.randn(5, dtype=torch.float32).cuda(0),
"observation.gpu1": torch.randn(5, dtype=torch.float32).cuda(1),
"observation.cpu": torch.randn(5, dtype=torch.float32), # CPU
}
transition = create_transition(observation=observation)
result = processor(transition)
# Check device placement
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].device.index == 0
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].device.index == 1
assert result[TransitionKey.OBSERVATION]["observation.cpu"].device.index == 0 # Moved to cuda:0
# Check dtype conversion happened for all
assert result[TransitionKey.OBSERVATION]["observation.gpu0"].dtype == torch.float16
assert result[TransitionKey.OBSERVATION]["observation.gpu1"].dtype == torch.float16
assert result[TransitionKey.OBSERVATION]["observation.cpu"].dtype == torch.float16
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_simulated_accelerate_scenario():
"""Test a scenario simulating how Accelerate would use the processor."""
# Simulate different processes getting different GPU assignments
for gpu_id in range(min(torch.cuda.device_count(), 2)):
# Each "process" has a processor configured for cuda:0
# but data comes in already placed on the process's GPU
processor = DeviceProcessor(device="cuda:0")
# Simulate data already placed by Accelerate
device = torch.device(f"cuda:{gpu_id}")
observation = {"observation.state": torch.randn(1, 10).to(device)}
action = torch.randn(1, 5).to(device)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Verify data stays on the GPU where Accelerate placed it
assert result[TransitionKey.OBSERVATION]["observation.state"].device == device
assert result[TransitionKey.ACTION].device == device
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_policy_processor_integration():
"""Test integration with policy processors - input on GPU, output on CPU."""