refactor(converters): rename _from_tensor to from_tensor_to_numpy for clarity (#1902)

- Updated the function name from _from_tensor to from_tensor_to_numpy to better reflect its purpose of converting PyTorch tensors to numpy arrays or scalars.
- Adjusted all references to the renamed function throughout the codebase to maintain consistency.
- Enhanced the _NormalizationMixin class to reconstruct the stats dictionary from tensor stats using the new function, ensuring compatibility after loading state dicts.
- Added tests to verify the correct reconstruction of stats and functionality of methods dependent on self.stats after loading.
This commit is contained in:
Adil Zouitine
2025-09-09 17:51:47 +02:00
committed by GitHub
parent a74b90edd1
commit acf0ba7fb3
3 changed files with 130 additions and 8 deletions
+6 -6
View File
@@ -139,7 +139,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict:
return result
def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any:
"""
Convert a PyTorch tensor to a numpy array or scalar if applicable.
@@ -421,17 +421,17 @@ def transition_to_dataset_frame(
# Create observation.state vector.
if obs_state_names:
vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
vals = [from_tensor_to_numpy(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names]
batch[OBS_STATE] = np.asarray(vals, dtype=np.float32)
# Create action vector.
if action_names:
vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
vals = [from_tensor_to_numpy(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names]
batch[ACTION] = np.asarray(vals, dtype=np.float32)
# Add transition metadata.
if tr.get(TransitionKey.REWARD) is not None:
reward_val = _from_tensor(tr[TransitionKey.REWARD])
reward_val = from_tensor_to_numpy(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar.
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
@@ -439,14 +439,14 @@ def transition_to_dataset_frame(
batch[REWARD] = reward_val
if tr.get(TransitionKey.DONE) is not None:
done_val = _from_tensor(tr[TransitionKey.DONE])
done_val = from_tensor_to_numpy(tr[TransitionKey.DONE])
if DONE in features and features[DONE].get("shape") == (1,):
batch[DONE] = np.array([done_val], dtype=bool)
else:
batch[DONE] = done_val
if tr.get(TransitionKey.TRUNCATED) is not None:
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
truncated_val = from_tensor_to_numpy(tr[TransitionKey.TRUNCATED])
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
else:
+11 -2
View File
@@ -27,7 +27,7 @@ from torch import Tensor
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from .converters import to_tensor
from .converters import from_tensor_to_numpy, to_tensor
from .core import EnvTransition, TransitionKey
from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry
@@ -101,7 +101,6 @@ class _NormalizationMixin:
self.stats = self.stats or {}
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
def to(
@@ -158,6 +157,16 @@ class _NormalizationMixin:
dtype=torch.float32, device=self.device
)
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
self.stats = {}
for key, tensor_dict in self._tensor_stats.items():
self.stats[key] = {}
for stat_name, tensor in tensor_dict.items():
# Convert tensor back to python/numpy format
self.stats[key][stat_name] = from_tensor_to_numpy(tensor)
def get_config(self) -> dict[str, Any]:
"""
Returns a serializable dictionary of the processor's configuration.
+113
View File
@@ -1586,3 +1586,116 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.bfloat16
assert stat_tensor.device.type == "cuda"
def test_stats_reconstruction_after_load_state_dict():
"""
Test that stats dict is properly reconstructed from _tensor_stats after loading.
This test ensures the bug where stats became empty after loading is fixed.
The bug occurred when:
1. Only _tensor_stats were saved via state_dict()
2. stats field became empty {} after loading
3. Calling to() method or hotswap_stats would fail because they depend on self.stats
"""
# Create normalizer with stats
features = {
"observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)),
"observation.state": PolicyFeature(FeatureType.STATE, (2,)),
"action": PolicyFeature(FeatureType.ACTION, (2,)),
}
norm_map = {
FeatureType.VISUAL: NormalizationMode.MEAN_STD,
FeatureType.STATE: NormalizationMode.MIN_MAX,
FeatureType.ACTION: NormalizationMode.MEAN_STD,
}
stats = {
"observation.image": {
"mean": np.array([0.5, 0.5, 0.5]),
"std": np.array([0.2, 0.2, 0.2]),
},
"observation.state": {
"min": np.array([0.0, -1.0]),
"max": np.array([1.0, 1.0]),
},
"action": {
"mean": np.array([0.0, 0.0]),
"std": np.array([1.0, 2.0]),
},
}
original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats)
# Save state dict (simulating save/load)
state_dict = original_normalizer.state_dict()
# Create new normalizer with empty stats (simulating load)
new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={})
# Before fix: this would cause stats to remain empty
new_normalizer.load_state_dict(state_dict)
# Verify that stats dict is properly reconstructed from _tensor_stats
assert new_normalizer.stats is not None
assert new_normalizer.stats != {}
# Check that all expected keys are present
assert "observation.image" in new_normalizer.stats
assert "observation.state" in new_normalizer.stats
assert "action" in new_normalizer.stats
# Check that values are correct (converted back from tensors)
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5])
np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2])
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0])
np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0])
np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0])
np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0])
# Test that methods that depend on self.stats work correctly after loading
# This would fail before the bug fix because self.stats was empty
# Test 1: to() method should work without crashing
try:
new_normalizer.to(device="cpu", dtype=torch.float32)
# If we reach here, the bug is fixed
except (KeyError, AttributeError) as e:
pytest.fail(f"to() method failed after loading state_dict: {e}")
# Test 2: hotswap_stats should work
new_stats = {
"observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]},
"observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]},
"action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]},
}
pipeline = DataProcessorPipeline([new_normalizer])
try:
new_pipeline = hotswap_stats(pipeline, new_stats)
# If we reach here, hotswap_stats worked correctly
assert new_pipeline.steps[0].stats == new_stats
except (KeyError, AttributeError) as e:
pytest.fail(f"hotswap_stats failed after loading state_dict: {e}")
# Test 3: The normalizer should work functionally the same as the original
observation = {
"observation.image": torch.tensor([0.7, 0.5, 0.3]),
"observation.state": torch.tensor([0.5, 0.0]),
}
action = torch.tensor([1.0, -0.5])
transition = create_transition(observation=observation, action=action)
original_result = original_normalizer(transition)
new_result = new_normalizer(transition)
# Results should be identical (within floating point precision)
torch.testing.assert_close(
original_result[TransitionKey.OBSERVATION]["observation.image"],
new_result[TransitionKey.OBSERVATION]["observation.image"],
)
torch.testing.assert_close(
original_result[TransitionKey.OBSERVATION]["observation.state"],
new_result[TransitionKey.OBSERVATION]["observation.state"],
)
torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION])