refactor(converters): implement unified tensor conversion function (#1830)

- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors.
- Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability.
- Updated tests to cover new functionality and ensure correct tensor conversion behavior.
This commit is contained in:
Adil Zouitine
2025-09-02 13:28:26 +02:00
committed by GitHub
parent d32b76cc66
commit a837685bf8
5 changed files with 313 additions and 63 deletions
+8 -8
View File
@@ -20,10 +20,10 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.converters import to_tensor
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
hotswap_stats,
)
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
@@ -51,7 +51,7 @@ def test_numpy_conversion():
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
@@ -66,7 +66,7 @@ def test_tensor_conversion():
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
@@ -79,7 +79,7 @@ def test_scalar_conversion():
"std": 0.1,
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
@@ -92,7 +92,7 @@ def test_list_conversion():
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = _convert_stats_to_tensors(stats)
tensor_stats = to_tensor(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
@@ -105,7 +105,7 @@ def test_unsupported_type():
}
}
with pytest.raises(TypeError, match="Unsupported type"):
_convert_stats_to_tensors(stats)
to_tensor(stats)
# Helper functions to create feature maps and norm maps
@@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality():
assert new_processor.steps[1].stats == new_stats
# Check that tensor stats are updated correctly
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
expected_tensor_stats = to_tensor(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
@@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types():
assert step.stats == new_stats
# Check tensor stats conversion
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
expected_tensor_stats = to_tensor(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(