From d32006440c13e47fc22cd60deddf3679f415ec99 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 8 Sep 2025 10:46:35 +0200 Subject: [PATCH] refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880) * refactor(processors): reorder processor steps for consistency across implementations - Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep. - Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability. * refactor(normalization): remove dtype specification in tensor conversion for adaptation logic - Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types. - Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types. --- src/lerobot/policies/act/processor_act.py | 5 +- .../policies/diffusion/processor_diffusion.py | 4 +- src/lerobot/policies/pi0/processor_pi0.py | 10 +-- .../policies/pi0fast/processor_pi0fast.py | 4 +- src/lerobot/policies/sac/processor_sac.py | 4 +- .../policies/smolvla/processor_smolvla.py | 10 +-- src/lerobot/policies/tdmpc/processor_tdmpc.py | 4 +- src/lerobot/policies/vqbet/processor_vqbet.py | 4 +- src/lerobot/processor/normalize_processor.py | 43 +++++---- tests/processor/test_act_processor.py | 73 ++++++++++++++- tests/processor/test_diffusion_processor.py | 80 ++++++++++++++++- tests/processor/test_normalize_processor.py | 89 ++++++++++++++++++ tests/processor/test_pi0_processor.py | 78 ++++++++++++++-- tests/processor/test_sac_processor.py | 82 +++++++++++++++-- tests/processor/test_smolvla_processor.py | 80 +++++++++++++++-- tests/processor/test_tdmpc_processor.py | 89 ++++++++++++++++-- tests/processor/test_vqbet_processor.py | 90 +++++++++++++++++-- 17 files changed, 677 insertions(+), 72 deletions(-) diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 698740ce8..aec922839 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -41,13 +41,14 @@ def make_act_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, + device=config.device, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 9914cd0c1..2d7868b25 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -42,13 +42,13 @@ def make_diffusion_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 86cd76f84..acdc0dca9 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -81,11 +81,6 @@ def make_pi0_pre_post_processors( # Add remaining processors input_steps: list[ProcessorStep] = [ RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), AddBatchDimensionProcessorStep(), Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma TokenizerProcessorStep( @@ -95,6 +90,11 @@ def make_pi0_pre_post_processors( padding="max_length", ), DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), ] output_steps: list[ProcessorStep] = [ diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index c815c8379..38882c21f 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -42,13 +42,13 @@ def make_pi0fast_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 4f0f8c5a3..9130a196d 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -43,13 +43,13 @@ def make_sac_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 2123efb50..1f2abcede 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -46,11 +46,6 @@ def make_smolvla_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one - NormalizerProcessorStep( - features={**config.input_features, **config.output_features}, - norm_map=config.normalization_mapping, - stats=dataset_stats, - ), AddBatchDimensionProcessorStep(), SmolVLANewLineProcessor(), TokenizerProcessorStep( @@ -60,6 +55,11 @@ def make_smolvla_pre_post_processors( max_length=config.tokenizer_max_length, ), DeviceProcessorStep(device=config.device), + NormalizerProcessorStep( + features={**config.input_features, **config.output_features}, + norm_map=config.normalization_mapping, + stats=dataset_stats, + ), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 28a66fc3e..d131972cb 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -42,13 +42,13 @@ def make_tdmpc_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index 7743b6ee0..ad78d1b6a 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -43,13 +43,13 @@ def make_vqbet_pre_post_processors( input_steps = [ RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys + AddBatchDimensionProcessorStep(), + DeviceProcessorStep(device=config.device), NormalizerProcessorStep( features={**config.input_features, **config.output_features}, norm_map=config.normalization_mapping, stats=dataset_stats, ), - AddBatchDimensionProcessorStep(), - DeviceProcessorStep(device=config.device), ] output_steps = [ DeviceProcessorStep(device="cpu"), diff --git a/src/lerobot/processor/normalize_processor.py b/src/lerobot/processor/normalize_processor.py index 0e97c2568..ab67b1708 100644 --- a/src/lerobot/processor/normalize_processor.py +++ b/src/lerobot/processor/normalize_processor.py @@ -29,6 +29,7 @@ class _NormalizationMixin: norm_map: dict[FeatureType, NormalizationMode] stats: dict[str, dict[str, Any]] | None = None device: torch.device | str | None = None + dtype: torch.dtype | None = None eps: float = 1e-8 normalize_observation_keys: set[str] | None = None @@ -56,12 +57,20 @@ class _NormalizationMixin: # Convert stats to tensors and move to the target device once during initialization. self.stats = self.stats or {} - self._tensor_stats = to_tensor(self.stats, device=self.device) + if self.dtype is None: + self.dtype = torch.float32 - def to(self, device: torch.device | str) -> _NormalizationMixin: + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) + + def to( + self, device: torch.device | str | None = None, dtype: torch.dtype | None = None + ) -> _NormalizationMixin: """Moves the processor's normalization stats to the specified device and returns self.""" - self.device = device - self._tensor_stats = to_tensor(self.stats, device=self.device) + if device is not None: + self.device = device + if dtype is not None: + self.dtype = dtype + self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) return self def state_dict(self) -> dict[str, Tensor]: @@ -98,12 +107,14 @@ class _NormalizationMixin: if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: continue if feature.type != FeatureType.ACTION and key in new_observation: - tensor = torch.as_tensor(new_observation[key], dtype=torch.float32) + # Convert to tensor but preserve original dtype for adaptation logic + tensor = torch.as_tensor(new_observation[key]) new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) return new_observation def _normalize_action(self, action: Any, inverse: bool) -> Tensor: - tensor = torch.as_tensor(action, dtype=torch.float32) + # Convert to tensor but preserve original dtype for adaptation logic + tensor = torch.as_tensor(action) processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse) return processed_action @@ -118,19 +129,13 @@ class _NormalizationMixin: if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): raise ValueError(f"Unsupported normalization mode: {norm_mode}") - # Ensure input tensor is on the same device as the stats. - if self.device and tensor.device != self.device: - tensor = tensor.to(self.device) + # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor + if self._tensor_stats and key in self._tensor_stats: + first_stat = next(iter(self._tensor_stats[key].values())) + if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype: + self.to(device=tensor.device, dtype=tensor.dtype) - # For Accelerate compatibility: move stats to match input tensor device - input_device = tensor.device stats = self._tensor_stats[key] - tensor = tensor.to(dtype=torch.float32) - - # Move stats to input device if needed - stats_device = next(iter(stats.values())).device - if stats_device != input_device: - stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key] if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: mean, std = stats["mean"], stats["std"] @@ -147,7 +152,7 @@ class _NormalizationMixin: # to prevent division by zero. This consistently maps an input equal to # min_val to -1, ensuring a stable transformation. denom = torch.where( - denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom + denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom ) if inverse: # Map from [-1, 1] back to [min, max] @@ -268,5 +273,5 @@ def hotswap_stats( if isinstance(step, _NormalizationMixin): step.stats = stats # Re-initialize tensor_stats on the correct device. - step._tensor_stats = to_tensor(stats, device=step.device) + step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype) return rp diff --git a/tests/processor/test_act_processor.py b/tests/processor/test_act_processor.py index b52df9a27..c577405a8 100644 --- a/tests/processor/test_act_processor.py +++ b/tests/processor/test_act_processor.py @@ -87,9 +87,9 @@ def test_make_act_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 4 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -308,6 +308,17 @@ def test_act_processor_mixed_precision(): for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) else: modified_steps.append(step) preprocessor.steps = modified_steps @@ -353,3 +364,59 @@ def test_act_processor_batch_consistency(): processed_batched = preprocessor(transition_batched) assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8 assert processed_batched[TransitionKey.ACTION].shape[0] == 8 + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_act_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_act_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32 + action = torch.randn(4, dtype=torch.float32) + transition = create_transition(observation, action) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_diffusion_processor.py b/tests/processor/test_diffusion_processor.py index 8ab477fab..6e660937c 100644 --- a/tests/processor/test_diffusion_processor.py +++ b/tests/processor/test_diffusion_processor.py @@ -90,9 +90,9 @@ def test_make_diffusion_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 4 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -299,6 +299,17 @@ def test_diffusion_processor_mixed_precision(): for step in factory_preprocessor.steps: if isinstance(step, DeviceProcessorStep): modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) else: modified_steps.append(step) @@ -379,3 +390,66 @@ def test_diffusion_processor_batch_consistency(): assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch assert processed[TransitionKey.ACTION].shape[0] == expected_batch + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_diffusion_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + # Get the steps from the factory function + factory_preprocessor, _ = make_diffusion_pre_post_processors(config, stats) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in factory_preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + + # Create new processor with modified steps + preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x) + + # Verify initial normalizer configuration + normalizer_step = modified_steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(7, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert ( + processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 + ) # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_normalize_processor.py b/tests/processor/test_normalize_processor.py index f8bf4660e..a109392fc 100644 --- a/tests/processor/test_normalize_processor.py +++ b/tests/processor/test_normalize_processor.py @@ -1497,3 +1497,92 @@ def test_roundtrip_normalize_unnormalize_non_identity(): out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 ) assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) + + +def test_dtype_adaptation_bfloat16_input_float32_normalizer(): + """Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output""" + features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = { + "observation.state": { + "mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]), + "std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]), + } + } + + # Create normalizer configured with float32 dtype + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial configuration + assert normalizer.dtype == torch.float32 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.float32 + + # Create bfloat16 input tensor + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)} + transition = create_transition(observation=observation) + + # Process the transition + result = normalizer(transition) + + # Verify that: + # 1. Stats were automatically adapted to bfloat16 + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + + # 2. Output is in bfloat16 + output_tensor = result[TransitionKey.OBSERVATION]["observation.state"] + assert output_tensor.dtype == torch.bfloat16 + + # 3. Normalization was applied correctly (mean should be close to original - mean) / std + expected = ( + torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16) + - torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16) + ) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16) + assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32(): + """Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output""" + from lerobot.processor import DeviceProcessorStep + + features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))} + norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD} + stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}} + + # Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32) + device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16") + normalizer = NormalizerProcessorStep( + features=features, norm_map=norm_map, stats=stats, dtype=torch.float32 + ) + + # Verify initial normalizer configuration + assert normalizer.dtype == torch.float32 + + # Create CPU input + observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)} + transition = create_transition(observation=observation) + + # Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA + processed_1 = device_processor(transition) + intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"] + assert intermediate_tensor.dtype == torch.bfloat16 + assert intermediate_tensor.device.type == "cuda" + + # Step 2: NormalizerProcessor receives bfloat16 input and adapts + final_result = normalizer(processed_1) + final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"] + + # Verify final output is bfloat16 (automatic adaptation worked) + assert final_tensor.dtype == torch.bfloat16 + assert final_tensor.device.type == "cuda" + + # Verify normalizer adapted its internal state + assert normalizer.dtype == torch.bfloat16 + for stat_tensor in normalizer._tensor_stats["observation.state"].values(): + assert stat_tensor.dtype == torch.bfloat16 + assert stat_tensor.device.type == "cuda" diff --git a/tests/processor/test_pi0_processor.py b/tests/processor/test_pi0_processor.py index 2c8c0921a..27b354e9d 100644 --- a/tests/processor/test_pi0_processor.py +++ b/tests/processor/test_pi0_processor.py @@ -116,11 +116,11 @@ def test_make_pi0_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 6 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], Pi0NewLineProcessor) - # Step 4 would be TokenizerProcessorStep but it's mocked - assert isinstance(preprocessor.steps[5], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -377,3 +377,71 @@ def test_pi0_newline_processor_state_dict(): # Test get_config config = processor.get_config() assert config == {} + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_pi0_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + stats = create_default_stats() + config.device = "cuda" + + with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep): + preprocessor, _ = make_pi0_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10 + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6 + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert ( + processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 + ) # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_sac_processor.py b/tests/processor/test_sac_processor.py index 7a2d2d19c..522411e0f 100644 --- a/tests/processor/test_sac_processor.py +++ b/tests/processor/test_sac_processor.py @@ -92,9 +92,9 @@ def test_make_sac_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 4 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -307,9 +307,24 @@ def test_sac_processor_mixed_precision(): ) # Replace DeviceProcessorStep with one that uses float16 - for i, step in enumerate(preprocessor.steps): + modified_steps = [] + for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): - preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps # Create test data observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} @@ -374,3 +389,60 @@ def test_sac_processor_edge_cases(): assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) # When action is None, it may still be present with None value assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_sac_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_sac_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data + observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32 + action = torch.randn(5, dtype=torch.float32) + transition = create_transition(observation, action) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 diff --git a/tests/processor/test_smolvla_processor.py b/tests/processor/test_smolvla_processor.py index 9a187fc7e..5d2e9c7ce 100644 --- a/tests/processor/test_smolvla_processor.py +++ b/tests/processor/test_smolvla_processor.py @@ -123,11 +123,11 @@ def test_make_smolvla_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 6 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], SmolVLANewLineProcessor) - # Step 4 would be TokenizerProcessorStep but it's mocked - assert isinstance(preprocessor.steps[5], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor) + # Step 3 would be TokenizerProcessorStep but it's mocked + assert isinstance(preprocessor.steps[4], DeviceProcessorStep) + assert isinstance(preprocessor.steps[5], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -404,3 +404,73 @@ def test_smolvla_newline_processor_transform_features(): } result = processor.transform_features(features) assert result == features # Should return unchanged + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_smolvla_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + with patch( + "lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep + ): + preprocessor, _ = make_smolvla_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5) + normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition( + observation, action, complementary_data={"task": "test bfloat16 adaptation"} + ) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert ( + processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 + ) # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_tdmpc_processor.py b/tests/processor/test_tdmpc_processor.py index eb38143d7..bc9ff2bdc 100644 --- a/tests/processor/test_tdmpc_processor.py +++ b/tests/processor/test_tdmpc_processor.py @@ -95,9 +95,9 @@ def test_make_tdmpc_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 4 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -331,9 +331,24 @@ def test_tdmpc_processor_mixed_precision(): ) # Replace DeviceProcessorStep with one that uses float16 - for i, step in enumerate(preprocessor.steps): + modified_steps = [] + for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): - preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps # Create test data observation = { @@ -410,3 +425,67 @@ def test_tdmpc_processor_edge_cases(): processed = preprocessor(transition) assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) assert OBS_STATE not in processed[TransitionKey.OBSERVATION] + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_tdmpc_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_tdmpc_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(12, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(6, dtype=torch.float32) + transition = create_transition(observation, action) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert ( + processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 + ) # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check diff --git a/tests/processor/test_vqbet_processor.py b/tests/processor/test_vqbet_processor.py index 787a3d524..bc24c5e0f 100644 --- a/tests/processor/test_vqbet_processor.py +++ b/tests/processor/test_vqbet_processor.py @@ -95,9 +95,9 @@ def test_make_vqbet_processor_basic(): # Check steps in preprocessor assert len(preprocessor.steps) == 4 assert isinstance(preprocessor.steps[0], RenameProcessorStep) - assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) - assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) - assert isinstance(preprocessor.steps[3], DeviceProcessorStep) + assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep) + assert isinstance(preprocessor.steps[2], DeviceProcessorStep) + assert isinstance(preprocessor.steps[3], NormalizerProcessorStep) # Check steps in postprocessor assert len(postprocessor.steps) == 2 @@ -324,9 +324,24 @@ def test_vqbet_processor_mixed_precision(): ) # Replace DeviceProcessorStep with one that uses float16 - for i, step in enumerate(preprocessor.steps): + modified_steps = [] + for step in preprocessor.steps: if isinstance(step, DeviceProcessorStep): - preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) + elif isinstance(step, NormalizerProcessorStep): + # Update normalizer to use the same device as the device processor + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float16, # Match the float16 dtype + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps # Create test data observation = { @@ -405,3 +420,68 @@ def test_vqbet_processor_sequential_processing(): assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) assert result[TransitionKey.ACTION].shape == (1, 7) + + +@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available") +def test_vqbet_processor_bfloat16_device_float32_normalizer(): + """Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation""" + config = create_default_config() + config.device = "cuda" + stats = create_default_stats() + + preprocessor, _ = make_vqbet_pre_post_processors( + config, + stats, + preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x}, + ) + + # Modify the pipeline to use bfloat16 device processor with float32 normalizer + modified_steps = [] + for step in preprocessor.steps: + if isinstance(step, DeviceProcessorStep): + # Device processor converts to bfloat16 + modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16")) + elif isinstance(step, NormalizerProcessorStep): + # Normalizer stays configured as float32 (will auto-adapt to bfloat16) + modified_steps.append( + NormalizerProcessorStep( + features=step.features, + norm_map=step.norm_map, + stats=step.stats, + device=config.device, + dtype=torch.float32, # Deliberately configured as float32 + ) + ) + else: + modified_steps.append(step) + preprocessor.steps = modified_steps + + # Verify initial normalizer configuration + normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep + assert normalizer_step.dtype == torch.float32 + + # Create test data with both state and visual observations + observation = { + OBS_STATE: torch.randn(8, dtype=torch.float32), + OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32), + } + action = torch.randn(7, dtype=torch.float32) + transition = create_transition(observation, action) + + # Process through full pipeline + processed = preprocessor(transition) + + # Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16 + assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16 + assert ( + processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16 + ) # IDENTITY normalization still gets dtype conversion + assert processed[TransitionKey.ACTION].dtype == torch.bfloat16 + + # Verify normalizer automatically adapted its internal state + assert normalizer_step.dtype == torch.bfloat16 + # Check state stats (has normalization) + for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values(): + assert stat_tensor.dtype == torch.bfloat16 + # OBS_IMAGE uses IDENTITY normalization, so no stats to check