refactor(device_processor): Update device handling and improve type hints

- Changed device attribute type from torch.device to str for better clarity.
- Introduced a private _device attribute to store the actual torch.device instance.
- Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments.
- Refactored device-related assertions in tests to use a consistent approach for device type verification.
This commit is contained in:
Adil Zouitine
2025-08-06 18:08:15 +02:00
parent 2805ae347c
commit 0535f2a59a
2 changed files with 40 additions and 27 deletions
+36 -25
View File
@@ -257,21 +257,23 @@ def test_non_blocking_flag():
cpu_processor = DeviceProcessor(device="cpu")
assert cpu_processor.non_blocking is False
# CUDA processor should have non_blocking=True
cuda_processor = DeviceProcessor(device="cuda")
assert cuda_processor.non_blocking is True
if torch.cuda.is_available():
# CUDA processor should have non_blocking=True
cuda_processor = DeviceProcessor(device="cuda")
assert cuda_processor.non_blocking is True
cuda_0_processor = DeviceProcessor(device="cuda:0")
assert cuda_0_processor.non_blocking is True
cuda_0_processor = DeviceProcessor(device="cuda:0")
assert cuda_0_processor.non_blocking is True
def test_serialization_methods():
"""Test get_config, state_dict, and load_state_dict methods."""
processor = DeviceProcessor(device="cuda")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device)
# Test get_config
config = processor.get_config()
assert config == {"device": "cuda", "float_dtype": None}
assert config == {"device": device, "float_dtype": None}
# Test state_dict (should be empty)
state = processor.state_dict()
@@ -279,11 +281,11 @@ def test_serialization_methods():
# Test load_state_dict (should be no-op)
processor.load_state_dict({})
assert processor.device == "cuda"
assert processor.device == device
# Test reset (should be no-op)
processor.reset()
assert processor.device == "cuda"
assert processor.device == device
def test_feature_contract():
@@ -302,6 +304,7 @@ def test_feature_contract():
def test_integration_with_robot_processor():
"""Test integration with RobotProcessor."""
from lerobot.constants import ACTION, OBS_STATE
from lerobot.processor import ToBatchProcessor
# Create a pipeline with DeviceProcessor
@@ -311,22 +314,24 @@ def test_integration_with_robot_processor():
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
# Create test data
observation = {"observation.state": torch.randn(10)}
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation, action=action)
result = processor(transition)
# Check that tensors are batched and on correct device
assert result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1 # Batched
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
# The result has TransitionKey.OBSERVATION as the key, with observation.state inside
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
assert result[TransitionKey.ACTION].shape[0] == 1 # Batched
assert result[TransitionKey.ACTION].device.type == "cpu"
def test_save_and_load_pretrained():
"""Test saving and loading processor with DeviceProcessor."""
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
device = "cuda:0" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device, float_dtype="float16")
robot_processor = RobotProcessor(steps=[processor], name="device_test_processor")
with tempfile.TemporaryDirectory() as tmpdir:
@@ -339,8 +344,11 @@ def test_save_and_load_pretrained():
assert len(loaded_processor.steps) == 1
loaded_device_processor = loaded_processor.steps[0]
assert isinstance(loaded_device_processor, DeviceProcessor)
assert loaded_device_processor.device == "cuda:0"
assert loaded_device_processor.float_dtype == "float16"
# Use getattr to access attributes safely
assert (
getattr(loaded_device_processor, "device", None) == device.split(":")[0]
) # Device normalizes cuda:0 to cuda
assert getattr(loaded_device_processor, "float_dtype", None) == "float16"
def test_registry_functionality():
@@ -566,10 +574,11 @@ def test_float_dtype_with_mixed_tensors():
def test_float_dtype_serialization():
"""Test that float_dtype is properly serialized in get_config."""
processor = DeviceProcessor(device="cuda", float_dtype="float16")
device = "cuda" if torch.cuda.is_available() else "cpu"
processor = DeviceProcessor(device=device, float_dtype="float16")
config = processor.get_config()
assert config == {"device": "cuda", "float_dtype": "float16"}
assert config == {"device": device, "float_dtype": "float16"}
# Test with None float_dtype
processor_none = DeviceProcessor(device="cpu", float_dtype=None)
@@ -815,17 +824,18 @@ def test_complementary_data_none():
def test_policy_processor_integration():
"""Test integration with policy processors - input on GPU, output on CPU."""
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.constants import ACTION, OBS_STATE
from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor
# Create features and stats
features = {
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
}
stats = {
"observation.state": {"mean": torch.zeros(10), "std": torch.ones(10)},
"action": {"mean": torch.zeros(5), "std": torch.ones(5)},
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
ACTION: {"mean": torch.zeros(5), "std": torch.ones(5)},
}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD}
@@ -844,13 +854,13 @@ def test_policy_processor_integration():
output_processor = RobotProcessor(
steps=[
DeviceProcessor(device="cpu"),
UnnormalizerProcessor(features={"action": features["action"]}, norm_map=norm_map, stats=stats),
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
],
name="test_postprocessor",
)
# Test data on CPU
observation = {"observation.state": torch.randn(10)}
observation = {OBS_STATE: torch.randn(10)}
action = torch.randn(5)
transition = create_transition(observation=observation, action=action)
@@ -858,8 +868,9 @@ def test_policy_processor_integration():
input_result = input_processor(transition)
# Verify tensors are on GPU and batched
assert input_result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
assert input_result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1
# The result has TransitionKey.OBSERVATION as the key, with observation.state inside
assert input_result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
assert input_result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1
assert input_result[TransitionKey.ACTION].device.type == "cuda"
assert input_result[TransitionKey.ACTION].shape[0] == 1