mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
feat(pipeline): Add hook unregistration functionality and enhance documentation
- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.
This commit is contained in:
@@ -350,6 +350,177 @@ def test_reset():
|
||||
assert len(reset_called) == 1
|
||||
|
||||
|
||||
def test_unregister_hooks():
|
||||
"""Test unregistering hooks from the pipeline."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
# Test before_step_hook
|
||||
before_calls = []
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
|
||||
# Verify hook is registered
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
assert len(before_calls) == 1
|
||||
|
||||
# Unregister and verify it's no longer called
|
||||
pipeline.unregister_before_step_hook(before_hook)
|
||||
before_calls.clear()
|
||||
pipeline(transition)
|
||||
assert len(before_calls) == 0
|
||||
|
||||
# Test after_step_hook
|
||||
after_calls = []
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
pipeline(transition)
|
||||
assert len(after_calls) == 1
|
||||
|
||||
pipeline.unregister_after_step_hook(after_hook)
|
||||
after_calls.clear()
|
||||
pipeline(transition)
|
||||
assert len(after_calls) == 0
|
||||
|
||||
# Test reset_hook
|
||||
reset_calls = []
|
||||
|
||||
def reset_hook():
|
||||
reset_calls.append(True)
|
||||
|
||||
pipeline.register_reset_hook(reset_hook)
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 1
|
||||
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
reset_calls.clear()
|
||||
pipeline.reset()
|
||||
assert len(reset_calls) == 0
|
||||
|
||||
|
||||
def test_unregister_nonexistent_hook():
|
||||
"""Test error handling when unregistering hooks that don't exist."""
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
def some_hook(idx: int, transition: EnvTransition):
|
||||
return None
|
||||
|
||||
def reset_hook():
|
||||
pass
|
||||
|
||||
# Test unregistering hooks that were never registered
|
||||
with pytest.raises(ValueError, match="not found in before_step_hooks"):
|
||||
pipeline.unregister_before_step_hook(some_hook)
|
||||
|
||||
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
||||
pipeline.unregister_after_step_hook(some_hook)
|
||||
|
||||
with pytest.raises(ValueError, match="not found in reset_hooks"):
|
||||
pipeline.unregister_reset_hook(reset_hook)
|
||||
|
||||
|
||||
def test_multiple_hooks_and_selective_unregister():
|
||||
"""Test registering multiple hooks and selectively unregistering them."""
|
||||
pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")])
|
||||
|
||||
calls_1 = []
|
||||
calls_2 = []
|
||||
calls_3 = []
|
||||
|
||||
def hook1(idx: int, transition: EnvTransition):
|
||||
calls_1.append(f"hook1_step{idx}")
|
||||
return None
|
||||
|
||||
def hook2(idx: int, transition: EnvTransition):
|
||||
calls_2.append(f"hook2_step{idx}")
|
||||
return None
|
||||
|
||||
def hook3(idx: int, transition: EnvTransition):
|
||||
calls_3.append(f"hook3_step{idx}")
|
||||
return None
|
||||
|
||||
# Register multiple hooks
|
||||
pipeline.register_before_step_hook(hook1)
|
||||
pipeline.register_before_step_hook(hook2)
|
||||
pipeline.register_before_step_hook(hook3)
|
||||
|
||||
# Run pipeline - all hooks should be called for both steps
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
|
||||
assert calls_1 == ["hook1_step0", "hook1_step1"]
|
||||
assert calls_2 == ["hook2_step0", "hook2_step1"]
|
||||
assert calls_3 == ["hook3_step0", "hook3_step1"]
|
||||
|
||||
# Clear calls
|
||||
calls_1.clear()
|
||||
calls_2.clear()
|
||||
calls_3.clear()
|
||||
|
||||
# Unregister middle hook
|
||||
pipeline.unregister_before_step_hook(hook2)
|
||||
|
||||
# Run again - only hook1 and hook3 should be called
|
||||
pipeline(transition)
|
||||
|
||||
assert calls_1 == ["hook1_step0", "hook1_step1"]
|
||||
assert calls_2 == [] # hook2 was unregistered
|
||||
assert calls_3 == ["hook3_step0", "hook3_step1"]
|
||||
|
||||
|
||||
def test_hook_execution_order_documentation():
|
||||
"""Test and document that hooks are executed sequentially in registration order."""
|
||||
pipeline = RobotProcessor([MockStep("step")])
|
||||
|
||||
execution_order = []
|
||||
|
||||
def hook_a(idx: int, transition: EnvTransition):
|
||||
execution_order.append("A")
|
||||
return None
|
||||
|
||||
def hook_b(idx: int, transition: EnvTransition):
|
||||
execution_order.append("B")
|
||||
return None
|
||||
|
||||
def hook_c(idx: int, transition: EnvTransition):
|
||||
execution_order.append("C")
|
||||
return None
|
||||
|
||||
# Register in specific order: A, B, C
|
||||
pipeline.register_before_step_hook(hook_a)
|
||||
pipeline.register_before_step_hook(hook_b)
|
||||
pipeline.register_before_step_hook(hook_c)
|
||||
|
||||
transition = create_transition()
|
||||
pipeline(transition)
|
||||
|
||||
# Verify execution order matches registration order
|
||||
assert execution_order == ["A", "B", "C"]
|
||||
|
||||
# Test that after unregistering B and re-registering it, it goes to the end
|
||||
pipeline.unregister_before_step_hook(hook_b)
|
||||
execution_order.clear()
|
||||
|
||||
pipeline(transition)
|
||||
assert execution_order == ["A", "C"] # B is gone
|
||||
|
||||
# Re-register B - it should now be at the end
|
||||
pipeline.register_before_step_hook(hook_b)
|
||||
execution_order.clear()
|
||||
|
||||
pipeline(transition)
|
||||
assert execution_order == ["A", "C", "B"] # B is now last
|
||||
|
||||
|
||||
def test_profile_steps():
|
||||
"""Test step profiling functionality."""
|
||||
step1 = MockStep("step1")
|
||||
|
||||
Reference in New Issue
Block a user