mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-17 09:39:47 +00:00
refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908)
* refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -49,7 +49,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -74,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -123,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([1.0, 2.0]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -136,7 +136,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0]))
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
@@ -144,7 +144,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -153,13 +153,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -171,7 +171,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
|
||||
@@ -38,7 +38,7 @@ def test_state_1d_to_2d():
|
||||
# Test observation.state
|
||||
state_1d = torch.randn(7)
|
||||
observation = {OBS_STATE: state_1d}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -54,7 +54,7 @@ def test_env_state_1d_to_2d():
|
||||
# Test observation.environment_state
|
||||
env_state_1d = torch.randn(10)
|
||||
observation = {OBS_ENV_STATE: env_state_1d}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -70,7 +70,7 @@ def test_image_3d_to_4d():
|
||||
# Test observation.image
|
||||
image_3d = torch.randn(224, 224, 3)
|
||||
observation = {OBS_IMAGE: image_3d}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d():
|
||||
f"{OBS_IMAGES}.camera1": image1_3d,
|
||||
f"{OBS_IMAGES}.camera2": image2_3d,
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged():
|
||||
OBS_ENV_STATE: env_state_2d,
|
||||
OBS_IMAGE: image_4d,
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged():
|
||||
OBS_STATE: state_3d,
|
||||
OBS_IMAGE: image_5d,
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged():
|
||||
"custom_key": 42, # Integer
|
||||
"another_key": {"nested": "dict"}, # Dict
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -180,7 +180,7 @@ def test_none_observation():
|
||||
"""Test processor handles None observation gracefully."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(observation={}, action={})
|
||||
transition = create_transition(observation={}, action=torch.empty(0))
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.OBSERVATION] == {}
|
||||
@@ -191,7 +191,7 @@ def test_empty_observation():
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
observation = {}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -216,7 +216,7 @@ def test_mixed_observation():
|
||||
"other_tensor": other_tensor,
|
||||
"non_tensor": "string_value",
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
@@ -243,7 +243,7 @@ def test_integration_with_robot_processor():
|
||||
OBS_STATE: torch.randn(7),
|
||||
OBS_IMAGE: torch.randn(224, 224, 3),
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
@@ -299,7 +299,7 @@ def test_save_and_load_pretrained():
|
||||
|
||||
# Test functionality of loaded processor
|
||||
observation = {OBS_STATE: torch.randn(5)}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = loaded_pipeline(transition)
|
||||
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
||||
@@ -333,7 +333,7 @@ def test_registry_based_save_load():
|
||||
OBS_STATE: torch.randn(3),
|
||||
OBS_IMAGE: torch.randn(100, 100, 3),
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
@@ -355,7 +355,7 @@ def test_device_compatibility():
|
||||
OBS_STATE: state_1d,
|
||||
OBS_IMAGE: image_3d,
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
@@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors():
|
||||
OBS_STATE: scalar_tensor,
|
||||
"scalar_value": scalar_tensor,
|
||||
}
|
||||
transition = create_transition(observation=observation, action={})
|
||||
transition = create_transition(observation=observation, action=torch.empty(0))
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
@@ -490,42 +490,43 @@ def test_action_scalar_tensor():
|
||||
assert torch.equal(result[TransitionKey.ACTION], action_scalar)
|
||||
|
||||
|
||||
def test_action_non_tensor():
|
||||
"""Test that non-tensor actions remain unchanged."""
|
||||
def test_action_non_tensor_raises_error():
|
||||
"""Test that non-tensor actions raise ValueError for PolicyAction processors."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
# List action
|
||||
# List action should raise error
|
||||
action_list = [0.1, 0.2, 0.3, 0.4]
|
||||
transition = create_transition(action=action_list, observation={})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == action_list
|
||||
transition = create_transition(action=action_list)
|
||||
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||
processor(transition)
|
||||
|
||||
# Numpy array action (as Python object, not converted)
|
||||
# Numpy array action should raise error
|
||||
action_numpy = np.array([1, 2, 3, 4])
|
||||
transition = create_transition(action=action_numpy, observation={})
|
||||
result = processor(transition)
|
||||
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
||||
transition = create_transition(action=action_numpy)
|
||||
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||
processor(transition)
|
||||
|
||||
# String action (edge case)
|
||||
# String action should raise error
|
||||
action_string = "forward"
|
||||
transition = create_transition(action=action_string, observation={})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == action_string
|
||||
transition = create_transition(action=action_string)
|
||||
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||
processor(transition)
|
||||
|
||||
# Dict action (structured action)
|
||||
# Dict action should raise error
|
||||
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
|
||||
transition = create_transition(action=action_dict, observation={})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == action_dict
|
||||
transition = create_transition(action=action_dict)
|
||||
with pytest.raises(ValueError, match="Action should be a PolicyAction type"):
|
||||
processor(transition)
|
||||
|
||||
|
||||
def test_action_none():
|
||||
"""Test that None action is handled correctly."""
|
||||
"""Test that empty action tensor is handled correctly."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(action={}, observation={})
|
||||
transition = create_transition(action=torch.empty(0), observation={})
|
||||
result = processor(transition)
|
||||
assert result[TransitionKey.ACTION] == {}
|
||||
# Empty 1D tensor becomes empty 2D tensor with batch dimension
|
||||
assert result[TransitionKey.ACTION].shape == (1, 0)
|
||||
|
||||
|
||||
def test_action_with_observation():
|
||||
@@ -630,7 +631,9 @@ def test_task_string_to_list():
|
||||
|
||||
# Create complementary data with string task
|
||||
complementary_data = {"task": "pick_cube"}
|
||||
transition = create_transition(action={}, observation={}, complementary_data=complementary_data)
|
||||
transition = create_transition(
|
||||
action=torch.empty(0), observation={}, complementary_data=complementary_data
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -647,14 +650,18 @@ def test_task_string_validation():
|
||||
|
||||
# Valid string task - should be converted to list
|
||||
complementary_data = {"task": "valid_task"}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == ["valid_task"]
|
||||
|
||||
# Valid list of strings - should remain unchanged
|
||||
complementary_data = {"task": ["task1", "task2"]}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
assert processed_comp_data["task"] == ["task1", "task2"]
|
||||
@@ -676,7 +683,9 @@ def test_task_list_of_strings():
|
||||
|
||||
for task_list in test_lists:
|
||||
complementary_data = {"task": task_list}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -690,7 +699,7 @@ def test_complementary_data_none():
|
||||
"""Test processor handles None complementary_data gracefully."""
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
transition = create_transition(complementary_data=None, action={}, observation={})
|
||||
transition = create_transition(complementary_data=None, action=torch.empty(0), observation={})
|
||||
result = processor(transition)
|
||||
|
||||
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
||||
@@ -701,7 +710,9 @@ def test_complementary_data_empty():
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -717,7 +728,9 @@ def test_complementary_data_no_task():
|
||||
"timestamp": 1234567890.0,
|
||||
"extra_info": "some data",
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -736,7 +749,9 @@ def test_complementary_data_mixed():
|
||||
"difficulty": "hard",
|
||||
"metadata": {"scene": "kitchen"},
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -803,7 +818,9 @@ def test_task_comprehensive_string_cases():
|
||||
# Test that all string tasks get properly batched
|
||||
for task in string_tasks:
|
||||
complementary_data = {"task": task}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -825,7 +842,9 @@ def test_task_comprehensive_string_cases():
|
||||
|
||||
for task_list in list_tasks:
|
||||
complementary_data = {"task": task_list}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -845,7 +864,9 @@ def test_task_preserves_other_keys():
|
||||
"config": {"speed": "slow", "precision": "high"},
|
||||
"metrics": [1.0, 2.0, 3.0],
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -869,7 +890,9 @@ def test_index_scalar_to_1d():
|
||||
# Create 0D index tensor (scalar)
|
||||
index_0d = torch.tensor(42, dtype=torch.int64)
|
||||
complementary_data = {"index": index_0d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -886,7 +909,9 @@ def test_task_index_scalar_to_1d():
|
||||
# Create 0D task_index tensor (scalar)
|
||||
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
||||
complementary_data = {"task_index": task_index_0d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -908,7 +933,9 @@ def test_index_and_task_index_together():
|
||||
"task_index": task_index_0d,
|
||||
"task": "pick_object",
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -936,13 +963,17 @@ def test_index_already_batched():
|
||||
|
||||
# Test 1D (already batched)
|
||||
complementary_data = {"index": index_1d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
|
||||
|
||||
# Test 2D
|
||||
complementary_data = {"index": index_2d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
|
||||
|
||||
@@ -957,13 +988,17 @@ def test_task_index_already_batched():
|
||||
|
||||
# Test 1D (already batched)
|
||||
complementary_data = {"task_index": task_index_1d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
|
||||
|
||||
# Test 2D
|
||||
complementary_data = {"task_index": task_index_2d}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
result = processor(transition)
|
||||
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
|
||||
|
||||
@@ -976,7 +1011,9 @@ def test_index_non_tensor_unchanged():
|
||||
"index": 42, # Plain int, not tensor
|
||||
"task_index": [1, 2, 3], # List, not tensor
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -999,7 +1036,9 @@ def test_index_dtype_preservation():
|
||||
"index": index_0d,
|
||||
"task_index": task_index_0d,
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -1062,7 +1101,9 @@ def test_index_device_compatibility():
|
||||
"index": index_0d,
|
||||
"task_index": task_index_0d,
|
||||
}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
||||
@@ -1081,7 +1122,9 @@ def test_empty_index_tensor():
|
||||
# Empty 0D tensor doesn't make sense, but test empty 1D
|
||||
index_empty = torch.tensor([], dtype=torch.int64)
|
||||
complementary_data = {"index": index_empty}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
result = processor(transition)
|
||||
|
||||
@@ -1116,7 +1159,9 @@ def test_task_processing_creates_new_transition():
|
||||
processor = AddBatchDimensionProcessorStep()
|
||||
|
||||
complementary_data = {"task": "sort_objects"}
|
||||
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
||||
transition = create_transition(
|
||||
complementary_data=complementary_data, observation={}, action=torch.empty(0)
|
||||
)
|
||||
|
||||
# Store reference to original transition and complementary_data
|
||||
original_transition = transition
|
||||
|
||||
@@ -329,14 +329,14 @@ def test_min_max_unnormalization(action_stats_min_max):
|
||||
assert torch.allclose(unnormalized_action, expected)
|
||||
|
||||
|
||||
def test_numpy_action_input(action_stats_mean_std):
|
||||
def test_tensor_action_input(action_stats_mean_std):
|
||||
features = _create_action_features()
|
||||
norm_map = _create_action_norm_map_mean_std()
|
||||
unnormalizer = UnnormalizerProcessorStep(
|
||||
features=features, norm_map=norm_map, stats={"action": action_stats_mean_std}
|
||||
)
|
||||
|
||||
normalized_action = np.array([1.0, -0.5, 2.0], dtype=np.float32)
|
||||
normalized_action = torch.tensor([1.0, -0.5, 2.0], dtype=torch.float32)
|
||||
transition = create_transition(action=normalized_action)
|
||||
|
||||
unnormalized_transition = unnormalizer(transition)
|
||||
|
||||
@@ -371,12 +371,12 @@ def test_sac_processor_edge_cases():
|
||||
assert processed[TransitionKey.OBSERVATION] == {}
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
# Test with None action
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action={})
|
||||
# Test with zero action (representing "null" action)
|
||||
transition = create_transition(observation={OBS_STATE: torch.randn(10)}, action=torch.zeros(5))
|
||||
processed = preprocessor(transition)
|
||||
assert processed[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 10)
|
||||
# When action is None, it may still be present with None value
|
||||
assert TransitionKey.ACTION not in processed or processed[TransitionKey.ACTION] is None
|
||||
# Action should be present and batched, even if it's zeros
|
||||
assert processed[TransitionKey.ACTION].shape == (1, 5)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
|
||||
Reference in New Issue
Block a user