diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index c9218a650..fb3d9b860 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -18,6 +18,7 @@ from __future__ import annotations from collections.abc import Iterable, Sequence from copy import deepcopy +from functools import singledispatch from typing import Any import numpy as np @@ -29,17 +30,113 @@ from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNC from .pipeline import EnvTransition, TransitionKey -def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]): - if isinstance(x, torch.Tensor): - return x - if isinstance(x, np.ndarray): - # Keep images (uint8 HWC) and python objects as-is - if x.dtype == np.uint8 or x.dtype == np.object_: - return x - # Scalars/arrays to float32 tensor - return torch.as_tensor(x, dtype=torch.float32) - # Anything else to float32 tensor - return torch.as_tensor(x, dtype=torch.float32) +@singledispatch +def to_tensor( + value: Any, + *, + dtype: torch.dtype | None = torch.float32, + device: torch.device | str | None = None, +) -> torch.Tensor: + """ + Convert various data types to PyTorch tensors with configurable options. + + This is a unified tensor conversion function using single dispatch to handle + different input types appropriately. + + Args: + value: Input value to convert (tensor, array, scalar, sequence, etc.) + dtype: Target tensor dtype. If None, preserves original dtype. + device: Target device for the tensor. + + Returns: + PyTorch tensor. + + Raises: + TypeError: If the input type is not supported. + """ + raise TypeError(f"Unsupported type for tensor conversion: {type(value)}") + + +@to_tensor.register(torch.Tensor) +def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle existing PyTorch tensors.""" + if dtype is not None: + value = value.to(dtype=dtype) + if device is not None: + value = value.to(device=device) + return value + + +@to_tensor.register(np.ndarray) +def _( + value: np.ndarray, + *, + dtype=torch.float32, + device=None, + **kwargs, +) -> torch.Tensor: + """Handle numpy arrays.""" + # Check for numpy scalars (0-dimensional arrays) and treat them as scalars + if value.ndim == 0: + # Numpy scalars should be converted to 0-dimensional tensors + scalar_value = value.item() + return torch.tensor(scalar_value, dtype=dtype, device=device) + + # Create tensor from numpy array (torch.from_numpy handles contiguity automatically) + tensor = torch.from_numpy(value) + + # Apply dtype conversion if specified + if dtype is not None: + tensor = tensor.to(dtype=dtype) + if device is not None: + tensor = tensor.to(device=device) + + return tensor + + +@to_tensor.register(int) +@to_tensor.register(float) +@to_tensor.register(np.integer) +@to_tensor.register(np.floating) +def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle scalar values including numpy scalars.""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(list) +@to_tensor.register(tuple) +def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor: + """Handle sequences (lists, tuples).""" + return torch.tensor(value, dtype=dtype, device=device) + + +@to_tensor.register(dict) +def _(value: dict, *, device=None, **kwargs) -> dict: + """Handle dictionaries by recursively converting values to tensors.""" + if not value: + return {} + + result = {} + for key, sub_value in value.items(): + if sub_value is None: + continue + + if isinstance(sub_value, dict): + # Recursively process nested dictionaries + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + continue + + # Convert individual values to tensors + result[key] = to_tensor( + sub_value, + device=device, + **kwargs, + ) + return result def _from_tensor(x: Any): @@ -88,7 +185,7 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition: continue arr = np.array(v) if np.isscalar(v) else v - act_dict[f"{ACTION}.{k}"] = _to_tensor(arr) + act_dict[f"{ACTION}.{k}"] = to_tensor(arr) return make_obs_act_transition(act=act_dict) @@ -103,7 +200,7 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio obs_dict: dict[str, Any] = {} for k, v in state.items(): arr = np.array(v) if np.isscalar(v) else v - obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr) + obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr) for cam, img in images.items(): obs_dict[f"{OBS_IMAGES}.{cam}"] = img diff --git a/src/lerobot/processor/gym_action_processor.py b/src/lerobot/processor/gym_action_processor.py index 70c54cbde..41b320370 100644 --- a/src/lerobot/processor/gym_action_processor.py +++ b/src/lerobot/processor/gym_action_processor.py @@ -16,6 +16,7 @@ from dataclasses import dataclass import numpy as np import torch +from lerobot.processor.converters import to_tensor from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry @@ -59,5 +60,5 @@ class Numpy2TorchActionProcessor(ActionProcessor): f"Expected np.ndarray or None, got {type(action).__name__}. " "Use appropriate processor for non-tensor actions." ) - torch_action = torch.from_numpy(action) + torch_action = to_tensor(action, dtype=None) # Preserve original dtype return torch_action diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index fa635414c..b88d8b5af 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -4,12 +4,12 @@ from copy import deepcopy from dataclasses import dataclass, field from typing import Any -import numpy as np import torch from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.processor.converters import to_tensor from lerobot.processor.pipeline import ( EnvTransition, ProcessorStep, @@ -19,37 +19,6 @@ from lerobot.processor.pipeline import ( ) -def _to_tensor(value: Any, device: torch.device | None = None) -> Tensor: - """Convert common python/numpy/torch types to a torch.float32 tensor. - - Always returns float32; preserves device if provided. - """ - if isinstance(value, torch.Tensor): - return value.to(dtype=torch.float32, device=device) - if isinstance(value, np.ndarray): - # ensure contiguous, cast to float32 then convert - return torch.from_numpy(np.ascontiguousarray(value.astype(np.float32))).to(device=device) - if isinstance(value, (int, float)): - return torch.tensor(value, dtype=torch.float32, device=device) - if isinstance(value, (list, tuple)): - return torch.tensor(value, dtype=torch.float32, device=device) - raise TypeError(f"Unsupported type for stats value: {type(value)}") - - -def _convert_stats_to_tensors( - stats: dict[str, dict[str, Any]], device: torch.device | None = None -) -> dict[str, dict[str, Tensor]]: - """Convert numeric stats values to torch tensors, preserving keys.""" - tensor_stats: dict[str, dict[str, Tensor]] = {} - for key, sub in (stats or {}).items(): - if sub is None: - continue - tensor_stats[key] = {} - for stat_name, value in sub.items(): - tensor_stats[key][stat_name] = _to_tensor(value, device=device) - return tensor_stats - - @dataclass class _NormalizationMixin: """ @@ -91,12 +60,12 @@ class _NormalizationMixin: # Convert stats to tensors and move to the target device once during initialization. self.stats = self.stats or {} - self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) + self._tensor_stats = to_tensor(self.stats, device=self.device) def to(self, device: torch.device | str) -> _NormalizationMixin: """Moves the processor's normalization stats to the specified device and returns self.""" self.device = device - self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device) + self._tensor_stats = to_tensor(self.stats, device=self.device) return self def state_dict(self) -> dict[str, Tensor]: @@ -165,7 +134,7 @@ class _NormalizationMixin: # Move stats to input device if needed stats_device = next(iter(stats.values())).device if stats_device != input_device: - stats = _convert_stats_to_tensors({key: self._tensor_stats[key]}, device=input_device)[key] + stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key] if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] @@ -295,5 +264,5 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An if isinstance(step, _NormalizationMixin): step.stats = stats # Re-initialize tensor_stats on the correct device. - step._tensor_stats = _convert_stats_to_tensors(stats, device=step.device) + step._tensor_stats = to_tensor(stats, device=step.device) return rp diff --git a/tests/processor/test_converters.py b/tests/processor/test_converters.py index 590f6a892..ac2015b48 100644 --- a/tests/processor/test_converters.py +++ b/tests/processor/test_converters.py @@ -5,6 +5,7 @@ import torch from lerobot.processor.converters import ( to_dataset_frame, to_output_robot_action, + to_tensor, to_transition_robot_observation, to_transition_teleop_action, ) @@ -12,12 +13,12 @@ from lerobot.processor.pipeline import TransitionKey def test_to_transition_teleop_action_prefix_and_tensor_conversion(): - # Scalars, arrays, and "image-like" uint8 arrays are supported + # Scalars, arrays, and uint8 arrays are all converted to tensors img = np.zeros((8, 12, 3), dtype=np.uint8) act = { "ee.x": 0.5, # scalar to torch tensor "delta": np.array([1.0, 2.0]), # ndarray to torch tensor - "raw_img": img, # uint8 HWC to passthrough ndarray + "raw_img": img, # uint8 HWC to torch tensor } tr = to_transition_teleop_action(act) @@ -29,7 +30,7 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion(): assert "action.delta" in tr[TransitionKey.ACTION] assert "action.raw_img" in tr[TransitionKey.ACTION] - # Types: scalars/arrays -> torch tensor; images to np.ndarray + # Types: all values -> torch tensor assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor) assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5) @@ -37,8 +38,8 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion(): assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,) assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0])) - assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray) - assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8 + assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor) + assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8 assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3) # Observation is created as empty dict by make_transition @@ -194,3 +195,185 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata(): # Complementary data assert batch["frame_is_pad"] is True assert batch["task"] == "Pick cube" + + +# Tests for the unified to_tensor function +def test_to_tensor_numpy_arrays(): + """Test to_tensor with various numpy arrays.""" + # Regular numpy array + arr = np.array([1.0, 2.0, 3.0]) + result = to_tensor(arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Different numpy dtypes should convert to float32 by default + int_arr = np.array([1, 2, 3], dtype=np.int64) + result = to_tensor(int_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # uint8 arrays (previously "preserved") should now convert + uint8_arr = np.array([100, 150, 200], dtype=np.uint8) + result = to_tensor(uint8_arr) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0])) + + +def test_to_tensor_numpy_scalars(): + """Test to_tensor with numpy scalars (0-dimensional arrays).""" + # numpy float32 scalar + scalar = np.float32(3.14) + result = to_tensor(scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 # Should be 0-dimensional tensor + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + # numpy int32 scalar + int_scalar = np.int32(42) + result = to_tensor(int_scalar) + assert isinstance(result, torch.Tensor) + assert result.ndim == 0 + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + +def test_to_tensor_python_scalars(): + """Test to_tensor with Python scalars.""" + # Python int + result = to_tensor(42) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(42.0) + + # Python float + result = to_tensor(3.14) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert result.item() == pytest.approx(3.14) + + +def test_to_tensor_sequences(): + """Test to_tensor with lists and tuples.""" + # List + result = to_tensor([1, 2, 3]) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + # Tuple + result = to_tensor((4.5, 5.5, 6.5)) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5])) + + +def test_to_tensor_existing_tensors(): + """Test to_tensor with existing PyTorch tensors.""" + # Tensor with same dtype should pass through with potential device change + tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32) + result = to_tensor(tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, tensor) + + # Tensor with different dtype should convert + int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64) + result = to_tensor(int_tensor) + assert isinstance(result, torch.Tensor) + assert result.dtype == torch.float32 + assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dictionaries(): + """Test to_tensor with nested dictionaries.""" + # Simple dictionary + data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42} + result = to_tensor(data) + assert isinstance(result, dict) + assert isinstance(result["mean"], torch.Tensor) + assert isinstance(result["std"], torch.Tensor) + assert isinstance(result["count"], torch.Tensor) + assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["std"], torch.tensor([1.0, 2.0])) + assert result["count"].item() == pytest.approx(42.0) + + # Nested dictionary + nested = { + "action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]}, + "observation": {"mean": np.array([0.5, 0.6]), "count": 10}, + } + result = to_tensor(nested) + assert isinstance(result, dict) + assert isinstance(result["action"], dict) + assert isinstance(result["observation"], dict) + assert isinstance(result["action"]["mean"], torch.Tensor) + assert isinstance(result["observation"]["mean"], torch.Tensor) + assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2])) + assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6])) + + +def test_to_tensor_none_filtering(): + """Test that None values are filtered out from dictionaries.""" + data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}} + result = to_tensor(data) + assert "none_value" not in result + assert "also_none" not in result["nested"] + assert "valid" in result + assert "valid" in result["nested"] + assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0])) + + +def test_to_tensor_dtype_parameter(): + """Test to_tensor with different dtype parameters.""" + arr = np.array([1, 2, 3]) + + # Default dtype (float32) + result = to_tensor(arr) + assert result.dtype == torch.float32 + + # Explicit float32 + result = to_tensor(arr, dtype=torch.float32) + assert result.dtype == torch.float32 + + # Float64 + result = to_tensor(arr, dtype=torch.float64) + assert result.dtype == torch.float64 + + # Preserve original dtype + float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64) + result = to_tensor(float64_arr, dtype=None) + assert result.dtype == torch.float64 + + +def test_to_tensor_device_parameter(): + """Test to_tensor with device parameter.""" + arr = np.array([1.0, 2.0, 3.0]) + + # CPU device (default) + result = to_tensor(arr, device="cpu") + assert result.device.type == "cpu" + + # CUDA device (if available) + if torch.cuda.is_available(): + result = to_tensor(arr, device="cuda") + assert result.device.type == "cuda" + + +def test_to_tensor_empty_dict(): + """Test to_tensor with empty dictionary.""" + result = to_tensor({}) + assert isinstance(result, dict) + assert len(result) == 0 + + +def test_to_tensor_unsupported_type(): + """Test to_tensor with unsupported types raises TypeError.""" + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor("unsupported_string") + + with pytest.raises(TypeError, match="Unsupported type for tensor conversion"): + to_tensor(object()) diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index bc4727a55..13bf14192 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -20,10 +20,10 @@ import pytest import torch from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.processor.converters import to_tensor from lerobot.processor.normalize_processor import ( NormalizerProcessor, UnnormalizerProcessor, - _convert_stats_to_tensors, hotswap_stats, ) from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey @@ -51,7 +51,7 @@ def test_numpy_conversion(): "std": np.array([0.2, 0.2, 0.2]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor) assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor) @@ -66,7 +66,7 @@ def test_tensor_conversion(): "std": torch.tensor([1.0, 1.0]), } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert tensor_stats["action"]["mean"].dtype == torch.float32 assert tensor_stats["action"]["std"].dtype == torch.float32 @@ -79,7 +79,7 @@ def test_scalar_conversion(): "std": 0.1, } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5)) assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1)) @@ -92,7 +92,7 @@ def test_list_conversion(): "max": [1.0, 1.0, 2.0], } } - tensor_stats = _convert_stats_to_tensors(stats) + tensor_stats = to_tensor(stats) assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0])) assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0])) @@ -105,7 +105,7 @@ def test_unsupported_type(): } } with pytest.raises(TypeError, match="Unsupported type"): - _convert_stats_to_tensors(stats) + to_tensor(stats) # Helper functions to create feature maps and norm maps @@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality(): assert new_processor.steps[1].stats == new_stats # Check that tensor stats are updated correctly - expected_tensor_stats = _convert_stats_to_tensors(new_stats) + expected_tensor_stats = to_tensor(new_stats) for key in expected_tensor_stats: for stat_name in expected_tensor_stats[key]: torch.testing.assert_close( @@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types(): assert step.stats == new_stats # Check tensor stats conversion - expected_tensor_stats = _convert_stats_to_tensors(new_stats) + expected_tensor_stats = to_tensor(new_stats) for key in expected_tensor_stats: for stat_name in expected_tensor_stats[key]: torch.testing.assert_close(