mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
chore(processor): add Step suffix to all processors (#1854)
* refactor(processor): rename MapDeltaActionToRobotAction and MapTensorToDeltaActionDict for consistency * refactor(processor): rename DeviceProcessor to DeviceProcessorStep for consistency across modules * refactor(processor): rename Torch2NumpyActionProcessor to Torch2NumpyActionProcessorStep for consistency * refactor(processor): rename Numpy2TorchActionProcessor to Numpy2TorchActionProcessorStep for consistency * refactor(processor): rename AddTeleopActionAsComplimentaryData to AddTeleopActionAsComplimentaryDataStep for consistency * refactor(processor): rename ImageCropResizeProcessor and AddTeleopEventsAsInfo for consistency * refactor(processor): rename TimeLimitProcessor to TimeLimitProcessorStep for consistency * refactor(processor): rename GripperPenaltyProcessor to GripperPenaltyProcessorStep for consistency * refactor(processor): rename InterventionActionProcessor to InterventionActionProcessorStep for consistency * refactor(processor): rename RewardClassifierProcessor to RewardClassifierProcessorStep for consistency * refactor(processor): rename JointVelocityProcessor to JointVelocityProcessorStep for consistency * refactor(processor): rename MotorCurrentProcessor to MotorCurrentProcessorStep for consistency * refactor(processor): rename NormalizerProcessor and UnnormalizerProcessor to NormalizerProcessorStep and UnnormalizerProcessorStep for consistency * refactor(processor): rename VanillaObservationProcessor to VanillaObservationProcessorStep for consistency * refactor(processor): rename RenameProcessor to RenameProcessorStep for consistency * refactor(processor): rename TokenizerProcessor to TokenizerProcessorStep for consistency * refactor(processor): rename ToBatchProcessor to AddBatchDimensionProcessorStep for consistency * refactor(processor): update config file name in test for RenameProcessorStep consistency
This commit is contained in:
@@ -22,7 +22,12 @@ import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.processor import DataProcessorPipeline, ProcessorStepRegistry, ToBatchProcessor, TransitionKey
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DataProcessorPipeline,
|
||||
ProcessorStepRegistry,
|
||||
TransitionKey,
|
||||
)
|
||||
|
||||
|
||||
def create_transition(
|
||||
@@ -42,7 +47,7 @@ def create_transition(
|
||||
|
||||
def test_state_1d_to_2d():
|
||||
"""Test that 1D state tensors get unsqueezed to 2D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.state
|
||||
state_1d = torch.randn(7)
|
||||
@@ -58,7 +63,7 @@ def test_state_1d_to_2d():
|
||||
|
||||
def test_env_state_1d_to_2d():
|
||||
"""Test that 1D environment state tensors get unsqueezed to 2D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.environment_state
|
||||
env_state_1d = torch.randn(10)
|
||||
@@ -74,7 +79,7 @@ def test_env_state_1d_to_2d():
|
||||
|
||||
def test_image_3d_to_4d():
|
||||
"""Test that 3D image tensors get unsqueezed to 4D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.image
|
||||
image_3d = torch.randn(224, 224, 3)
|
||||
@@ -90,7 +95,7 @@ def test_image_3d_to_4d():
|
||||
|
||||
def test_multiple_images_3d_to_4d():
|
||||
"""Test that 3D image tensors in observation.images.* get unsqueezed to 4D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test observation.images.camera1 and observation.images.camera2
|
||||
image1_3d = torch.randn(64, 64, 3)
|
||||
@@ -115,7 +120,7 @@ def test_multiple_images_3d_to_4d():
|
||||
|
||||
def test_already_batched_tensors_unchanged():
|
||||
"""Test that already batched tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
state_2d = torch.randn(1, 7)
|
||||
@@ -141,7 +146,7 @@ def test_already_batched_tensors_unchanged():
|
||||
|
||||
def test_higher_dimensional_tensors_unchanged():
|
||||
"""Test that tensors with more dimensions than expected remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors with more dimensions
|
||||
state_3d = torch.randn(2, 7, 5) # More than 1D
|
||||
@@ -164,7 +169,7 @@ def test_higher_dimensional_tensors_unchanged():
|
||||
|
||||
def test_non_tensor_values_unchanged():
|
||||
"""Test that non-tensor values in observations remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
observation = {
|
||||
OBS_STATE: [1, 2, 3], # List, not tensor
|
||||
@@ -187,7 +192,7 @@ def test_non_tensor_values_unchanged():
|
||||
|
||||
def test_none_observation():
|
||||
"""Test processor handles None observation gracefully."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(observation=None)
|
||||
result = processor(transition)
|
||||
@@ -197,7 +202,7 @@ def test_none_observation():
|
||||
|
||||
def test_empty_observation():
|
||||
"""Test processor handles empty observation dict."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation)
|
||||
@@ -209,7 +214,7 @@ def test_empty_observation():
|
||||
|
||||
def test_mixed_observation():
|
||||
"""Test processor with mixed observation containing various types and dimensions."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
state_1d = torch.randn(5)
|
||||
env_state_2d = torch.randn(1, 8) # Already batched
|
||||
@@ -241,8 +246,8 @@ def test_mixed_observation():
|
||||
|
||||
|
||||
def test_integration_with_robot_processor():
|
||||
"""Test ToBatchProcessor integration with RobotProcessor."""
|
||||
to_batch_processor = ToBatchProcessor()
|
||||
"""Test AddBatchDimensionProcessorStep integration with RobotProcessor."""
|
||||
to_batch_processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline([to_batch_processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
# Create unbatched observation
|
||||
@@ -261,7 +266,7 @@ def test_integration_with_robot_processor():
|
||||
|
||||
def test_serialization_methods():
|
||||
"""Test get_config, state_dict, load_state_dict, and reset methods."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test get_config
|
||||
config = processor.get_config()
|
||||
@@ -281,8 +286,8 @@ def test_serialization_methods():
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
"""Test saving and loading ToBatchProcessor with RobotProcessor."""
|
||||
processor = ToBatchProcessor()
|
||||
"""Test saving and loading AddBatchDimensionProcessorStep with RobotProcessor."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline(
|
||||
[processor], name="BatchPipeline", to_transition=lambda x: x, to_output=lambda x: x
|
||||
)
|
||||
@@ -302,7 +307,7 @@ def test_save_and_load_pretrained():
|
||||
|
||||
assert loaded_pipeline.name == "BatchPipeline"
|
||||
assert len(loaded_pipeline) == 1
|
||||
assert isinstance(loaded_pipeline.steps[0], ToBatchProcessor)
|
||||
assert isinstance(loaded_pipeline.steps[0], AddBatchDimensionProcessorStep)
|
||||
|
||||
# Test functionality of loaded processor
|
||||
observation = {OBS_STATE: torch.randn(5)}
|
||||
@@ -313,10 +318,10 @@ def test_save_and_load_pretrained():
|
||||
|
||||
|
||||
def test_registry_functionality():
|
||||
"""Test that ToBatchProcessor is properly registered."""
|
||||
"""Test that AddBatchDimensionProcessorStep is properly registered."""
|
||||
# Check that the processor is registered
|
||||
registered_class = ProcessorStepRegistry.get("to_batch_processor")
|
||||
assert registered_class is ToBatchProcessor
|
||||
assert registered_class is AddBatchDimensionProcessorStep
|
||||
|
||||
# Check that it's in the list of registered processors
|
||||
assert "to_batch_processor" in ProcessorStepRegistry.list()
|
||||
@@ -324,7 +329,7 @@ def test_registry_functionality():
|
||||
|
||||
def test_registry_based_save_load():
|
||||
"""Test saving and loading using registry name."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
pipeline = DataProcessorPipeline([processor], to_transition=lambda x: x, to_output=lambda x: x)
|
||||
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
@@ -350,7 +355,7 @@ def test_registry_based_save_load():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_device_compatibility():
|
||||
"""Test processor works with tensors on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors on GPU
|
||||
state_1d = torch.randn(7, device="cuda")
|
||||
@@ -374,7 +379,7 @@ def test_device_compatibility():
|
||||
|
||||
def test_processor_preserves_other_transition_keys():
|
||||
"""Test that processor only modifies observation and preserves other transition keys."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action = torch.randn(5)
|
||||
reward = 1.5
|
||||
@@ -411,7 +416,7 @@ def test_processor_preserves_other_transition_keys():
|
||||
|
||||
def test_edge_case_zero_dimensional_tensors():
|
||||
"""Test processor handles 0D tensors (scalars) correctly."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# 0D tensors should not be modified
|
||||
scalar_tensor = torch.tensor(42.0)
|
||||
@@ -433,7 +438,7 @@ def test_edge_case_zero_dimensional_tensors():
|
||||
# Action-specific tests
|
||||
def test_action_1d_to_2d():
|
||||
"""Test that 1D action tensors get batch dimension added."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 1D action tensor
|
||||
action_1d = torch.randn(4)
|
||||
@@ -448,7 +453,7 @@ def test_action_1d_to_2d():
|
||||
|
||||
def test_action_already_batched():
|
||||
"""Test that already batched action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various batch sizes
|
||||
action_batched_1 = torch.randn(1, 4)
|
||||
@@ -467,7 +472,7 @@ def test_action_already_batched():
|
||||
|
||||
def test_action_higher_dimensional():
|
||||
"""Test that higher dimensional action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# 3D action tensor (e.g., sequence of actions)
|
||||
action_3d = torch.randn(2, 4, 3)
|
||||
@@ -484,7 +489,7 @@ def test_action_higher_dimensional():
|
||||
|
||||
def test_action_scalar_tensor():
|
||||
"""Test that scalar (0D) action tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action_scalar = torch.tensor(1.5)
|
||||
transition = create_transition(action=action_scalar)
|
||||
@@ -497,7 +502,7 @@ def test_action_scalar_tensor():
|
||||
|
||||
def test_action_non_tensor():
|
||||
"""Test that non-tensor actions remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# List action
|
||||
action_list = [0.1, 0.2, 0.3, 0.4]
|
||||
@@ -526,7 +531,7 @@ def test_action_non_tensor():
|
||||
|
||||
def test_action_none():
|
||||
"""Test that None action is handled correctly."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(action=None)
|
||||
result = processor(transition)
|
||||
@@ -535,7 +540,7 @@ def test_action_none():
|
||||
|
||||
def test_action_with_observation():
|
||||
"""Test action processing together with observation processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Both need batching
|
||||
observation = {
|
||||
@@ -555,7 +560,7 @@ def test_action_with_observation():
|
||||
|
||||
def test_action_different_sizes():
|
||||
"""Test action processing with various action dimensions."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Different action sizes (robot with different DOF)
|
||||
action_sizes = [1, 2, 4, 7, 10, 20]
|
||||
@@ -572,7 +577,7 @@ def test_action_different_sizes():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_action_device_compatibility():
|
||||
"""Test action processing on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# CUDA action
|
||||
action_cuda = torch.randn(4, device="cuda")
|
||||
@@ -593,7 +598,7 @@ def test_action_device_compatibility():
|
||||
|
||||
def test_action_dtype_preservation():
|
||||
"""Test that action dtype is preserved during processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Different dtypes
|
||||
dtypes = [torch.float32, torch.float64, torch.int32, torch.int64]
|
||||
@@ -609,7 +614,7 @@ def test_action_dtype_preservation():
|
||||
|
||||
def test_empty_action_tensor():
|
||||
"""Test handling of empty action tensors."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Empty 1D tensor
|
||||
action_empty = torch.tensor([])
|
||||
@@ -631,7 +636,7 @@ def test_empty_action_tensor():
|
||||
# Task-specific tests
|
||||
def test_task_string_to_list():
|
||||
"""Test that string tasks get wrapped in lists to add batch dimension."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create complementary data with string task
|
||||
complementary_data = {"task": "pick_cube"}
|
||||
@@ -648,7 +653,7 @@ def test_task_string_to_list():
|
||||
|
||||
def test_task_string_validation():
|
||||
"""Test that only string and list of strings are valid task values."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Valid string task - should be converted to list
|
||||
complementary_data = {"task": "valid_task"}
|
||||
@@ -667,7 +672,7 @@ def test_task_string_validation():
|
||||
|
||||
def test_task_list_of_strings():
|
||||
"""Test that lists of strings remain unchanged (already batched)."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various list of strings
|
||||
test_lists = [
|
||||
@@ -693,7 +698,7 @@ def test_task_list_of_strings():
|
||||
|
||||
def test_complementary_data_none():
|
||||
"""Test processor handles None complementary_data gracefully."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(complementary_data=None)
|
||||
result = processor(transition)
|
||||
@@ -703,7 +708,7 @@ def test_complementary_data_none():
|
||||
|
||||
def test_complementary_data_empty():
|
||||
"""Test processor handles empty complementary_data dict."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
@@ -715,7 +720,7 @@ def test_complementary_data_empty():
|
||||
|
||||
def test_complementary_data_no_task():
|
||||
"""Test processor handles complementary_data without task field."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"episode_id": 123,
|
||||
@@ -733,7 +738,7 @@ def test_complementary_data_no_task():
|
||||
|
||||
def test_complementary_data_mixed():
|
||||
"""Test processor with mixed complementary_data containing task and other fields."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"task": "stack_blocks",
|
||||
@@ -758,7 +763,7 @@ def test_complementary_data_mixed():
|
||||
|
||||
def test_task_with_observation_and_action():
|
||||
"""Test task processing together with observation and action processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# All components need batching
|
||||
observation = {
|
||||
@@ -783,7 +788,7 @@ def test_task_with_observation_and_action():
|
||||
|
||||
def test_task_comprehensive_string_cases():
|
||||
"""Test task processing with comprehensive string cases and edge cases."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test various string formats
|
||||
string_tasks = [
|
||||
@@ -841,7 +846,7 @@ def test_task_comprehensive_string_cases():
|
||||
|
||||
def test_task_preserves_other_keys():
|
||||
"""Test that task processing preserves other keys in complementary_data."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"task": "clean_table",
|
||||
@@ -869,7 +874,7 @@ def test_task_preserves_other_keys():
|
||||
# Index and task_index specific tests
|
||||
def test_index_scalar_to_1d():
|
||||
"""Test that 0D index tensor gets unsqueezed to 1D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D index tensor (scalar)
|
||||
index_0d = torch.tensor(42, dtype=torch.int64)
|
||||
@@ -886,7 +891,7 @@ def test_index_scalar_to_1d():
|
||||
|
||||
def test_task_index_scalar_to_1d():
|
||||
"""Test that 0D task_index tensor gets unsqueezed to 1D."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D task_index tensor (scalar)
|
||||
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
||||
@@ -903,7 +908,7 @@ def test_task_index_scalar_to_1d():
|
||||
|
||||
def test_index_and_task_index_together():
|
||||
"""Test processing both index and task_index together."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create 0D tensors for both
|
||||
index_0d = torch.tensor(100, dtype=torch.int64)
|
||||
@@ -933,7 +938,7 @@ def test_index_and_task_index_together():
|
||||
|
||||
def test_index_already_batched():
|
||||
"""Test that already batched index tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
index_1d = torch.tensor([42], dtype=torch.int64)
|
||||
@@ -954,7 +959,7 @@ def test_index_already_batched():
|
||||
|
||||
def test_task_index_already_batched():
|
||||
"""Test that already batched task_index tensors remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create already batched tensors
|
||||
task_index_1d = torch.tensor([7], dtype=torch.int64)
|
||||
@@ -975,7 +980,7 @@ def test_task_index_already_batched():
|
||||
|
||||
def test_index_non_tensor_unchanged():
|
||||
"""Test that non-tensor index values remain unchanged."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {
|
||||
"index": 42, # Plain int, not tensor
|
||||
@@ -992,7 +997,7 @@ def test_index_non_tensor_unchanged():
|
||||
|
||||
def test_index_dtype_preservation():
|
||||
"""Test that index and task_index dtype is preserved during processing."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Test different dtypes
|
||||
dtypes = [torch.int32, torch.int64, torch.long]
|
||||
@@ -1015,7 +1020,7 @@ def test_index_dtype_preservation():
|
||||
|
||||
def test_index_with_full_transition():
|
||||
"""Test index/task_index processing with full transition data."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create full transition with all components
|
||||
observation = {
|
||||
@@ -1057,7 +1062,7 @@ def test_index_with_full_transition():
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_index_device_compatibility():
|
||||
"""Test processor works with index/task_index tensors on different devices."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Create tensors on GPU
|
||||
index_0d = torch.tensor(42, dtype=torch.int64, device="cuda")
|
||||
@@ -1081,7 +1086,7 @@ def test_index_device_compatibility():
|
||||
|
||||
def test_empty_index_tensor():
|
||||
"""Test handling of empty index tensors."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# Empty 0D tensor doesn't make sense, but test empty 1D
|
||||
index_empty = torch.tensor([], dtype=torch.int64)
|
||||
@@ -1096,7 +1101,7 @@ def test_empty_index_tensor():
|
||||
|
||||
def test_action_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed action."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
action = torch.randn(4)
|
||||
transition = create_transition(action=action)
|
||||
@@ -1118,7 +1123,7 @@ def test_action_processing_creates_new_transition():
|
||||
|
||||
def test_task_processing_creates_new_transition():
|
||||
"""Test that the processor creates a new transition object with correctly processed task."""
|
||||
processor = ToBatchProcessor()
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data)
|
||||
|
||||
Reference in New Issue
Block a user