From 77106697c3f0015e5d81b08da340b2f0442fbf77 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Mon, 21 Jul 2025 18:56:01 +0200 Subject: [PATCH] feat(pipeline): Add hook unregistration functionality and enhance documentation - Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks. --- src/lerobot/processor/pipeline.py | 68 ++++++++++++ tests/processor/test_pipeline.py | 171 ++++++++++++++++++++++++++++++ 2 files changed, 239 insertions(+) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 323a6066c..34504ed96 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -261,6 +261,19 @@ class RobotProcessor(ModelHubMixin): after_step_hooks: List of hooks called after each step. Each hook receives the step index and transition, and can optionally return a modified transition. reset_hooks: List of hooks called during processor reset. + + Hook Semantics: + - Hooks are executed sequentially in the order they were registered. There is no way to + reorder hooks after registration without creating a new pipeline. + - Hooks CAN modify transitions by returning a new transition dict. If a hook returns None, + the current transition remains unchanged. While this capability exists, it should be used + with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects. + IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK. + - All hooks for a given type (before/after) are executed for every step, or none at all if + an error occurs. There is no partial execution of hooks. + - Hooks should generally be stateless to maintain predictable behavior. If you need stateful + processing, consider implementing a proper ProcessorStep instead. + - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. """ steps: Sequence[ProcessorStep] = field(default_factory=list) @@ -318,6 +331,13 @@ class RobotProcessor(ModelHubMixin): if not isinstance(transition, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") + # Hook execution subtleties: + # - Hooks are executed sequentially in the order they were registered (list order) + # - Each hook sees the potentially modified transition from the previous hook + # - If a hook returns None, the transition remains unchanged + # - All hooks for a given type (before/after) run for every step, creating a + # multiplicative effect: N steps × M hooks = N×M hook executions + # - Hook execution cannot be interrupted - they all run or none run (on error) for idx, processor_step in enumerate(self.steps): for hook in self.before_step_hooks: updated = hook(idx, transition) @@ -638,14 +658,62 @@ class RobotProcessor(ModelHubMixin): """Attach fn to be executed before every processor step.""" self.before_step_hooks.append(fn) + def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + """Remove a previously registered before_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.before_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." + ) from None + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): """Attach fn to be executed after every processor step.""" self.after_step_hooks.append(fn) + def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + """Remove a previously registered after_step hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.after_step_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference." + ) from None + def register_reset_hook(self, fn: Callable[[], None]): """Attach fn to be executed when reset is called.""" self.reset_hooks.append(fn) + def unregister_reset_hook(self, fn: Callable[[], None]): + """Remove a previously registered reset hook. + + Args: + fn: The exact function reference that was registered. Must be the same object. + + Raises: + ValueError: If the hook is not found in the registered hooks. + """ + try: + self.reset_hooks.remove(fn) + except ValueError: + raise ValueError( + f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference." + ) from None + def reset(self): """Clear state in every step that implements ``reset()`` and fire registered hooks.""" for step in self.steps: diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index a21e229dd..f9f6237ff 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -350,6 +350,177 @@ def test_reset(): assert len(reset_called) == 1 +def test_unregister_hooks(): + """Test unregistering hooks from the pipeline.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + # Test before_step_hook + before_calls = [] + + def before_hook(idx: int, transition: EnvTransition): + before_calls.append(idx) + return None + + pipeline.register_before_step_hook(before_hook) + + # Verify hook is registered + transition = create_transition() + pipeline(transition) + assert len(before_calls) == 1 + + # Unregister and verify it's no longer called + pipeline.unregister_before_step_hook(before_hook) + before_calls.clear() + pipeline(transition) + assert len(before_calls) == 0 + + # Test after_step_hook + after_calls = [] + + def after_hook(idx: int, transition: EnvTransition): + after_calls.append(idx) + return None + + pipeline.register_after_step_hook(after_hook) + pipeline(transition) + assert len(after_calls) == 1 + + pipeline.unregister_after_step_hook(after_hook) + after_calls.clear() + pipeline(transition) + assert len(after_calls) == 0 + + # Test reset_hook + reset_calls = [] + + def reset_hook(): + reset_calls.append(True) + + pipeline.register_reset_hook(reset_hook) + pipeline.reset() + assert len(reset_calls) == 1 + + pipeline.unregister_reset_hook(reset_hook) + reset_calls.clear() + pipeline.reset() + assert len(reset_calls) == 0 + + +def test_unregister_nonexistent_hook(): + """Test error handling when unregistering hooks that don't exist.""" + pipeline = RobotProcessor([MockStep()]) + + def some_hook(idx: int, transition: EnvTransition): + return None + + def reset_hook(): + pass + + # Test unregistering hooks that were never registered + with pytest.raises(ValueError, match="not found in before_step_hooks"): + pipeline.unregister_before_step_hook(some_hook) + + with pytest.raises(ValueError, match="not found in after_step_hooks"): + pipeline.unregister_after_step_hook(some_hook) + + with pytest.raises(ValueError, match="not found in reset_hooks"): + pipeline.unregister_reset_hook(reset_hook) + + +def test_multiple_hooks_and_selective_unregister(): + """Test registering multiple hooks and selectively unregistering them.""" + pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")]) + + calls_1 = [] + calls_2 = [] + calls_3 = [] + + def hook1(idx: int, transition: EnvTransition): + calls_1.append(f"hook1_step{idx}") + return None + + def hook2(idx: int, transition: EnvTransition): + calls_2.append(f"hook2_step{idx}") + return None + + def hook3(idx: int, transition: EnvTransition): + calls_3.append(f"hook3_step{idx}") + return None + + # Register multiple hooks + pipeline.register_before_step_hook(hook1) + pipeline.register_before_step_hook(hook2) + pipeline.register_before_step_hook(hook3) + + # Run pipeline - all hooks should be called for both steps + transition = create_transition() + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == ["hook2_step0", "hook2_step1"] + assert calls_3 == ["hook3_step0", "hook3_step1"] + + # Clear calls + calls_1.clear() + calls_2.clear() + calls_3.clear() + + # Unregister middle hook + pipeline.unregister_before_step_hook(hook2) + + # Run again - only hook1 and hook3 should be called + pipeline(transition) + + assert calls_1 == ["hook1_step0", "hook1_step1"] + assert calls_2 == [] # hook2 was unregistered + assert calls_3 == ["hook3_step0", "hook3_step1"] + + +def test_hook_execution_order_documentation(): + """Test and document that hooks are executed sequentially in registration order.""" + pipeline = RobotProcessor([MockStep("step")]) + + execution_order = [] + + def hook_a(idx: int, transition: EnvTransition): + execution_order.append("A") + return None + + def hook_b(idx: int, transition: EnvTransition): + execution_order.append("B") + return None + + def hook_c(idx: int, transition: EnvTransition): + execution_order.append("C") + return None + + # Register in specific order: A, B, C + pipeline.register_before_step_hook(hook_a) + pipeline.register_before_step_hook(hook_b) + pipeline.register_before_step_hook(hook_c) + + transition = create_transition() + pipeline(transition) + + # Verify execution order matches registration order + assert execution_order == ["A", "B", "C"] + + # Test that after unregistering B and re-registering it, it goes to the end + pipeline.unregister_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C"] # B is gone + + # Re-register B - it should now be at the end + pipeline.register_before_step_hook(hook_b) + execution_order.clear() + + pipeline(transition) + assert execution_order == ["A", "C", "B"] # B is now last + + def test_profile_steps(): """Test step profiling functionality.""" step1 = MockStep("step1")