mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(converters): implement unified tensor conversion function (#1830)
- Introduced `to_tensor` function using `singledispatch` to handle various input types, including scalars, arrays, and dictionaries, converting them to PyTorch tensors. - Replaced previous tensor conversion logic in `gym_action_processor`, `normalize_processor`, and `test_converters` with the new `to_tensor` function for improved readability and maintainability. - Updated tests to cover new functionality and ensure correct tensor conversion behavior.
This commit is contained in:
@@ -20,10 +20,10 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.processor.normalize_processor import (
|
||||
NormalizerProcessor,
|
||||
UnnormalizerProcessor,
|
||||
_convert_stats_to_tensors,
|
||||
hotswap_stats,
|
||||
)
|
||||
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
|
||||
@@ -51,7 +51,7 @@ def test_numpy_conversion():
|
||||
"std": np.array([0.2, 0.2, 0.2]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
|
||||
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
|
||||
@@ -66,7 +66,7 @@ def test_tensor_conversion():
|
||||
"std": torch.tensor([1.0, 1.0]),
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert tensor_stats["action"]["mean"].dtype == torch.float32
|
||||
assert tensor_stats["action"]["std"].dtype == torch.float32
|
||||
@@ -79,7 +79,7 @@ def test_scalar_conversion():
|
||||
"std": 0.1,
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
|
||||
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
|
||||
@@ -92,7 +92,7 @@ def test_list_conversion():
|
||||
"max": [1.0, 1.0, 2.0],
|
||||
}
|
||||
}
|
||||
tensor_stats = _convert_stats_to_tensors(stats)
|
||||
tensor_stats = to_tensor(stats)
|
||||
|
||||
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
|
||||
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
|
||||
@@ -105,7 +105,7 @@ def test_unsupported_type():
|
||||
}
|
||||
}
|
||||
with pytest.raises(TypeError, match="Unsupported type"):
|
||||
_convert_stats_to_tensors(stats)
|
||||
to_tensor(stats)
|
||||
|
||||
|
||||
# Helper functions to create feature maps and norm maps
|
||||
@@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality():
|
||||
assert new_processor.steps[1].stats == new_stats
|
||||
|
||||
# Check that tensor stats are updated correctly
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
@@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types():
|
||||
assert step.stats == new_stats
|
||||
|
||||
# Check tensor stats conversion
|
||||
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
|
||||
expected_tensor_stats = to_tensor(new_stats)
|
||||
for key in expected_tensor_stats:
|
||||
for stat_name in expected_tensor_stats[key]:
|
||||
torch.testing.assert_close(
|
||||
|
||||
Reference in New Issue
Block a user