mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 18:20:08 +00:00
test(processor): all processors use now the same create_transition (#1906)
* test(processor): all processors use now the same create_transition * test(processor): use identity instead of lambda for transition in pipelines
This commit is contained in:
@@ -26,25 +26,11 @@ from lerobot.processor import (
|
||||
RenameObservationsProcessorStep,
|
||||
TransitionKey,
|
||||
)
|
||||
from lerobot.processor.converters import create_transition, identity_transition
|
||||
from lerobot.processor.rename_processor import rename_stats
|
||||
from tests.conftest import assert_contract_is_typed
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_basic_renaming():
|
||||
"""Test basic key renaming functionality."""
|
||||
rename_map = {
|
||||
@@ -193,7 +179,9 @@ def test_integration_with_robot_processor():
|
||||
}
|
||||
rename_processor = RenameObservationsProcessorStep(rename_map=rename_map)
|
||||
|
||||
pipeline = DataProcessorPipeline([rename_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = DataProcessorPipeline(
|
||||
[rename_processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
observation = {
|
||||
"agent_pos": np.array([1.0, 2.0, 3.0]),
|
||||
@@ -244,7 +232,7 @@ def test_save_and_load_pretrained():
|
||||
|
||||
# Load pipeline
|
||||
loaded_pipeline = DataProcessorPipeline.from_pretrained(
|
||||
tmp_dir, to_transition=lambda x: x, to_output=lambda x: x
|
||||
tmp_dir, to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
assert loaded_pipeline.name == "TestRenameProcessorStep"
|
||||
@@ -286,7 +274,9 @@ def test_registry_functionality():
|
||||
def test_registry_based_save_load():
|
||||
"""Test save/load using registry name instead of module path."""
|
||||
processor = RenameObservationsProcessorStep(rename_map={"key1": "renamed_key1"})
|
||||
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
# Save and load
|
||||
@@ -328,7 +318,7 @@ def test_chained_rename_processors():
|
||||
)
|
||||
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor1, processor2], to_transition=lambda x: x, to_output=lambda x: x
|
||||
[processor1, processor2], to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
observation = {
|
||||
|
||||
Reference in New Issue
Block a user