mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor(pipeline): Clarify hook behavior and improve documentation
- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
This commit is contained in:
@@ -263,6 +263,40 @@ def test_step_through_with_dict():
|
||||
# For now, just check that we get dict outputs
|
||||
|
||||
|
||||
def test_step_through_no_hooks():
|
||||
"""Test that step_through doesn't execute hooks."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
hook_calls = []
|
||||
|
||||
def tracking_hook(idx: int, transition: EnvTransition):
|
||||
hook_calls.append(f"hook_called_step_{idx}")
|
||||
|
||||
# Register hooks
|
||||
pipeline.register_before_step_hook(tracking_hook)
|
||||
pipeline.register_after_step_hook(tracking_hook)
|
||||
|
||||
# Use step_through
|
||||
transition = create_transition()
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
# Verify step was executed (counter should increment)
|
||||
assert len(results) == 2 # Initial + 1 step
|
||||
assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
|
||||
|
||||
# Verify hooks were NOT called
|
||||
assert len(hook_calls) == 0
|
||||
|
||||
# Now use __call__ to verify hooks ARE called there
|
||||
hook_calls.clear()
|
||||
pipeline(transition)
|
||||
|
||||
# Verify hooks were called (before and after for 1 step = 2 calls)
|
||||
assert len(hook_calls) == 2
|
||||
assert hook_calls == ["hook_called_step_0", "hook_called_step_0"]
|
||||
|
||||
|
||||
def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -290,11 +324,9 @@ def test_hooks():
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return transition
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return transition
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
@@ -306,24 +338,6 @@ def test_hooks():
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_hook_modification():
|
||||
"""Test that hooks can modify transitions."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
def modify_reward_hook(idx: int, transition: EnvTransition):
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = 42.0
|
||||
return new_transition
|
||||
|
||||
pipeline.register_before_step_hook(modify_reward_hook)
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
@@ -360,7 +374,6 @@ def test_unregister_hooks():
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
|
||||
@@ -380,7 +393,6 @@ def test_unregister_hooks():
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
pipeline(transition)
|
||||
@@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook():
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
def some_hook(idx: int, transition: EnvTransition):
|
||||
return None
|
||||
pass
|
||||
|
||||
def reset_hook():
|
||||
pass
|
||||
@@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister():
|
||||
|
||||
def hook1(idx: int, transition: EnvTransition):
|
||||
calls_1.append(f"hook1_step{idx}")
|
||||
return None
|
||||
|
||||
def hook2(idx: int, transition: EnvTransition):
|
||||
calls_2.append(f"hook2_step{idx}")
|
||||
return None
|
||||
|
||||
def hook3(idx: int, transition: EnvTransition):
|
||||
calls_3.append(f"hook3_step{idx}")
|
||||
return None
|
||||
|
||||
# Register multiple hooks
|
||||
pipeline.register_before_step_hook(hook1)
|
||||
@@ -485,15 +494,12 @@ def test_hook_execution_order_documentation():
|
||||
|
||||
def hook_a(idx: int, transition: EnvTransition):
|
||||
execution_order.append("A")
|
||||
return None
|
||||
|
||||
def hook_b(idx: int, transition: EnvTransition):
|
||||
execution_order.append("B")
|
||||
return None
|
||||
|
||||
def hook_c(idx: int, transition: EnvTransition):
|
||||
execution_order.append("C")
|
||||
return None
|
||||
|
||||
# Register in specific order: A, B, C
|
||||
pipeline.register_before_step_hook(hook_a)
|
||||
|
||||
Reference in New Issue
Block a user