refactor(processors): Improve Normalization Processor Performance and Device/Dtype Adaptability (#1880)

* refactor(processors): reorder processor steps for consistency across implementations

- Updated the order of processor steps in multiple files to ensure consistency, placing AddBatchDimensionProcessorStep and DeviceProcessorStep before NormalizerProcessorStep.
- Adjusted related test assertions to reflect the new order of steps in the preprocessor, enhancing clarity and maintainability.

* refactor(normalization): remove dtype specification in tensor conversion for adaptation logic

- Updated tensor conversion in the _NormalizationMixin class to remove explicit dtype specification, allowing for automatic adaptation of tensor types.
- Adjusted related tests to ensure proper functionality with the new tensor conversion logic, verifying that normalizers adapt correctly to input types.
This commit is contained in:
Adil Zouitine
2025-09-08 10:46:35 +02:00
committed by GitHub
parent f1cfdfced9
commit d32006440c
17 changed files with 677 additions and 72 deletions
+70 -3
View File
@@ -87,9 +87,9 @@ def test_make_act_processor_basic():
# Check steps in preprocessor
assert len(preprocessor.steps) == 4
assert isinstance(preprocessor.steps[0], RenameProcessorStep)
assert isinstance(preprocessor.steps[1], NormalizerProcessorStep)
assert isinstance(preprocessor.steps[2], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[3], DeviceProcessorStep)
assert isinstance(preprocessor.steps[1], AddBatchDimensionProcessorStep)
assert isinstance(preprocessor.steps[2], DeviceProcessorStep)
assert isinstance(preprocessor.steps[3], NormalizerProcessorStep)
# Check steps in postprocessor
assert len(postprocessor.steps) == 2
@@ -308,6 +308,17 @@ def test_act_processor_mixed_precision():
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="float16"))
elif isinstance(step, NormalizerProcessorStep):
# Update normalizer to use the same device as the device processor
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float16, # Match the float16 dtype
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
@@ -353,3 +364,59 @@ def test_act_processor_batch_consistency():
processed_batched = preprocessor(transition_batched)
assert processed_batched[TransitionKey.OBSERVATION][OBS_STATE].shape[0] == 8
assert processed_batched[TransitionKey.ACTION].shape[0] == 8
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
def test_act_processor_bfloat16_device_float32_normalizer():
"""Test: DeviceProcessor(bfloat16) + NormalizerProcessor(float32) → output bfloat16 via automatic adaptation"""
config = create_default_config()
config.device = "cuda"
stats = create_default_stats()
preprocessor, _ = make_act_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Modify the pipeline to use bfloat16 device processor with float32 normalizer
modified_steps = []
for step in preprocessor.steps:
if isinstance(step, DeviceProcessorStep):
# Device processor converts to bfloat16
modified_steps.append(DeviceProcessorStep(device=config.device, float_dtype="bfloat16"))
elif isinstance(step, NormalizerProcessorStep):
# Normalizer stays configured as float32 (will auto-adapt to bfloat16)
modified_steps.append(
NormalizerProcessorStep(
features=step.features,
norm_map=step.norm_map,
stats=step.stats,
device=config.device,
dtype=torch.float32, # Deliberately configured as float32
)
)
else:
modified_steps.append(step)
preprocessor.steps = modified_steps
# Verify initial normalizer configuration
normalizer_step = preprocessor.steps[3] # NormalizerProcessorStep
assert normalizer_step.dtype == torch.float32
# Create test data
observation = {OBS_STATE: torch.randn(7, dtype=torch.float32)} # Start with float32
action = torch.randn(4, dtype=torch.float32)
transition = create_transition(observation, action)
# Process through full pipeline
processed = preprocessor(transition)
# Verify: DeviceProcessor → bfloat16, NormalizerProcessor adapts → final output is bfloat16
assert processed[TransitionKey.OBSERVATION][OBS_STATE].dtype == torch.bfloat16
assert processed[TransitionKey.ACTION].dtype == torch.bfloat16
# Verify normalizer automatically adapted its internal state
assert normalizer_step.dtype == torch.bfloat16
for stat_tensor in normalizer_step._tensor_stats[OBS_STATE].values():
assert stat_tensor.dtype == torch.bfloat16