mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
refactor(pipeline): Transition from tuple to dictionary format for EnvTransition
- Updated the EnvTransition structure to use a dictionary format instead of a tuple, enhancing readability and maintainability. - Replaced instances of TransitionIndex with TransitionKey for accessing transition components. - Adjusted related processing functions and tests to accommodate the new dictionary format, ensuring consistent handling of transitions across the codebase.
This commit is contained in:
@@ -13,14 +13,28 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionIndex
|
||||
from lerobot.processor import ProcessorStepRegistry, RenameProcessor, RobotProcessor, TransitionKey
|
||||
|
||||
|
||||
def create_transition(
|
||||
observation=None, action=None, reward=None, done=None, truncated=None, info=None, complementary_data=None
|
||||
):
|
||||
"""Helper to create an EnvTransition dictionary."""
|
||||
return {
|
||||
TransitionKey.OBSERVATION: observation,
|
||||
TransitionKey.ACTION: action,
|
||||
TransitionKey.REWARD: reward,
|
||||
TransitionKey.DONE: done,
|
||||
TransitionKey.TRUNCATED: truncated,
|
||||
TransitionKey.INFO: info,
|
||||
TransitionKey.COMPLEMENTARY_DATA: complementary_data,
|
||||
}
|
||||
|
||||
|
||||
def test_basic_renaming():
|
||||
@@ -36,10 +50,10 @@ def test_basic_renaming():
|
||||
"old_key2": np.array([3.0, 4.0]),
|
||||
"unchanged_key": "keep_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "new_key1" in processed_obs
|
||||
@@ -63,10 +77,10 @@ def test_empty_rename_map():
|
||||
"key1": torch.tensor([1.0]),
|
||||
"key2": "value2",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# All keys should be unchanged
|
||||
assert processed_obs.keys() == observation.keys()
|
||||
@@ -78,7 +92,7 @@ def test_none_observation():
|
||||
"""Test processor with None observation."""
|
||||
processor = RenameProcessor(rename_map={"old": "new"})
|
||||
|
||||
transition = (None, None, None, None, None, None, None)
|
||||
transition = create_transition()
|
||||
result = processor(transition)
|
||||
|
||||
# Should return transition unchanged
|
||||
@@ -98,10 +112,10 @@ def test_overlapping_rename():
|
||||
"b": 2,
|
||||
"x": 3,
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that renaming happens correctly
|
||||
assert "a" not in processed_obs
|
||||
@@ -124,10 +138,10 @@ def test_partial_rename():
|
||||
"reward": 1.0,
|
||||
"info": {"episode": 1},
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renamed keys
|
||||
assert "observation.proprio_state" in processed_obs
|
||||
@@ -178,10 +192,12 @@ def test_integration_with_robot_processor():
|
||||
"pixels": np.zeros((32, 32, 3), dtype=np.uint8),
|
||||
"other_data": "preserve_me",
|
||||
}
|
||||
transition = (observation, None, 0.5, False, False, {}, {})
|
||||
transition = create_transition(
|
||||
observation=observation, reward=0.5, done=False, truncated=False, info={}, complementary_data={}
|
||||
)
|
||||
|
||||
result = pipeline(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renaming worked through pipeline
|
||||
assert "observation.state" in processed_obs
|
||||
@@ -191,8 +207,8 @@ def test_integration_with_robot_processor():
|
||||
assert processed_obs["other_data"] == "preserve_me"
|
||||
|
||||
# Check other transition elements unchanged
|
||||
assert result[TransitionIndex.REWARD] == 0.5
|
||||
assert result[TransitionIndex.DONE] is False
|
||||
assert result[TransitionKey.REWARD] == 0.5
|
||||
assert result[TransitionKey.DONE] is False
|
||||
|
||||
|
||||
def test_save_and_load_pretrained():
|
||||
@@ -229,10 +245,10 @@ def test_save_and_load_pretrained():
|
||||
|
||||
# Test functionality after loading
|
||||
observation = {"old_state": [1, 2, 3], "old_image": "image_data"}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = loaded_pipeline(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
assert "observation.state" in processed_obs
|
||||
assert "observation.image" in processed_obs
|
||||
@@ -306,17 +322,17 @@ def test_chained_rename_processors():
|
||||
"img": "image_data",
|
||||
"extra": "keep_me",
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
# Step through to see intermediate results
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
# After first processor
|
||||
assert "agent_position" in results[1][TransitionIndex.OBSERVATION]
|
||||
assert "camera_image" in results[1][TransitionIndex.OBSERVATION]
|
||||
assert "agent_position" in results[1][TransitionKey.OBSERVATION]
|
||||
assert "camera_image" in results[1][TransitionKey.OBSERVATION]
|
||||
|
||||
# After second processor
|
||||
final_obs = results[2][TransitionIndex.OBSERVATION]
|
||||
final_obs = results[2][TransitionKey.OBSERVATION]
|
||||
assert "observation.state" in final_obs
|
||||
assert "observation.image" in final_obs
|
||||
assert final_obs["extra"] == "keep_me"
|
||||
@@ -343,10 +359,10 @@ def test_nested_observation_rename():
|
||||
"observation.proprio": torch.randn(7),
|
||||
"observation.gripper": torch.tensor([0.0]), # Not renamed
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check renames
|
||||
assert "observation.camera.left_view" in processed_obs
|
||||
@@ -378,10 +394,10 @@ def test_value_types_preserved():
|
||||
"old_dict": {"nested": "value"},
|
||||
"old_list": [1, 2, 3],
|
||||
}
|
||||
transition = (observation, None, None, None, None, None, None)
|
||||
transition = create_transition(observation=observation)
|
||||
|
||||
result = processor(transition)
|
||||
processed_obs = result[TransitionIndex.OBSERVATION]
|
||||
processed_obs = result[TransitionKey.OBSERVATION]
|
||||
|
||||
# Check that values and types are preserved
|
||||
assert torch.equal(processed_obs["new_tensor"], tensor_value)
|
||||
|
||||
Reference in New Issue
Block a user