refactor(converters): implement unified tensor conversion function (#1841)

- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors.
- Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability.
- Updated tests to cover new functionality and ensure correct tensor conversion behavior.

Co-authored-by: AdilZouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Steven Palma
2025-09-02 13:47:04 +02:00
committed by GitHub
parent 15ffc01fb3
commit 2c802ac134
5 changed files with 313 additions and 63 deletions
+110 -13
View File
@@ -18,6 +18,7 @@ from __future__ import annotations
from collections.abc import Iterable, Sequence
from copy import deepcopy
from functools import singledispatch
from typing import Any
import numpy as np
@@ -29,17 +30,113 @@ from lerobot.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD, TRUNC
from .pipeline import EnvTransition, TransitionKey
def _to_tensor(x: torch.Tensor | np.ndarray | Sequence[int | float]):
if isinstance(x, torch.Tensor):
return x
if isinstance(x, np.ndarray):
# Keep images (uint8 HWC) and python objects as-is
if x.dtype == np.uint8 or x.dtype == np.object_:
return x
# Scalars/arrays to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
# Anything else to float32 tensor
return torch.as_tensor(x, dtype=torch.float32)
@singledispatch
def to_tensor(
value: Any,
*,
dtype: torch.dtype | None = torch.float32,
device: torch.device | str | None = None,
) -> torch.Tensor:
"""
Convert various data types to PyTorch tensors with configurable options.
This is a unified tensor conversion function using single dispatch to handle
different input types appropriately.
Args:
value: Input value to convert (tensor, array, scalar, sequence, etc.)
dtype: Target tensor dtype. If None, preserves original dtype.
device: Target device for the tensor.
Returns:
PyTorch tensor.
Raises:
TypeError: If the input type is not supported.
"""
raise TypeError(f"Unsupported type for tensor conversion: {type(value)}")
@to_tensor.register(torch.Tensor)
def _(value: torch.Tensor, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle existing PyTorch tensors."""
if dtype is not None:
value = value.to(dtype=dtype)
if device is not None:
value = value.to(device=device)
return value
@to_tensor.register(np.ndarray)
def _(
value: np.ndarray,
*,
dtype=torch.float32,
device=None,
**kwargs,
) -> torch.Tensor:
"""Handle numpy arrays."""
# Check for numpy scalars (0-dimensional arrays) and treat them as scalars
if value.ndim == 0:
# Numpy scalars should be converted to 0-dimensional tensors
scalar_value = value.item()
return torch.tensor(scalar_value, dtype=dtype, device=device)
# Create tensor from numpy array (torch.from_numpy handles contiguity automatically)
tensor = torch.from_numpy(value)
# Apply dtype conversion if specified
if dtype is not None:
tensor = tensor.to(dtype=dtype)
if device is not None:
tensor = tensor.to(device=device)
return tensor
@to_tensor.register(int)
@to_tensor.register(float)
@to_tensor.register(np.integer)
@to_tensor.register(np.floating)
def _(value, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle scalar values including numpy scalars."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(list)
@to_tensor.register(tuple)
def _(value: Sequence, *, dtype=torch.float32, device=None, **kwargs) -> torch.Tensor:
"""Handle sequences (lists, tuples)."""
return torch.tensor(value, dtype=dtype, device=device)
@to_tensor.register(dict)
def _(value: dict, *, device=None, **kwargs) -> dict:
"""Handle dictionaries by recursively converting values to tensors."""
if not value:
return {}
result = {}
for key, sub_value in value.items():
if sub_value is None:
continue
if isinstance(sub_value, dict):
# Recursively process nested dictionaries
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
continue
# Convert individual values to tensors
result[key] = to_tensor(
sub_value,
device=device,
**kwargs,
)
return result
def _from_tensor(x: Any):
@@ -88,7 +185,7 @@ def to_transition_teleop_action(action: dict[str, Any]) -> EnvTransition:
continue
arr = np.array(v) if np.isscalar(v) else v
act_dict[f"{ACTION}.{k}"] = _to_tensor(arr)
act_dict[f"{ACTION}.{k}"] = to_tensor(arr)
return make_obs_act_transition(act=act_dict)
@@ -103,7 +200,7 @@ def to_transition_robot_observation(observation: dict[str, Any]) -> EnvTransitio
obs_dict: dict[str, Any] = {}
for k, v in state.items():
arr = np.array(v) if np.isscalar(v) else v
obs_dict[f"{OBS_STATE}.{k}"] = _to_tensor(arr)
obs_dict[f"{OBS_STATE}.{k}"] = to_tensor(arr)
for cam, img in images.items():
obs_dict[f"{OBS_IMAGES}.{cam}"] = img
@@ -16,6 +16,7 @@ from dataclasses import dataclass
import numpy as np
import torch
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import ActionProcessor, ProcessorStepRegistry
@@ -59,5 +60,5 @@ class Numpy2TorchActionProcessor(ActionProcessor):
f"Expected np.ndarray or None, got {type(action).__name__}. "
"Use appropriate processor for non-tensor actions."
)
torch_action = torch.from_numpy(action)
torch_action = to_tensor(action, dtype=None) # Preserve original dtype
return torch_action
+5 -36
View File
@@ -4,12 +4,12 @@ from copy import deepcopy
from dataclasses import dataclass, field
from typing import Any
import numpy as np
import torch
from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor.converters import to_tensor
from lerobot.processor.pipeline import (
EnvTransition,
ProcessorStep,
@@ -19,37 +19,6 @@ from lerobot.processor.pipeline import (
)
def _to_tensor(value: Any, device: torch.device | None = None) -> Tensor:
"""Convert common python/numpy/torch types to a torch.float32 tensor.
Always returns float32; preserves device if provided.
"""
if isinstance(value, torch.Tensor):
return value.to(dtype=torch.float32, device=device)
if isinstance(value, np.ndarray):
# ensure contiguous, cast to float32 then convert
return torch.from_numpy(np.ascontiguousarray(value.astype(np.float32))).to(device=device)
if isinstance(value, (int, float)):
return torch.tensor(value, dtype=torch.float32, device=device)
if isinstance(value, (list, tuple)):
return torch.tensor(value, dtype=torch.float32, device=device)
raise TypeError(f"Unsupported type for stats value: {type(value)}")
def _convert_stats_to_tensors(
stats: dict[str, dict[str, Any]], device: torch.device | None = None
) -> dict[str, dict[str, Tensor]]:
"""Convert numeric stats values to torch tensors, preserving keys."""
tensor_stats: dict[str, dict[str, Tensor]] = {}
for key, sub in (stats or {}).items():
if sub is None:
continue
tensor_stats[key] = {}
for stat_name, value in sub.items():
tensor_stats[key][stat_name] = _to_tensor(value, device=device)
return tensor_stats
@dataclass
class _NormalizationMixin:
"""
@@ -91,12 +60,12 @@ class _NormalizationMixin:
# Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {}
self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device)
self._tensor_stats = to_tensor(self.stats, device=self.device)
def to(self, device: torch.device | str) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self."""
self.device = device
self._tensor_stats = _convert_stats_to_tensors(self.stats, device=self.device)
self._tensor_stats = to_tensor(self.stats, device=self.device)
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -165,7 +134,7 @@ class _NormalizationMixin:
# Move stats to input device if needed
stats_device = next(iter(stats.values())).device
if stats_device != input_device:
stats = _convert_stats_to_tensors({key: self._tensor_stats[key]}, device=input_device)[key]
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"]
@@ -295,5 +264,5 @@ def hotswap_stats(robot_processor: RobotProcessor, stats: dict[str, dict[str, An
if isinstance(step, _NormalizationMixin):
step.stats = stats
# Re-initialize tensor_stats on the correct device.
step._tensor_stats = _convert_stats_to_tensors(stats, device=step.device)
step._tensor_stats = to_tensor(stats, device=step.device)
return rp
+188 -5
View File
@@ -5,6 +5,7 @@ import torch
from lerobot.processor.converters import (
to_dataset_frame,
to_output_robot_action,
to_tensor,
to_transition_robot_observation,
to_transition_teleop_action,
)
@@ -12,12 +13,12 @@ from lerobot.processor.pipeline import TransitionKey
def test_to_transition_teleop_action_prefix_and_tensor_conversion():
# Scalars, arrays, and "image-like" uint8 arrays are supported
# Scalars, arrays, and uint8 arrays are all converted to tensors
img = np.zeros((8, 12, 3), dtype=np.uint8)
act = {
"ee.x": 0.5, # scalar to torch tensor
"delta": np.array([1.0, 2.0]), # ndarray to torch tensor
"raw_img": img, # uint8 HWC to passthrough ndarray
"raw_img": img, # uint8 HWC to torch tensor
}
tr = to_transition_teleop_action(act)
@@ -29,7 +30,7 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
assert "action.delta" in tr[TransitionKey.ACTION]
assert "action.raw_img" in tr[TransitionKey.ACTION]
# Types: scalars/arrays -> torch tensor; images to np.ndarray
# Types: all values -> torch tensor
assert isinstance(tr[TransitionKey.ACTION]["action.ee.x"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.ee.x"].item() == pytest.approx(0.5)
@@ -37,8 +38,8 @@ def test_to_transition_teleop_action_prefix_and_tensor_conversion():
assert tr[TransitionKey.ACTION]["action.delta"].shape == (2,)
assert torch.allclose(tr[TransitionKey.ACTION]["action.delta"], torch.tensor([1.0, 2.0]))
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], np.ndarray)
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == np.uint8
assert isinstance(tr[TransitionKey.ACTION]["action.raw_img"], torch.Tensor)
assert tr[TransitionKey.ACTION]["action.raw_img"].dtype == torch.float32 # converted from uint8
assert tr[TransitionKey.ACTION]["action.raw_img"].shape == (8, 12, 3)
# Observation is created as empty dict by make_transition
@@ -194,3 +195,185 @@ def test_to_dataset_frame_merge_and_pack_vectors_and_metadata():
# Complementary data
assert batch["frame_is_pad"] is True
assert batch["task"] == "Pick cube"
# Tests for the unified to_tensor function
def test_to_tensor_numpy_arrays():
"""Test to_tensor with various numpy arrays."""
# Regular numpy array
arr = np.array([1.0, 2.0, 3.0])
result = to_tensor(arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# Different numpy dtypes should convert to float32 by default
int_arr = np.array([1, 2, 3], dtype=np.int64)
result = to_tensor(int_arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# uint8 arrays (previously "preserved") should now convert
uint8_arr = np.array([100, 150, 200], dtype=np.uint8)
result = to_tensor(uint8_arr)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([100.0, 150.0, 200.0]))
def test_to_tensor_numpy_scalars():
"""Test to_tensor with numpy scalars (0-dimensional arrays)."""
# numpy float32 scalar
scalar = np.float32(3.14)
result = to_tensor(scalar)
assert isinstance(result, torch.Tensor)
assert result.ndim == 0 # Should be 0-dimensional tensor
assert result.dtype == torch.float32
assert result.item() == pytest.approx(3.14)
# numpy int32 scalar
int_scalar = np.int32(42)
result = to_tensor(int_scalar)
assert isinstance(result, torch.Tensor)
assert result.ndim == 0
assert result.dtype == torch.float32
assert result.item() == pytest.approx(42.0)
def test_to_tensor_python_scalars():
"""Test to_tensor with Python scalars."""
# Python int
result = to_tensor(42)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert result.item() == pytest.approx(42.0)
# Python float
result = to_tensor(3.14)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert result.item() == pytest.approx(3.14)
def test_to_tensor_sequences():
"""Test to_tensor with lists and tuples."""
# List
result = to_tensor([1, 2, 3])
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
# Tuple
result = to_tensor((4.5, 5.5, 6.5))
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([4.5, 5.5, 6.5]))
def test_to_tensor_existing_tensors():
"""Test to_tensor with existing PyTorch tensors."""
# Tensor with same dtype should pass through with potential device change
tensor = torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)
result = to_tensor(tensor)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, tensor)
# Tensor with different dtype should convert
int_tensor = torch.tensor([1, 2, 3], dtype=torch.int64)
result = to_tensor(int_tensor)
assert isinstance(result, torch.Tensor)
assert result.dtype == torch.float32
assert torch.allclose(result, torch.tensor([1.0, 2.0, 3.0]))
def test_to_tensor_dictionaries():
"""Test to_tensor with nested dictionaries."""
# Simple dictionary
data = {"mean": [0.1, 0.2], "std": np.array([1.0, 2.0]), "count": 42}
result = to_tensor(data)
assert isinstance(result, dict)
assert isinstance(result["mean"], torch.Tensor)
assert isinstance(result["std"], torch.Tensor)
assert isinstance(result["count"], torch.Tensor)
assert torch.allclose(result["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["std"], torch.tensor([1.0, 2.0]))
assert result["count"].item() == pytest.approx(42.0)
# Nested dictionary
nested = {
"action": {"mean": [0.1, 0.2], "std": [1.0, 2.0]},
"observation": {"mean": np.array([0.5, 0.6]), "count": 10},
}
result = to_tensor(nested)
assert isinstance(result, dict)
assert isinstance(result["action"], dict)
assert isinstance(result["observation"], dict)
assert isinstance(result["action"]["mean"], torch.Tensor)
assert isinstance(result["observation"]["mean"], torch.Tensor)
assert torch.allclose(result["action"]["mean"], torch.tensor([0.1, 0.2]))
assert torch.allclose(result["observation"]["mean"], torch.tensor([0.5, 0.6]))
def test_to_tensor_none_filtering():
"""Test that None values are filtered out from dictionaries."""
data = {"valid": [1, 2, 3], "none_value": None, "nested": {"valid": [4, 5], "also_none": None}}
result = to_tensor(data)
assert "none_value" not in result
assert "also_none" not in result["nested"]
assert "valid" in result
assert "valid" in result["nested"]
assert torch.allclose(result["valid"], torch.tensor([1.0, 2.0, 3.0]))
def test_to_tensor_dtype_parameter():
"""Test to_tensor with different dtype parameters."""
arr = np.array([1, 2, 3])
# Default dtype (float32)
result = to_tensor(arr)
assert result.dtype == torch.float32
# Explicit float32
result = to_tensor(arr, dtype=torch.float32)
assert result.dtype == torch.float32
# Float64
result = to_tensor(arr, dtype=torch.float64)
assert result.dtype == torch.float64
# Preserve original dtype
float64_arr = np.array([1.0, 2.0, 3.0], dtype=np.float64)
result = to_tensor(float64_arr, dtype=None)
assert result.dtype == torch.float64
def test_to_tensor_device_parameter():
"""Test to_tensor with device parameter."""
arr = np.array([1.0, 2.0, 3.0])
# CPU device (default)
result = to_tensor(arr, device="cpu")
assert result.device.type == "cpu"
# CUDA device (if available)
if torch.cuda.is_available():
result = to_tensor(arr, device="cuda")
assert result.device.type == "cuda"
def test_to_tensor_empty_dict():
"""Test to_tensor with empty dictionary."""
result = to_tensor({})
assert isinstance(result, dict)
assert len(result) == 0
def test_to_tensor_unsupported_type():
"""Test to_tensor with unsupported types raises TypeError."""
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor("unsupported_string")
with pytest.raises(TypeError, match="Unsupported type for tensor conversion"):
to_tensor(object())
+8 -8
View File
@@ -20,10 +20,10 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.converters import to_tensor
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
hotswap_stats,
)
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
@@ -51,7 +51,7 @@ def test_numpy_conversion():
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
@@ -66,7 +66,7 @@ def test_tensor_conversion():
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
@@ -79,7 +79,7 @@ def test_scalar_conversion():
"std": 0.1,
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
@@ -92,7 +92,7 @@ def test_list_conversion():
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
@@ -105,7 +105,7 @@ def test_unsupported_type():
}
}
with pytest.raises(TypeError, match="Unsupported type"):
_convert_stats_to_tensors(stats)
to_tensor(stats)
# Helper functions to create feature maps and norm maps
@@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality():
assert new_processor.steps[1].stats == new_stats
# Check that tensor stats are updated correctly
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
expected_tensor_stats = to_tensor(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
@@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types():
assert step.stats == new_stats
# Check tensor stats conversion
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
expected_tensor_stats = to_tensor(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(