mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(processor): clarify action types, distinguish PolicyAction, RobotAction, and EnvAction (#1908)
* refactor(processor): split action from policy, robots and environment - Updated function names to robot_action_to_transition and robot_transition_to_action across multiple files to better reflect their purpose in processing robot actions. - Adjusted references in the RobotProcessorPipeline and related components to ensure compatibility with the new naming convention. - Enhanced type annotations for action parameters to improve code readability and maintainability. * refactor(converters): rename robot_transition_to_action to transition_to_robot_action - Updated function names across multiple files to improve clarity and consistency in processing robot actions. - Adjusted references in RobotProcessorPipeline and related components to align with the new naming convention. - Simplified action handling in the AddBatchDimensionProcessorStep by removing unnecessary checks for action presence. * refactor(converters): update references to transition_to_robot_action - Renamed all instances of robot_transition_to_action to transition_to_robot_action across multiple files for consistency and clarity in the processing of robot actions. - Adjusted the RobotProcessorPipeline configurations to reflect the new naming convention, enhancing code readability. * refactor(processor): update Torch2NumpyActionProcessorStep to extend ActionProcessorStep - Changed the base class of Torch2NumpyActionProcessorStep from PolicyActionProcessorStep to ActionProcessorStep, aligning it with the current architecture of action processing. - This modification enhances the clarity of the class's role in the processing pipeline. * fix(processor): main action processor can take also EnvAction --------- Co-authored-by: Steven Palma <steven.palma@huggingface.co>
This commit is contained in:
@@ -49,7 +49,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
"observation.image.top": torch.randn(1, 3, 128, 128),
|
||||
"observation.image.left": torch.randn(1, 3, 128, 128),
|
||||
"observation.state": [1, 2, 3, 4],
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([0.1, 0.2, 0.3, 0.4]),
|
||||
"next.reward": 1.5,
|
||||
"next.done": True,
|
||||
"next.truncated": False,
|
||||
@@ -74,7 +74,7 @@ def test_batch_to_transition_observation_grouping():
|
||||
assert transition[TransitionKey.OBSERVATION]["observation.state"] == [1, 2, 3, 4]
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.1, 0.2, 0.3, 0.4]))
|
||||
assert transition[TransitionKey.REWARD] == 1.5
|
||||
assert transition[TransitionKey.DONE]
|
||||
assert not transition[TransitionKey.TRUNCATED]
|
||||
@@ -123,7 +123,7 @@ def test_transition_to_batch_observation_flattening():
|
||||
def test_no_observation_keys():
|
||||
"""Test behavior when there are no observation.* keys."""
|
||||
batch = {
|
||||
"action": "action_data",
|
||||
"action": torch.tensor([1.0, 2.0]),
|
||||
"next.reward": 2.0,
|
||||
"next.done": False,
|
||||
"next.truncated": True,
|
||||
@@ -136,7 +136,7 @@ def test_no_observation_keys():
|
||||
assert transition[TransitionKey.OBSERVATION] is None
|
||||
|
||||
# Check other fields
|
||||
assert transition[TransitionKey.ACTION] == "action_data"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([1.0, 2.0]))
|
||||
assert transition[TransitionKey.REWARD] == 2.0
|
||||
assert not transition[TransitionKey.DONE]
|
||||
assert transition[TransitionKey.TRUNCATED]
|
||||
@@ -144,7 +144,7 @@ def test_no_observation_keys():
|
||||
|
||||
# Round trip should work
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["action"] == "action_data"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([1.0, 2.0]))
|
||||
assert reconstructed_batch["next.reward"] == 2.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert reconstructed_batch["next.truncated"]
|
||||
@@ -153,13 +153,13 @@ def test_no_observation_keys():
|
||||
|
||||
def test_minimal_batch():
|
||||
"""Test with minimal batch containing only observation.* and action."""
|
||||
batch = {"observation.state": "minimal_state", "action": "minimal_action"}
|
||||
batch = {"observation.state": "minimal_state", "action": torch.tensor([0.5])}
|
||||
|
||||
transition = batch_to_transition(batch)
|
||||
|
||||
# Check observation
|
||||
assert transition[TransitionKey.OBSERVATION] == {"observation.state": "minimal_state"}
|
||||
assert transition[TransitionKey.ACTION] == "minimal_action"
|
||||
assert torch.allclose(transition[TransitionKey.ACTION], torch.tensor([0.5]))
|
||||
|
||||
# Check defaults
|
||||
assert transition[TransitionKey.REWARD] == 0.0
|
||||
@@ -171,7 +171,7 @@ def test_minimal_batch():
|
||||
# Round trip
|
||||
reconstructed_batch = transition_to_batch(transition)
|
||||
assert reconstructed_batch["observation.state"] == "minimal_state"
|
||||
assert reconstructed_batch["action"] == "minimal_action"
|
||||
assert torch.allclose(reconstructed_batch["action"], torch.tensor([0.5]))
|
||||
assert reconstructed_batch["next.reward"] == 0.0
|
||||
assert not reconstructed_batch["next.done"]
|
||||
assert not reconstructed_batch["next.truncated"]
|
||||
|
||||
Reference in New Issue
Block a user