mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
refactor(device_processor): Update device handling and improve type hints
- Changed device attribute type from torch.device to str for better clarity. - Introduced a private _device attribute to store the actual torch.device instance. - Updated tests to conditionally check for CUDA availability, ensuring compatibility across different environments. - Refactored device-related assertions in tests to use a consistent approach for device type verification.
This commit is contained in:
@@ -257,21 +257,23 @@ def test_non_blocking_flag():
|
||||
cpu_processor = DeviceProcessor(device="cpu")
|
||||
assert cpu_processor.non_blocking is False
|
||||
|
||||
# CUDA processor should have non_blocking=True
|
||||
cuda_processor = DeviceProcessor(device="cuda")
|
||||
assert cuda_processor.non_blocking is True
|
||||
if torch.cuda.is_available():
|
||||
# CUDA processor should have non_blocking=True
|
||||
cuda_processor = DeviceProcessor(device="cuda")
|
||||
assert cuda_processor.non_blocking is True
|
||||
|
||||
cuda_0_processor = DeviceProcessor(device="cuda:0")
|
||||
assert cuda_0_processor.non_blocking is True
|
||||
cuda_0_processor = DeviceProcessor(device="cuda:0")
|
||||
assert cuda_0_processor.non_blocking is True
|
||||
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test get_config, state_dict, and load_state_dict methods."""
|
||||
processor = DeviceProcessor(device="cuda")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device)
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
assert config == {"device": "cuda", "float_dtype": None}
|
||||
assert config == {"device": device, "float_dtype": None}
|
||||
|
||||
# Test state_dict (should be empty)
|
||||
state = processor.state_dict()
|
||||
@@ -279,11 +281,11 @@ def test_serialization_methods():
|
||||
|
||||
# Test load_state_dict (should be no-op)
|
||||
processor.load_state_dict({})
|
||||
assert processor.device == "cuda"
|
||||
assert processor.device == device
|
||||
|
||||
# Test reset (should be no-op)
|
||||
processor.reset()
|
||||
assert processor.device == "cuda"
|
||||
assert processor.device == device
|
||||
|
||||
|
||||
def test_feature_contract():
|
||||
@@ -302,6 +304,7 @@ def test_feature_contract():
|
||||
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test integration with RobotProcessor."""
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.processor import ToBatchProcessor
|
||||
|
||||
# Create a pipeline with DeviceProcessor
|
||||
@@ -311,22 +314,24 @@ def test_integration_with_robot_processor():
|
||||
processor = RobotProcessor(steps=[batch_processor, device_processor], name="test_pipeline")
|
||||
|
||||
# Create test data
|
||||
observation = {"observation.state": torch.randn(10)}
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
result = processor(transition)
|
||||
|
||||
# Check that tensors are batched and on correct device
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1 # Batched
|
||||
assert result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cpu"
|
||||
# The result has TransitionKey.OBSERVATION as the key, with observation.state inside
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1 # Batched
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cpu"
|
||||
assert result[TransitionKey.ACTION].shape[0] == 1 # Batched
|
||||
assert result[TransitionKey.ACTION].device.type == "cpu"
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading processor with DeviceProcessor."""
|
||||
processor = DeviceProcessor(device="cuda:0", float_dtype="float16")
|
||||
device = "cuda:0" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device, float_dtype="float16")
|
||||
robot_processor = RobotProcessor(steps=[processor], name="device_test_processor")
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
@@ -339,8 +344,11 @@ def test_save_and_load_pretrained():
|
||||
assert len(loaded_processor.steps) == 1
|
||||
loaded_device_processor = loaded_processor.steps[0]
|
||||
assert isinstance(loaded_device_processor, DeviceProcessor)
|
||||
assert loaded_device_processor.device == "cuda:0"
|
||||
assert loaded_device_processor.float_dtype == "float16"
|
||||
# Use getattr to access attributes safely
|
||||
assert (
|
||||
getattr(loaded_device_processor, "device", None) == device.split(":")[0]
|
||||
) # Device normalizes cuda:0 to cuda
|
||||
assert getattr(loaded_device_processor, "float_dtype", None) == "float16"
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
@@ -566,10 +574,11 @@ def test_float_dtype_with_mixed_tensors():
|
||||
|
||||
def test_float_dtype_serialization():
|
||||
"""Test that float_dtype is properly serialized in get_config."""
|
||||
processor = DeviceProcessor(device="cuda", float_dtype="float16")
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
processor = DeviceProcessor(device=device, float_dtype="float16")
|
||||
config = processor.get_config()
|
||||
|
||||
assert config == {"device": "cuda", "float_dtype": "float16"}
|
||||
assert config == {"device": device, "float_dtype": "float16"}
|
||||
|
||||
# Test with None float_dtype
|
||||
processor_none = DeviceProcessor(device="cpu", float_dtype=None)
|
||||
@@ -815,17 +824,18 @@ def test_complementary_data_none():
|
||||
def test_policy_processor_integration():
|
||||
"""Test integration with policy processors - input on GPU, output on CPU."""
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.constants import ACTION, OBS_STATE
|
||||
from lerobot.processor import NormalizerProcessor, ToBatchProcessor, UnnormalizerProcessor
|
||||
|
||||
# Create features and stats
|
||||
features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(10,)),
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(5,)),
|
||||
}
|
||||
|
||||
stats = {
|
||||
"observation.state": {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
"action": {"mean": torch.zeros(5), "std": torch.ones(5)},
|
||||
OBS_STATE: {"mean": torch.zeros(10), "std": torch.ones(10)},
|
||||
ACTION: {"mean": torch.zeros(5), "std": torch.ones(5)},
|
||||
}
|
||||
|
||||
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD, FeatureType.ACTION: NormalizationMode.MEAN_STD}
|
||||
@@ -844,13 +854,13 @@ def test_policy_processor_integration():
|
||||
output_processor = RobotProcessor(
|
||||
steps=[
|
||||
DeviceProcessor(device="cpu"),
|
||||
UnnormalizerProcessor(features={"action": features["action"]}, norm_map=norm_map, stats=stats),
|
||||
UnnormalizerProcessor(features={ACTION: features[ACTION]}, norm_map=norm_map, stats=stats),
|
||||
],
|
||||
name="test_postprocessor",
|
||||
)
|
||||
|
||||
# Test data on CPU
|
||||
observation = {"observation.state": torch.randn(10)}
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
action = torch.randn(5)
|
||||
transition = create_transition(observation=observation, action=action)
|
||||
|
||||
@@ -858,8 +868,9 @@ def test_policy_processor_integration():
|
||||
input_result = input_processor(transition)
|
||||
|
||||
# Verify tensors are on GPU and batched
|
||||
assert input_result[TransitionKey.OBSERVATION]["observation.state"].device.type == "cuda"
|
||||
assert input_result[TransitionKey.OBSERVATION]["observation.state"].shape[0] == 1
|
||||
# The result has TransitionKey.OBSERVATION as the key, with observation.state inside
|
||||
assert input_result[TransitionKey.OBSERVATION][OBS_STATE].device.type == "cuda"
|
||||
assert input_result[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 1
|
||||
assert input_result[TransitionKey.ACTION].device.type == "cuda"
|
||||
assert input_result[TransitionKey.ACTION].shape[0] == 1
|
||||
|
||||
|
||||
Reference in New Issue
Block a user