|
|
|
@@ -38,7 +38,7 @@ def test_state_1d_to_2d():
|
|
|
|
|
# Test observation.state
|
|
|
|
|
state_1d = torch.randn(7)
|
|
|
|
|
observation = {OBS_STATE: state_1d}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -54,7 +54,7 @@ def test_env_state_1d_to_2d():
|
|
|
|
|
# Test observation.environment_state
|
|
|
|
|
env_state_1d = torch.randn(10)
|
|
|
|
|
observation = {OBS_ENV_STATE: env_state_1d}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -70,7 +70,7 @@ def test_image_3d_to_4d():
|
|
|
|
|
# Test observation.image
|
|
|
|
|
image_3d = torch.randn(224, 224, 3)
|
|
|
|
|
observation = {OBS_IMAGE: image_3d}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -90,7 +90,7 @@ def test_multiple_images_3d_to_4d():
|
|
|
|
|
f"{OBS_IMAGES}.camera1": image1_3d,
|
|
|
|
|
f"{OBS_IMAGES}.camera2": image2_3d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -118,7 +118,7 @@ def test_already_batched_tensors_unchanged():
|
|
|
|
|
OBS_ENV_STATE: env_state_2d,
|
|
|
|
|
OBS_IMAGE: image_4d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -142,7 +142,7 @@ def test_higher_dimensional_tensors_unchanged():
|
|
|
|
|
OBS_STATE: state_3d,
|
|
|
|
|
OBS_IMAGE: image_5d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -163,7 +163,7 @@ def test_non_tensor_values_unchanged():
|
|
|
|
|
"custom_key": 42, # Integer
|
|
|
|
|
"another_key": {"nested": "dict"}, # Dict
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -180,10 +180,10 @@ def test_none_observation():
|
|
|
|
|
"""Test processor handles None observation gracefully."""
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
transition = create_transition(observation=None)
|
|
|
|
|
transition = create_transition(observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.OBSERVATION] is None
|
|
|
|
|
assert result[TransitionKey.OBSERVATION] == {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_empty_observation():
|
|
|
|
@@ -191,7 +191,7 @@ def test_empty_observation():
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
observation = {}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -216,7 +216,7 @@ def test_mixed_observation():
|
|
|
|
|
"other_tensor": other_tensor,
|
|
|
|
|
"non_tensor": "string_value",
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_obs = result[TransitionKey.OBSERVATION]
|
|
|
|
@@ -243,7 +243,7 @@ def test_integration_with_robot_processor():
|
|
|
|
|
OBS_STATE: torch.randn(7),
|
|
|
|
|
OBS_IMAGE: torch.randn(224, 224, 3),
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = pipeline(transition)
|
|
|
|
|
processed_obs = result[TransitionKey.OBSERVATION]
|
|
|
|
@@ -299,7 +299,7 @@ def test_save_and_load_pretrained():
|
|
|
|
|
|
|
|
|
|
# Test functionality of loaded processor
|
|
|
|
|
observation = {OBS_STATE: torch.randn(5)}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = loaded_pipeline(transition)
|
|
|
|
|
assert result[TransitionKey.OBSERVATION][OBS_STATE].shape == (1, 5)
|
|
|
|
@@ -333,7 +333,7 @@ def test_registry_based_save_load():
|
|
|
|
|
OBS_STATE: torch.randn(3),
|
|
|
|
|
OBS_IMAGE: torch.randn(100, 100, 3),
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = loaded_pipeline(transition)
|
|
|
|
|
processed_obs = result[TransitionKey.OBSERVATION]
|
|
|
|
@@ -355,7 +355,7 @@ def test_device_compatibility():
|
|
|
|
|
OBS_STATE: state_1d,
|
|
|
|
|
OBS_IMAGE: image_3d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_obs = result[TransitionKey.OBSERVATION]
|
|
|
|
@@ -415,7 +415,7 @@ def test_edge_case_zero_dimensional_tensors():
|
|
|
|
|
OBS_STATE: scalar_tensor,
|
|
|
|
|
"scalar_value": scalar_tensor,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(observation=observation)
|
|
|
|
|
transition = create_transition(observation=observation, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_obs = result[TransitionKey.OBSERVATION]
|
|
|
|
@@ -432,7 +432,7 @@ def test_action_1d_to_2d():
|
|
|
|
|
|
|
|
|
|
# Create 1D action tensor
|
|
|
|
|
action_1d = torch.randn(4)
|
|
|
|
|
transition = create_transition(action=action_1d)
|
|
|
|
|
transition = create_transition(observation={}, action=action_1d)
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -450,12 +450,12 @@ def test_action_already_batched():
|
|
|
|
|
action_batched_5 = torch.randn(5, 4)
|
|
|
|
|
|
|
|
|
|
# Single batch
|
|
|
|
|
transition = create_transition(action=action_batched_1)
|
|
|
|
|
transition = create_transition(action=action_batched_1, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.ACTION], action_batched_1)
|
|
|
|
|
|
|
|
|
|
# Multiple batch
|
|
|
|
|
transition = create_transition(action=action_batched_5)
|
|
|
|
|
transition = create_transition(action=action_batched_5, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.ACTION], action_batched_5)
|
|
|
|
|
|
|
|
|
@@ -466,13 +466,13 @@ def test_action_higher_dimensional():
|
|
|
|
|
|
|
|
|
|
# 3D action tensor (e.g., sequence of actions)
|
|
|
|
|
action_3d = torch.randn(2, 4, 3)
|
|
|
|
|
transition = create_transition(action=action_3d)
|
|
|
|
|
transition = create_transition(action=action_3d, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.ACTION], action_3d)
|
|
|
|
|
|
|
|
|
|
# 4D action tensor
|
|
|
|
|
action_4d = torch.randn(2, 10, 4, 3)
|
|
|
|
|
transition = create_transition(action=action_4d)
|
|
|
|
|
transition = create_transition(action=action_4d, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.ACTION], action_4d)
|
|
|
|
|
|
|
|
|
@@ -482,7 +482,7 @@ def test_action_scalar_tensor():
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
action_scalar = torch.tensor(1.5)
|
|
|
|
|
transition = create_transition(action=action_scalar)
|
|
|
|
|
transition = create_transition(action=action_scalar, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
# Should remain scalar
|
|
|
|
@@ -496,25 +496,25 @@ def test_action_non_tensor():
|
|
|
|
|
|
|
|
|
|
# List action
|
|
|
|
|
action_list = [0.1, 0.2, 0.3, 0.4]
|
|
|
|
|
transition = create_transition(action=action_list)
|
|
|
|
|
transition = create_transition(action=action_list, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert result[TransitionKey.ACTION] == action_list
|
|
|
|
|
|
|
|
|
|
# Numpy array action (as Python object, not converted)
|
|
|
|
|
action_numpy = np.array([1, 2, 3, 4])
|
|
|
|
|
transition = create_transition(action=action_numpy)
|
|
|
|
|
transition = create_transition(action=action_numpy, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert np.array_equal(result[TransitionKey.ACTION], action_numpy)
|
|
|
|
|
|
|
|
|
|
# String action (edge case)
|
|
|
|
|
action_string = "forward"
|
|
|
|
|
transition = create_transition(action=action_string)
|
|
|
|
|
transition = create_transition(action=action_string, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert result[TransitionKey.ACTION] == action_string
|
|
|
|
|
|
|
|
|
|
# Dict action (structured action)
|
|
|
|
|
action_dict = {"linear": [0.5, 0.0], "angular": 0.2}
|
|
|
|
|
transition = create_transition(action=action_dict)
|
|
|
|
|
transition = create_transition(action=action_dict, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert result[TransitionKey.ACTION] == action_dict
|
|
|
|
|
|
|
|
|
@@ -523,9 +523,9 @@ def test_action_none():
|
|
|
|
|
"""Test that None action is handled correctly."""
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
transition = create_transition(action=None)
|
|
|
|
|
transition = create_transition(action={}, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert result[TransitionKey.ACTION] is None
|
|
|
|
|
assert result[TransitionKey.ACTION] == {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def test_action_with_observation():
|
|
|
|
@@ -557,7 +557,7 @@ def test_action_different_sizes():
|
|
|
|
|
|
|
|
|
|
for size in action_sizes:
|
|
|
|
|
action = torch.randn(size)
|
|
|
|
|
transition = create_transition(action=action)
|
|
|
|
|
transition = create_transition(action=action, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.ACTION].shape == (1, size)
|
|
|
|
@@ -571,7 +571,7 @@ def test_action_device_compatibility():
|
|
|
|
|
|
|
|
|
|
# CUDA action
|
|
|
|
|
action_cuda = torch.randn(4, device="cuda")
|
|
|
|
|
transition = create_transition(action=action_cuda)
|
|
|
|
|
transition = create_transition(action=action_cuda, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
|
|
|
@@ -579,7 +579,7 @@ def test_action_device_compatibility():
|
|
|
|
|
|
|
|
|
|
# CPU action
|
|
|
|
|
action_cpu = torch.randn(4, device="cpu")
|
|
|
|
|
transition = create_transition(action=action_cpu)
|
|
|
|
|
transition = create_transition(action=action_cpu, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.ACTION].shape == (1, 4)
|
|
|
|
@@ -595,7 +595,7 @@ def test_action_dtype_preservation():
|
|
|
|
|
|
|
|
|
|
for dtype in dtypes:
|
|
|
|
|
action = torch.randn(4).to(dtype)
|
|
|
|
|
transition = create_transition(action=action)
|
|
|
|
|
transition = create_transition(action=action, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.ACTION].dtype == dtype
|
|
|
|
@@ -608,7 +608,7 @@ def test_empty_action_tensor():
|
|
|
|
|
|
|
|
|
|
# Empty 1D tensor
|
|
|
|
|
action_empty = torch.tensor([])
|
|
|
|
|
transition = create_transition(action=action_empty)
|
|
|
|
|
transition = create_transition(action=action_empty, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
# Should add batch dimension even to empty tensor
|
|
|
|
@@ -616,7 +616,7 @@ def test_empty_action_tensor():
|
|
|
|
|
|
|
|
|
|
# Empty 2D tensor (already batched)
|
|
|
|
|
action_empty_2d = torch.randn(1, 0)
|
|
|
|
|
transition = create_transition(action=action_empty_2d)
|
|
|
|
|
transition = create_transition(action=action_empty_2d, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
# Should remain unchanged
|
|
|
|
@@ -630,7 +630,7 @@ def test_task_string_to_list():
|
|
|
|
|
|
|
|
|
|
# Create complementary data with string task
|
|
|
|
|
complementary_data = {"task": "pick_cube"}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(action={}, observation={}, complementary_data=complementary_data)
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -647,14 +647,14 @@ def test_task_string_validation():
|
|
|
|
|
|
|
|
|
|
# Valid string task - should be converted to list
|
|
|
|
|
complementary_data = {"task": "valid_task"}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
|
assert processed_comp_data["task"] == ["valid_task"]
|
|
|
|
|
|
|
|
|
|
# Valid list of strings - should remain unchanged
|
|
|
|
|
complementary_data = {"task": ["task1", "task2"]}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
|
assert processed_comp_data["task"] == ["task1", "task2"]
|
|
|
|
@@ -676,7 +676,7 @@ def test_task_list_of_strings():
|
|
|
|
|
|
|
|
|
|
for task_list in test_lists:
|
|
|
|
|
complementary_data = {"task": task_list}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -690,7 +690,7 @@ def test_complementary_data_none():
|
|
|
|
|
"""Test processor handles None complementary_data gracefully."""
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
transition = create_transition(complementary_data=None)
|
|
|
|
|
transition = create_transition(complementary_data=None, action={}, observation={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
|
assert result[TransitionKey.COMPLEMENTARY_DATA] == {}
|
|
|
|
@@ -701,7 +701,7 @@ def test_complementary_data_empty():
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
complementary_data = {}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -717,7 +717,7 @@ def test_complementary_data_no_task():
|
|
|
|
|
"timestamp": 1234567890.0,
|
|
|
|
|
"extra_info": "some data",
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -736,7 +736,7 @@ def test_complementary_data_mixed():
|
|
|
|
|
"difficulty": "hard",
|
|
|
|
|
"metadata": {"scene": "kitchen"},
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -803,7 +803,7 @@ def test_task_comprehensive_string_cases():
|
|
|
|
|
# Test that all string tasks get properly batched
|
|
|
|
|
for task in string_tasks:
|
|
|
|
|
complementary_data = {"task": task}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -825,7 +825,7 @@ def test_task_comprehensive_string_cases():
|
|
|
|
|
|
|
|
|
|
for task_list in list_tasks:
|
|
|
|
|
complementary_data = {"task": task_list}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -845,7 +845,7 @@ def test_task_preserves_other_keys():
|
|
|
|
|
"config": {"speed": "slow", "precision": "high"},
|
|
|
|
|
"metrics": [1.0, 2.0, 3.0],
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -869,7 +869,7 @@ def test_index_scalar_to_1d():
|
|
|
|
|
# Create 0D index tensor (scalar)
|
|
|
|
|
index_0d = torch.tensor(42, dtype=torch.int64)
|
|
|
|
|
complementary_data = {"index": index_0d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -886,7 +886,7 @@ def test_task_index_scalar_to_1d():
|
|
|
|
|
# Create 0D task_index tensor (scalar)
|
|
|
|
|
task_index_0d = torch.tensor(7, dtype=torch.int64)
|
|
|
|
|
complementary_data = {"task_index": task_index_0d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -908,7 +908,7 @@ def test_index_and_task_index_together():
|
|
|
|
|
"task_index": task_index_0d,
|
|
|
|
|
"task": "pick_object",
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -936,13 +936,13 @@ def test_index_already_batched():
|
|
|
|
|
|
|
|
|
|
# Test 1D (already batched)
|
|
|
|
|
complementary_data = {"index": index_1d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_1d)
|
|
|
|
|
|
|
|
|
|
# Test 2D
|
|
|
|
|
complementary_data = {"index": index_2d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["index"], index_2d)
|
|
|
|
|
|
|
|
|
@@ -957,13 +957,13 @@ def test_task_index_already_batched():
|
|
|
|
|
|
|
|
|
|
# Test 1D (already batched)
|
|
|
|
|
complementary_data = {"task_index": task_index_1d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_1d)
|
|
|
|
|
|
|
|
|
|
# Test 2D
|
|
|
|
|
complementary_data = {"task_index": task_index_2d}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
assert torch.equal(result[TransitionKey.COMPLEMENTARY_DATA]["task_index"], task_index_2d)
|
|
|
|
|
|
|
|
|
@@ -976,7 +976,7 @@ def test_index_non_tensor_unchanged():
|
|
|
|
|
"index": 42, # Plain int, not tensor
|
|
|
|
|
"task_index": [1, 2, 3], # List, not tensor
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -999,7 +999,7 @@ def test_index_dtype_preservation():
|
|
|
|
|
"index": index_0d,
|
|
|
|
|
"task_index": task_index_0d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -1062,7 +1062,7 @@ def test_index_device_compatibility():
|
|
|
|
|
"index": index_0d,
|
|
|
|
|
"task_index": task_index_0d,
|
|
|
|
|
}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
processed_comp_data = result[TransitionKey.COMPLEMENTARY_DATA]
|
|
|
|
@@ -1081,7 +1081,7 @@ def test_empty_index_tensor():
|
|
|
|
|
# Empty 0D tensor doesn't make sense, but test empty 1D
|
|
|
|
|
index_empty = torch.tensor([], dtype=torch.int64)
|
|
|
|
|
complementary_data = {"index": index_empty}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
result = processor(transition)
|
|
|
|
|
|
|
|
|
@@ -1094,7 +1094,7 @@ def test_action_processing_creates_new_transition():
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
action = torch.randn(4)
|
|
|
|
|
transition = create_transition(action=action)
|
|
|
|
|
transition = create_transition(action=action, observation={})
|
|
|
|
|
|
|
|
|
|
# Store reference to original transition
|
|
|
|
|
original_transition = transition
|
|
|
|
@@ -1116,7 +1116,7 @@ def test_task_processing_creates_new_transition():
|
|
|
|
|
processor = AddBatchDimensionProcessorStep()
|
|
|
|
|
|
|
|
|
|
complementary_data = {"task": "sort_objects"}
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data)
|
|
|
|
|
transition = create_transition(complementary_data=complementary_data, observation={}, action={})
|
|
|
|
|
|
|
|
|
|
# Store reference to original transition and complementary_data
|
|
|
|
|
original_transition = transition
|
|
|
|
|