refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880)

* refactor(processors): reorder processor steps for consistency across implementations

- Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep.
- Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability.

* refactor(normalization): remove dtype specification in tensor conversion for adaptation logic

- Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types.
- Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types.
This commit is contained in:
Adil Zouitine
2025-09-08 10:46:35 +02:00
committed by GitHub
parent f1cfdfced9
commit d32006440c
17 changed files with 677 additions and 72 deletions
+3 -2
View File
@@ -41,13 +41,14 @@ def make_act_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), RenameProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
device=config.device,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
@@ -42,13 +42,13 @@ def make_diffusion_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), RenameProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
+5 -5
View File
@@ -81,11 +81,6 @@ def make_pi0_pre_post_processors(
# Add remaining processors # Add remaining processors
input_steps: list[ProcessorStep] = [ input_steps: list[ProcessorStep] = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma Pi0NewLineProcessor(), # Add newlines before tokenization for PaliGemma
TokenizerProcessorStep( TokenizerProcessorStep(
@@ -95,6 +90,11 @@ def make_pi0_pre_post_processors(
padding="max_length", padding="max_length",
), ),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
] ]
output_steps: list[ProcessorStep] = [ output_steps: list[ProcessorStep] = [
@@ -42,13 +42,13 @@ def make_pi0fast_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
+2 -2
View File
@@ -43,13 +43,13 @@ def make_sac_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), RenameProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
@@ -46,11 +46,6 @@ def make_smolvla_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one RenameProcessorStep(rename_map={}), # To mimic the same processor as pretrained one
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
AddBatchDimensionProcessorStep(), AddBatchDimensionProcessorStep(),
SmolVLANewLineProcessor(), SmolVLANewLineProcessor(),
TokenizerProcessorStep( TokenizerProcessorStep(
@@ -60,6 +55,11 @@ def make_smolvla_pre_post_processors(
max_length=config.tokenizer_max_length, max_length=config.tokenizer_max_length,
), ),
DeviceProcessorStep(device=config.device), DeviceProcessorStep(device=config.device),
NormalizerProcessorStep(
features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping,
stats=dataset_stats,
),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
@@ -42,13 +42,13 @@ def make_tdmpc_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), RenameProcessorStep(rename_map={}),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
@@ -43,13 +43,13 @@ def make_vqbet_pre_post_processors(
input_steps = [ input_steps = [
RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys RenameProcessorStep(rename_map={}), # Let the possibility to the user to rename the keys
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
NormalizerProcessorStep( NormalizerProcessorStep(
features={**config.input_features, **config.output_features}, features={**config.input_features, **config.output_features},
norm_map=config.normalization_mapping, norm_map=config.normalization_mapping,
stats=dataset_stats, stats=dataset_stats,
), ),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=config.device),
] ]
output_steps = [ output_steps = [
DeviceProcessorStep(device="cpu"), DeviceProcessorStep(device="cpu"),
+24 -19
View File
@@ -29,6 +29,7 @@ class _NormalizationMixin:
norm_map: dict[FeatureType, NormalizationMode] norm_map: dict[FeatureType, NormalizationMode]
stats: dict[str, dict[str, Any]] | None = None stats: dict[str, dict[str, Any]] | None = None
device: torch.device | str | None = None device: torch.device | str | None = None
dtype: torch.dtype | None = None
eps: float = 1e-8 eps: float = 1e-8
normalize_observation_keys: set[str] | None = None normalize_observation_keys: set[str] | None = None
@@ -56,12 +57,20 @@ class _NormalizationMixin:
# Convert stats to tensors and move to the target device once during initialization. # Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {} self.stats = self.stats or {}
self._tensor_stats = to_tensor(self.stats, device=self.device) if self.dtype is None:
self.dtype = torch.float32
def to(self, device: torch.device | str) -> _NormalizationMixin: self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
) -> _NormalizationMixin:
"""Moves the processor's normalization stats to the specified device and returns self.""" """Moves the processor's normalization stats to the specified device and returns self."""
self.device = device if device is not None:
self._tensor_stats = to_tensor(self.stats, device=self.device) self.device = device
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
return self return self
def state_dict(self) -> dict[str, Tensor]: def state_dict(self) -> dict[str, Tensor]:
@@ -98,12 +107,14 @@ class _NormalizationMixin:
if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys: if self.normalize_observation_keys is not None and key not in self.normalize_observation_keys:
continue continue
if feature.type != FeatureType.ACTION and key in new_observation: if feature.type != FeatureType.ACTION and key in new_observation:
tensor = torch.as_tensor(new_observation[key], dtype=torch.float32) # Convert to tensor but preserve original dtype for adaptation logic
tensor = torch.as_tensor(new_observation[key])
new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse) new_observation[key] = self._apply_transform(tensor, key, feature.type, inverse=inverse)
return new_observation return new_observation
def _normalize_action(self, action: Any, inverse: bool) -> Tensor: def _normalize_action(self, action: Any, inverse: bool) -> Tensor:
tensor = torch.as_tensor(action, dtype=torch.float32) # Convert to tensor but preserve original dtype for adaptation logic
tensor = torch.as_tensor(action)
processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse) processed_action = self._apply_transform(tensor, "action", FeatureType.ACTION, inverse=inverse)
return processed_action return processed_action
@@ -118,19 +129,13 @@ class _NormalizationMixin:
if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX): if norm_mode not in (NormalizationMode.MEAN_STD, NormalizationMode.MIN_MAX):
raise ValueError(f"Unsupported normalization mode: {norm_mode}") raise ValueError(f"Unsupported normalization mode: {norm_mode}")
# Ensure input tensor is on the same device as the stats. # For Accelerate compatibility: Ensure stats are on the same device and dtype as the input tensor
if self.device and tensor.device != self.device: if self._tensor_stats and key in self._tensor_stats:
tensor = tensor.to(self.device) first_stat = next(iter(self._tensor_stats[key].values()))
if first_stat.device != tensor.device or first_stat.dtype != tensor.dtype:
self.to(device=tensor.device, dtype=tensor.dtype)
# For Accelerate compatibility: move stats to match input tensor device
input_device = tensor.device
stats = self._tensor_stats[key] stats = self._tensor_stats[key]
tensor = tensor.to(dtype=torch.float32)
# Move stats to input device if needed
stats_device = next(iter(stats.values())).device
if stats_device != input_device:
stats = to_tensor({key: self._tensor_stats[key]}, device=input_device)[key]
if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats: if norm_mode == NormalizationMode.MEAN_STD and "mean" in stats and "std" in stats:
mean, std = stats["mean"], stats["std"] mean, std = stats["mean"], stats["std"]
@@ -147,7 +152,7 @@ class _NormalizationMixin:
# to prevent division by zero. This consistently maps an input equal to # to prevent division by zero. This consistently maps an input equal to
# min_val to -1, ensuring a stable transformation. # min_val to -1, ensuring a stable transformation.
denom = torch.where( denom = torch.where(
denom == 0, torch.tensor(self.eps, device=input_device, dtype=torch.float32), denom denom == 0, torch.tensor(self.eps, device=tensor.device, dtype=tensor.dtype), denom
) )
if inverse: if inverse:
# Map from [-1, 1] back to [min, max] # Map from [-1, 1] back to [min, max]
@@ -268,5 +273,5 @@ def hotswap_stats(
if isinstance(step, _NormalizationMixin): if isinstance(step, _NormalizationMixin):
step.stats = stats step.stats = stats
# Re-initialize tensor_stats on the correct device. # Re-initialize tensor_stats on the correct device.
step._tensor_stats = to_tensor(stats, device=step.device) step._tensor_stats = to_tensor(stats, device=step.device, dtype=step.dtype)
return rp return rp
+70 -3
View File
@@ -87,9 +87,9 @@ def test_make_act_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 4 assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep) assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -308,6 +308,17 @@ def test_act_processor_mixed_precision():
for step in preprocessor.steps: for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep): if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else: else:
modified_steps.append(step) modified_steps.append(step)
preprocessor.steps = modified_steps preprocessor.steps = modified_steps
@@ -353,3 +364,59 @@ def test_act_processor_batch_consistency():
processed_batched = preprocessor(transition_batched) processed_batched = preprocessor(transition_batched)
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8 assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
assert processed_batched[TransitionKey.ACTION].shape[0] == 8 assert processed_batched[TransitionKey.ACTION].shape[0] == 8
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
+77 -3
View File
@@ -90,9 +90,9 @@ def test_make_diffusion_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 4 assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep) assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -299,6 +299,17 @@ def test_diffusion_processor_mixed_precision():
for step in factory_preprocessor.steps: for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessorStep): if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16")) modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else: else:
modified_steps.append(step) modified_steps.append(step)
@@ -379,3 +390,66 @@ def test_diffusion_processor_batch_consistency():
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == expected_batch
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape[0] == expected_batch
assert processed[TransitionKey.ACTION].shape[0] == expected_batch assert processed[TransitionKey.ACTION].shape[0] == expected_batch
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_diffusion_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
# Get the steps from the factory function
factory_preprocessor, _ = make_diffusion_pre_post_processors(config, stats)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in factory_preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
# Create new processor with modified steps
preprocessor = DataProcessorPipeline(modified_steps, to_transition=lambda x: x, to_output=lambda x: x)
# Verify initial normalizer configuration
normalizer_step = modified_steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(7, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
@@ -1497,3 +1497,92 @@ def test_roundtrip_normalize_unnormalize_non_identity():
out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5 out[TransitionKey.OBSERVATION]["observation.state"], obs["observation.state"], atol=1e-5
) )
assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5) assert torch.allclose(out[TransitionKey.ACTION], act, atol=1e-5)
def test_dtype_adaptation_bfloat16_input_float32_normalizer():
"""Test automatic dtype adaptation: NormalizerProcessor(float32) adapts to bfloat16 input → bfloat16 output"""
features = {"observation.state": PolicyFeature(FeatureType.STATE, (5,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {
"observation.state": {
"mean": np.array([0.0, 0.0, 0.0, 0.0, 0.0]),
"std": np.array([1.0, 1.0, 1.0, 1.0, 1.0]),
}
}
# Create normalizer configured with float32 dtype
normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
)
# Verify initial configuration
assert normalizer.dtype == torch.float32
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.float32
# Create bfloat16 input tensor
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)}
transition = create_transition(observation=observation)
# Process the transition
result = normalizer(transition)
# Verify that:
# 1. Stats were automatically adapted to bfloat16
assert normalizer.dtype == torch.bfloat16
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.bfloat16
# 2. Output is in bfloat16
output_tensor = result[TransitionKey.OBSERVATION]["observation.state"]
assert output_tensor.dtype == torch.bfloat16
# 3. Normalization was applied correctly (mean should be close to original - mean) / std
expected = (
torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0], dtype=torch.bfloat16)
- torch.tensor([0.0, 0.0, 0.0, 0.0, 0.0], dtype=torch.bfloat16)
) / torch.tensor([1.0, 1.0, 1.0, 1.0, 1.0], dtype=torch.bfloat16)
assert torch.allclose(output_tensor, expected, atol=1e-2) # bfloat16 has lower precision
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_dtype_adaptation_device_processor_bfloat16_normalizer_float32():
"""Test policy pipeline scenario: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → bfloat16 output"""
from lerobot.processor import DeviceProcessorStep
features = {"observation.state": PolicyFeature(FeatureType.STATE, (3,))}
norm_map = {FeatureType.STATE: NormalizationMode.MEAN_STD}
stats = {"observation.state": {"mean": np.array([0.0, 0.0, 0.0]), "std": np.array([1.0, 1.0, 1.0])}}
# Create pipeline: DeviceProcessor(bfloat16) → NormalizerProcessor(float32)
device_processor = DeviceProcessorStep(device="cuda", float_dtype="bfloat16")
normalizer = NormalizerProcessorStep(
features=features, norm_map=norm_map, stats=stats, dtype=torch.float32
)
# Verify initial normalizer configuration
assert normalizer.dtype == torch.float32
# Create CPU input
observation = {"observation.state": torch.tensor([1.0, 2.0, 3.0], dtype=torch.float32)}
transition = create_transition(observation=observation)
# Step 1: DeviceProcessor converts to bfloat16 + moves to CUDA
processed_1 = device_processor(transition)
intermediate_tensor = processed_1[TransitionKey.OBSERVATION]["observation.state"]
assert intermediate_tensor.dtype == torch.bfloat16
assert intermediate_tensor.device.type == "cuda"
# Step 2: NormalizerProcessor receives bfloat16 input and adapts
final_result = normalizer(processed_1)
final_tensor = final_result[TransitionKey.OBSERVATION]["observation.state"]
# Verify final output is bfloat16 (automatic adaptation worked)
assert final_tensor.dtype == torch.bfloat16
assert final_tensor.device.type == "cuda"
# Verify normalizer adapted its internal state
assert normalizer.dtype == torch.bfloat16
for stat_tensor in normalizer._tensor_stats["observation.state"].values():
assert stat_tensor.dtype == torch.bfloat16
assert stat_tensor.device.type == "cuda"
+73 -5
View File
@@ -116,11 +116,11 @@ def test_make_pi0_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 6 assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], Pi0NewLineProcessor)
assert isinstance(preprocessor.steps[3], Pi0NewLineProcessor) # Step 3 would be TokenizerProcessorStep but it's mocked
# Step 4 would be TokenizerProcessorStep but it's mocked assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], DeviceProcessorStep) assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -377,3 +377,71 @@ def test_pi0_newline_processor_state_dict():
# Test get_config # Test get_config
config = processor.get_config() config = processor.get_config()
assert config == {} assert config == {}
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_pi0_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
stats = create_default_stats()
config.device = "cuda"
with patch("lerobot.policies.pi0.processor_pi0.TokenizerProcessorStep", MockTokenizerProcessorStep):
preprocessor, _ = make_pi0_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (PI0 has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(10, dtype=torch.float32), # PI0 expects size 10
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32) # PI0 expects size 6
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+77 -5
View File
@@ -92,9 +92,9 @@ def test_make_sac_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 4 assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep) assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -307,9 +307,24 @@ def test_sac_processor_mixed_precision():
) )
# Replace DeviceProcessorStep with one that uses float16 # Replace DeviceProcessorStep with one that uses float16
for i, step in enumerate(preprocessor.steps): modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep): if isinstance(step, DeviceProcessorStep):
preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data # Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)}
@@ -374,3 +389,60 @@ def test_sac_processor_edge_cases():
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10) assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
# When action is None, it may still be present with None value # When action is None, it may still be present with None value
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_sac_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_sac_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data
observation = {OBS_STATE: torch.randn(10, dtype=torch.float32)} # Start with float32
action = torch.randn(5, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
+75 -5
View File
@@ -123,11 +123,11 @@ def test_make_smolvla_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 6 assert len(preprocessor.steps) == 6
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], SmolVLANewLineProcessor)
assert isinstance(preprocessor.steps[3], SmolVLANewLineProcessor) # Step 3 would be TokenizerProcessorStep but it's mocked
# Step 4 would be TokenizerProcessorStep but it's mocked assert isinstance(preprocessor.steps[4], DeviceProcessorStep)
assert isinstance(preprocessor.steps[5], DeviceProcessorStep) assert isinstance(preprocessor.steps[5], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -404,3 +404,73 @@ def test_smolvla_newline_processor_transform_features():
} }
result = processor.transform_features(features) result = processor.transform_features(features)
assert result == features # Should return unchanged assert result == features # Should return unchanged
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_smolvla_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
with patch(
"lerobot.policies.smolvla.processor_smolvla.TokenizerProcessorStep", MockTokenizerProcessorStep
):
preprocessor, _ = make_smolvla_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration (SmolVLA has NormalizerProcessorStep at index 5)
normalizer_step = preprocessor.steps[5] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(8, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(
observation, action, complementary_data={"task": "test bfloat16 adaptation"}
)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+84 -5
View File
@@ -95,9 +95,9 @@ def test_make_tdmpc_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 4 assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep) assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -331,9 +331,24 @@ def test_tdmpc_processor_mixed_precision():
) )
# Replace DeviceProcessorStep with one that uses float16 # Replace DeviceProcessorStep with one that uses float16
for i, step in enumerate(preprocessor.steps): modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep): if isinstance(step, DeviceProcessorStep):
preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data # Create test data
observation = { observation = {
@@ -410,3 +425,67 @@ def test_tdmpc_processor_edge_cases():
processed = preprocessor(transition) processed = preprocessor(transition)
assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) assert processed[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
assert OBS_STATE not in processed[TransitionKey.OBSERVATION] assert OBS_STATE not in processed[TransitionKey.OBSERVATION]
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_tdmpc_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(12, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(6, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check
+85 -5
View File
@@ -95,9 +95,9 @@ def test_make_vqbet_processor_basic():
# Check steps in preprocessor # Check steps in preprocessor
assert len(preprocessor.steps) == 4 assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep) assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep) assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep) assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep) assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor # Check steps in postprocessor
assert len(postprocessor.steps) == 2 assert len(postprocessor.steps) == 2
@@ -324,9 +324,24 @@ def test_vqbet_processor_mixed_precision():
) )
# Replace DeviceProcessorStep with one that uses float16 # Replace DeviceProcessorStep with one that uses float16
for i, step in enumerate(preprocessor.steps): modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep): if isinstance(step, DeviceProcessorStep):
preprocessor.steps[i] = DeviceProcessorStep(device=config.device, float_dtype="float16") modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Create test data # Create test data
observation = { observation = {
@@ -405,3 +420,68 @@ def test_vqbet_processor_sequential_processing():
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8) assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 8)
assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224) assert result[TransitionKey.OBSERVATION][OBS_IMAGE].shape == (1, 3, 224, 224)
assert result[TransitionKey.ACTION].shape == (1, 7) assert result[TransitionKey.ACTION].shape == (1, 7)
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_vqbet_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_vqbet_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data with both state and visual observations
observation = {
OBS_STATE: torch.randn(8, dtype=torch.float32),
OBS_IMAGE: torch.randn(3, 224, 224, dtype=torch.float32),
}
action = torch.randn(7, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert (
processed[TransitionKey.OBSERVATION][OBS_IMAGE].dtype == torch.bfloat16
) # IDENTITY normalization still gets dtype conversion
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
# Check state stats (has normalization)
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16
# OBS_IMAGE uses IDENTITY normalization, so no stats to check