diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index 2551b2b6b..566b67402 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -139,7 +139,7 @@ def _(value: dict, *, device=None, **kwargs) -> dict: return result -def _from_tensor(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: +def from_tensor_to_numpy(x: torch.Tensor | Any) -> np.ndarray | float | int | Any: """ Convert a PyTorch tensor to a numpy array or scalar if applicable. @@ -421,17 +421,17 @@ def transition_to_dataset_frame( # Create observation.state vector. if obs_state_names: - vals = [_from_tensor(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names] + vals = [from_tensor_to_numpy(obs.get(f"{OBS_STATE}.{n}", 0.0)) for n in obs_state_names] batch[OBS_STATE] = np.asarray(vals, dtype=np.float32) # Create action vector. if action_names: - vals = [_from_tensor(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names] + vals = [from_tensor_to_numpy(act.get(f"{ACTION}.{n}", 0.0)) for n in action_names] batch[ACTION] = np.asarray(vals, dtype=np.float32) # Add transition metadata. if tr.get(TransitionKey.REWARD) is not None: - reward_val = _from_tensor(tr[TransitionKey.REWARD]) + reward_val = from_tensor_to_numpy(tr[TransitionKey.REWARD]) # Check if features expect array format, otherwise keep as scalar. if REWARD in features and features[REWARD].get("shape") == (1,): batch[REWARD] = np.array([reward_val], dtype=np.float32) @@ -439,14 +439,14 @@ def transition_to_dataset_frame( batch[REWARD] = reward_val if tr.get(TransitionKey.DONE) is not None: - done_val = _from_tensor(tr[TransitionKey.DONE]) + done_val = from_tensor_to_numpy(tr[TransitionKey.DONE]) if DONE in features and features[DONE].get("shape") == (1,): batch[DONE] = np.array([done_val], dtype=bool) else: batch[DONE] = done_val if tr.get(TransitionKey.TRUNCATED) is not None: - truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED]) + truncated_val = from_tensor_to_numpy(tr[TransitionKey.TRUNCATED]) if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,): batch[TRUNCATED] = np.array([truncated_val], dtype=bool) else: diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index d61b84660..9b502e54d 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -27,7 +27,7 @@ from torch import Tensor from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature from lerobot.datasets.lerobot_dataset import LeRobotDataset -from .converters import to_tensor +from .converters import from_tensor_to_numpy, to_tensor from .core import EnvTransition, TransitionKey from .pipeline import PolicyProcessorPipeline, ProcessorStep, ProcessorStepRegistry @@ -101,7 +101,6 @@ class _NormalizationMixin: self.stats = self.stats or {} if self.dtype is None: self.dtype = torch.float32 - self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) def to( @@ -158,6 +157,16 @@ class _NormalizationMixin: dtype=torch.float32, device=self.device ) + # Reconstruct the original stats dict from tensor stats for compatibility with to() method + # and other functions that rely on self.stats + + self.stats = {} + for key, tensor_dict in self._tensor_stats.items(): + self.stats[key] = {} + for stat_name, tensor in tensor_dict.items(): + # Convert tensor back to python/numpy format + self.stats[key][stat_name] = from_tensor_to_numpy(tensor) + def get_config(self) -> dict[str, Any]: """ Returns a serializable dictionary of the processor's configuration. diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index a109392fc..25ad66860 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1586,3 +1586,116 @@ def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): for stat_tensor in normalizer._tensor_stats["observation.state"].values(): assert stat_tensor.dtype == torch.bfloat16 assert stat_tensor.device.type == "cuda" + + +def test_stats_reconstruction_after_load_state_dict(): + """ + Test that stats dict is properly reconstructed from _tensor_stats after loading. + + This test ensures the bug where stats became empty after loading is fixed. + The bug occurred when: + 1. Only _tensor_stats were saved via state_dict() + 2. stats field became empty {} after loading + 3. Calling to() method or hotswap_stats would fail because they depend on self.stats + """ + + # Create normalizer with stats + features = { + "observation.image": PolicyFeature(FeatureType.VISUAL, (3, 96, 96)), + "observation.state": PolicyFeature(FeatureType.STATE, (2,)), + "action": PolicyFeature(FeatureType.ACTION, (2,)), + } + norm_map = { + FeatureType.VISUAL: NormalizationMode.MEAN_STD, + FeatureType.STATE: NormalizationMode.MIN_MAX, + FeatureType.ACTION: NormalizationMode.MEAN_STD, + } + stats = { + "observation.image": { + "mean": np.array([0.5, 0.5, 0.5]), + "std": np.array([0.2, 0.2, 0.2]), + }, + "observation.state": { + "min": np.array([0.0, -1.0]), + "max": np.array([1.0, 1.0]), + }, + "action": { + "mean": np.array([0.0, 0.0]), + "std": np.array([1.0, 2.0]), + }, + } + + original_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats=stats) + + # Save state dict (simulating save/load) + state_dict = original_normalizer.state_dict() + + # Create new normalizer with empty stats (simulating load) + new_normalizer = NormalizerProcessorStep(features=features, norm_map=norm_map, stats={}) + + # Before fix: this would cause stats to remain empty + new_normalizer.load_state_dict(state_dict) + + # Verify that stats dict is properly reconstructed from _tensor_stats + assert new_normalizer.stats is not None + assert new_normalizer.stats != {} + + # Check that all expected keys are present + assert "observation.image" in new_normalizer.stats + assert "observation.state" in new_normalizer.stats + assert "action" in new_normalizer.stats + + # Check that values are correct (converted back from tensors) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["mean"], [0.5, 0.5, 0.5]) + np.testing.assert_allclose(new_normalizer.stats["observation.image"]["std"], [0.2, 0.2, 0.2]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["min"], [0.0, -1.0]) + np.testing.assert_allclose(new_normalizer.stats["observation.state"]["max"], [1.0, 1.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["mean"], [0.0, 0.0]) + np.testing.assert_allclose(new_normalizer.stats["action"]["std"], [1.0, 2.0]) + + # Test that methods that depend on self.stats work correctly after loading + # This would fail before the bug fix because self.stats was empty + + # Test 1: to() method should work without crashing + try: + new_normalizer.to(device="cpu", dtype=torch.float32) + # If we reach here, the bug is fixed + except (KeyError, AttributeError) as e: + pytest.fail(f"to() method failed after loading state_dict: {e}") + + # Test 2: hotswap_stats should work + new_stats = { + "observation.image": {"mean": [0.3, 0.3, 0.3], "std": [0.1, 0.1, 0.1]}, + "observation.state": {"min": [-1.0, -2.0], "max": [2.0, 2.0]}, + "action": {"mean": [0.1, 0.1], "std": [0.5, 0.5]}, + } + + pipeline = DataProcessorPipeline([new_normalizer]) + try: + new_pipeline = hotswap_stats(pipeline, new_stats) + # If we reach here, hotswap_stats worked correctly + assert new_pipeline.steps[0].stats == new_stats + except (KeyError, AttributeError) as e: + pytest.fail(f"hotswap_stats failed after loading state_dict: {e}") + + # Test 3: The normalizer should work functionally the same as the original + observation = { + "observation.image": torch.tensor([0.7, 0.5, 0.3]), + "observation.state": torch.tensor([0.5, 0.0]), + } + action = torch.tensor([1.0, -0.5]) + transition = create_transition(observation=observation, action=action) + + original_result = original_normalizer(transition) + new_result = new_normalizer(transition) + + # Results should be identical (within floating point precision) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.image"], + new_result[TransitionKey.OBSERVATION]["observation.image"], + ) + torch.testing.assert_close( + original_result[TransitionKey.OBSERVATION]["observation.state"], + new_result[TransitionKey.OBSERVATION]["observation.state"], + ) + torch.testing.assert_close(original_result[TransitionKey.ACTION], new_result[TransitionKey.ACTION])