mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(processor): improve processor pipeline typing with generic type (#1810)
* refactor(processor): introduce generic type for to_output - Always return `TOutput` - Remove `_prepare_transition`, so `__call__` now always returns `TOutput` - Update tests accordingly - This refactor paves the way for adding settings for `to_transition` and `to_output` in `make_processor` and the post-processor * refactor(processor): consolidate ProcessorKwargs usage across policies - Removed the ProcessorTypes module and integrated ProcessorKwargs directly into the processor pipeline. - Updated multiple policy files to utilize the new ProcessorKwargs structure for preprocessor and postprocessor arguments. - Simplified the handling of processor kwargs by initializing them to empty dictionaries when not provided.
This commit is contained in:
@@ -245,7 +245,7 @@ def test_mixed_observation():
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test ToBatchProcessor integration with RobotProcessor."""
|
||||
to_batch_processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([to_batch_processor])
|
||||
pipeline = RobotProcessor([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create unbatched observation
|
||||
observation = {
|
||||
@@ -285,7 +285,9 @@ def test_serialization_methods():
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor], name="BatchPipeline")
|
||||
pipeline = RobotProcessor(
|
||||
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save pipeline
|
||||
@@ -296,7 +298,9 @@ def test_save_and_load_pretrained():
|
||||
assert config_path.exists()
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "BatchPipeline"
|
||||
assert len(loaded_pipeline) == 1
|
||||
@@ -323,11 +327,13 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test saving and loading using registry name."""
|
||||
processor = ToBatchProcessor()
|
||||
pipeline = RobotProcessor([processor])
|
||||
pipeline = RobotProcessor([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
pipeline.save_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(tmp_dir)
|
||||
loaded_pipeline = RobotProcessor.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
|
||||
# Verify the loaded processor works
|
||||
observation = {
|
||||
|
||||
Reference in New Issue
Block a user