Revert "refactor(converters): implement unified tensor conversion function (#…" (#1840)

This reverts commit a837685bf8.
This commit is contained in:
Steven Palma
2025-09-02 13:43:35 +02:00
committed by GitHub
parent a837685bf8
commit 15ffc01fb3
5 changed files with 63 additions and 313 deletions
+8 -8
View File
@@ -20,10 +20,10 @@ import pytest
import torch
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.processor.converters import to_tensor
from lerobot.processor.normalize_processor import (
NormalizerProcessor,
UnnormalizerProcessor,
_convert_stats_to_tensors,
hotswap_stats,
)
from lerobot.processor.pipeline import IdentityProcessor, RobotProcessor, TransitionKey
@@ -51,7 +51,7 @@ def test_numpy_conversion():
"std": np.array([0.2, 0.2, 0.2]),
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert isinstance(tensor_stats["observation.image"]["mean"], torch.Tensor)
assert isinstance(tensor_stats["observation.image"]["std"], torch.Tensor)
@@ -66,7 +66,7 @@ def test_tensor_conversion():
"std": torch.tensor([1.0, 1.0]),
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert tensor_stats["action"]["mean"].dtype == torch.float32
assert tensor_stats["action"]["std"].dtype == torch.float32
@@ -79,7 +79,7 @@ def test_scalar_conversion():
"std": 0.1,
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["reward"]["mean"], torch.tensor(0.5))
assert torch.allclose(tensor_stats["reward"]["std"], torch.tensor(0.1))
@@ -92,7 +92,7 @@ def test_list_conversion():
"max": [1.0, 1.0, 2.0],
}
}
tensor_stats = to_tensor(stats)
tensor_stats = _convert_stats_to_tensors(stats)
assert torch.allclose(tensor_stats["observation.state"]["min"], torch.tensor([0.0, -1.0, -2.0]))
assert torch.allclose(tensor_stats["observation.state"]["max"], torch.tensor([1.0, 1.0, 2.0]))
@@ -105,7 +105,7 @@ def test_unsupported_type():
}
}
with pytest.raises(TypeError, match="Unsupported type"):
to_tensor(stats)
_convert_stats_to_tensors(stats)
# Helper functions to create feature maps and norm maps
@@ -1017,7 +1017,7 @@ def test_hotswap_stats_basic_functionality():
assert new_processor.steps[1].stats == new_stats
# Check that tensor stats are updated correctly
expected_tensor_stats = to_tensor(new_stats)
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(
@@ -1223,7 +1223,7 @@ def test_hotswap_stats_multiple_normalizer_types():
assert step.stats == new_stats
# Check tensor stats conversion
expected_tensor_stats = to_tensor(new_stats)
expected_tensor_stats = _convert_stats_to_tensors(new_stats)
for key in expected_tensor_stats:
for stat_name in expected_tensor_stats[key]:
torch.testing.assert_close(