refactor(processor): improve processor pipeline typing with generic type (#1810)

* refactor(processor): introduce generic type for to_output

- Always return `TOutput`
- Remove `_prepare_transition`, so `__call__` now always returns `TOutput`
- Update tests accordingly
- This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor

* refactor(processor): consolidate ProcessorKwargs usage across policies

- Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline.
- Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments.
- Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
Adil Zouitine
2025-09-02 12:57:14 +02:00
committed by GitHub
parent 08fb310eaa
commit d32b76cc66
26 changed files with 847 additions and 220 deletions
+73 -11
View File
@@ -81,7 +81,12 @@ def test_make_tdmpc_processor_basic():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Check processor names
assert preprocessor.name == "robot_preprocessor"
@@ -105,7 +110,12 @@ def test_tdmpc_processor_normalization():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create test data
observation = {
@@ -138,7 +148,12 @@ def test_tdmpc_processor_cuda():
config.device = "cuda"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Create CPU data
observation = {
@@ -171,7 +186,12 @@ def test_tdmpc_processor_accelerate_scenario():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate Accelerate: data already on GPU
device = torch.device("cuda:0")
@@ -198,7 +218,12 @@ def test_tdmpc_processor_multi_gpu():
config.device = "cuda:0"
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Simulate data on different GPU
device = torch.device("cuda:1")
@@ -222,7 +247,22 @@ def test_tdmpc_processor_without_stats():
"""Test TDMPC processor creation without dataset statistics."""
config = create_default_config()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Get the steps from the factory function
factory_preprocessor, factory_postprocessor = make_tdmpc_pre_post_processors(config, dataset_stats=None)
# Create new processors with EnvTransition input/output
preprocessor = RobotProcessor(
factory_preprocessor.steps,
name=factory_preprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
postprocessor = RobotProcessor(
factory_postprocessor.steps,
name=factory_postprocessor.name,
to_transition=lambda x: x,
to_output=lambda x: x,
)
# Should still create processors
assert preprocessor is not None
@@ -245,14 +285,21 @@ def test_tdmpc_processor_save_and_load():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
with tempfile.TemporaryDirectory() as tmpdir:
# Save preprocessor
preprocessor.save_pretrained(tmpdir)
# Load preprocessor
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
loaded_preprocessor = RobotProcessor.from_pretrained(
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
)
# Test that loaded processor works
observation = {
@@ -276,7 +323,12 @@ def test_tdmpc_processor_mixed_precision():
stats = create_default_stats()
# Create processor
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Replace DeviceProcessor with one that uses float16
for i, step in enumerate(preprocessor.steps):
@@ -305,7 +357,12 @@ def test_tdmpc_processor_batch_data():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with batched data
batch_size = 64
@@ -330,7 +387,12 @@ def test_tdmpc_processor_edge_cases():
config = create_default_config()
stats = create_default_stats()
preprocessor, postprocessor = make_tdmpc_pre_post_processors(config, stats)
preprocessor, postprocessor = make_tdmpc_pre_post_processors(
config,
stats,
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
)
# Test with only state observation (no image)
observation = {OBS_STATE: torch.randn(12)}