mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
refactor(processor): improve processor pipeline typing with generic type (#1810)
* refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
@@ -78,7 +78,12 @@ def test_make_sac_processor_basic():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Check processor names
|
||||
assert preprocessor.name == "robot_preprocessor"
|
||||
@@ -102,7 +107,12 @@ def test_sac_processor_normalization_modes():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create test data
|
||||
observation = {OBS_STATE: torch.randn(10) * 2} # Larger values to test normalization
|
||||
@@ -133,7 +143,12 @@ def test_sac_processor_cuda():
|
||||
config.device = "cuda"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Create CPU data
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -162,7 +177,12 @@ def test_sac_processor_accelerate_scenario():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate Accelerate: data already on GPU
|
||||
device = torch.device("cuda:0")
|
||||
@@ -185,7 +205,12 @@ def test_sac_processor_multi_gpu():
|
||||
config.device = "cuda:0"
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Simulate data on different GPU
|
||||
device = torch.device("cuda:1")
|
||||
@@ -205,7 +230,22 @@ def test_sac_processor_without_stats():
|
||||
"""Test SAC processor creation without dataset statistics."""
|
||||
config = create_default_config()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
# Get the steps from the factory function
|
||||
factory_preprocessor, factory_postprocessor = make_sac_pre_post_processors(config, dataset_stats=None)
|
||||
|
||||
# Create new processors with EnvTransition input/output
|
||||
preprocessor = RobotProcessor(
|
||||
factory_preprocessor.steps,
|
||||
name=factory_preprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
postprocessor = RobotProcessor(
|
||||
factory_postprocessor.steps,
|
||||
name=factory_postprocessor.name,
|
||||
to_transition=lambda x: x,
|
||||
to_output=lambda x: x,
|
||||
)
|
||||
|
||||
# Should still create processors
|
||||
assert preprocessor is not None
|
||||
@@ -225,14 +265,21 @@ def test_sac_processor_save_and_load():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmpdir:
|
||||
# Save preprocessor
|
||||
preprocessor.save_pretrained(tmpdir)
|
||||
|
||||
# Load preprocessor
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(tmpdir)
|
||||
loaded_preprocessor = RobotProcessor.from_pretrained(
|
||||
tmpdir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Test that loaded processor works
|
||||
observation = {OBS_STATE: torch.randn(10)}
|
||||
@@ -252,7 +299,12 @@ def test_sac_processor_mixed_precision():
|
||||
stats = create_default_stats()
|
||||
|
||||
# Create processor
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Replace DeviceProcessor with one that uses float16
|
||||
for i, step in enumerate(preprocessor.steps):
|
||||
@@ -277,7 +329,12 @@ def test_sac_processor_batch_data():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with batched data
|
||||
batch_size = 32
|
||||
@@ -298,7 +355,12 @@ def test_sac_processor_edge_cases():
|
||||
config = create_default_config()
|
||||
stats = create_default_stats()
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(config, stats)
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config,
|
||||
stats,
|
||||
preprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
postprocessor_kwargs={"to_transition": lambda x: x, "to_output": lambda x: x},
|
||||
)
|
||||
|
||||
# Test with empty observation
|
||||
transition = create_transition(observation={}, action=torch.randn(5))
|
||||
|
||||
Reference in New Issue
Block a user