feat(pipeline): Add hook unregistration functionality and enhance documentation

- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management.
- Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks.
- Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.
This commit is contained in:
Adil Zouitine
2025-07-21 18:56:01 +02:00
parent 75bc44c166
commit 77106697c3
2 changed files with 239 additions and 0 deletions
+68
View File
@@ -261,6 +261,19 @@ class RobotProcessor(ModelHubMixin):
after_step_hooks: List of hooks called after each step. Each hook receives the step after_step_hooks: List of hooks called after each step. Each hook receives the step
index and transition, and can optionally return a modified transition. index and transition, and can optionally return a modified transition.
reset_hooks: List of hooks called during processor reset. reset_hooks: List of hooks called during processor reset.
Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to
reorder hooks after registration without creating a new pipeline.
- Hooks CAN modify transitions by returning a new transition dict. If a hook returns None,
the current transition remains unchanged. While this capability exists, it should be used
with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects.
IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK.
- All hooks for a given type (before/after) are executed for every step, or none at all if
an error occurs. There is no partial execution of hooks.
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
processing, consider implementing a proper ProcessorStep instead.
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
""" """
steps: Sequence[ProcessorStep] = field(default_factory=list) steps: Sequence[ProcessorStep] = field(default_factory=list)
@@ -318,6 +331,13 @@ class RobotProcessor(ModelHubMixin):
if not isinstance(transition, dict): if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
# Hook execution subtleties:
# - Hooks are executed sequentially in the order they were registered (list order)
# - Each hook sees the potentially modified transition from the previous hook
# - If a hook returns None, the transition remains unchanged
# - All hooks for a given type (before/after) run for every step, creating a
# multiplicative effect: N steps × M hooks = N×M hook executions
# - Hook execution cannot be interrupted - they all run or none run (on error)
for idx, processor_step in enumerate(self.steps): for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks: for hook in self.before_step_hooks:
updated = hook(idx, transition) updated = hook(idx, transition)
@@ -638,14 +658,62 @@ class RobotProcessor(ModelHubMixin):
"""Attach fn to be executed before every processor step.""" """Attach fn to be executed before every processor step."""
self.before_step_hooks.append(fn) self.before_step_hooks.append(fn)
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Remove a previously registered before_step hook.
Args:
fn: The exact function reference that was registered. Must be the same object.
Raises:
ValueError: If the hook is not found in the registered hooks.
"""
try:
self.before_step_hooks.remove(fn)
except ValueError:
raise ValueError(
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
) from None
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Attach fn to be executed after every processor step.""" """Attach fn to be executed after every processor step."""
self.after_step_hooks.append(fn) self.after_step_hooks.append(fn)
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
"""Remove a previously registered after_step hook.
Args:
fn: The exact function reference that was registered. Must be the same object.
Raises:
ValueError: If the hook is not found in the registered hooks.
"""
try:
self.after_step_hooks.remove(fn)
except ValueError:
raise ValueError(
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
) from None
def register_reset_hook(self, fn: Callable[[], None]): def register_reset_hook(self, fn: Callable[[], None]):
"""Attach fn to be executed when reset is called.""" """Attach fn to be executed when reset is called."""
self.reset_hooks.append(fn) self.reset_hooks.append(fn)
def unregister_reset_hook(self, fn: Callable[[], None]):
"""Remove a previously registered reset hook.
Args:
fn: The exact function reference that was registered. Must be the same object.
Raises:
ValueError: If the hook is not found in the registered hooks.
"""
try:
self.reset_hooks.remove(fn)
except ValueError:
raise ValueError(
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
) from None
def reset(self): def reset(self):
"""Clear state in every step that implements ``reset()`` and fire registered hooks.""" """Clear state in every step that implements ``reset()`` and fire registered hooks."""
for step in self.steps: for step in self.steps:
+171
View File
@@ -350,6 +350,177 @@ def test_reset():
assert len(reset_called) == 1 assert len(reset_called) == 1
def test_unregister_hooks():
"""Test unregistering hooks from the pipeline."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
# Test before_step_hook
before_calls = []
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return None
pipeline.register_before_step_hook(before_hook)
# Verify hook is registered
transition = create_transition()
pipeline(transition)
assert len(before_calls) == 1
# Unregister and verify it's no longer called
pipeline.unregister_before_step_hook(before_hook)
before_calls.clear()
pipeline(transition)
assert len(before_calls) == 0
# Test after_step_hook
after_calls = []
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return None
pipeline.register_after_step_hook(after_hook)
pipeline(transition)
assert len(after_calls) == 1
pipeline.unregister_after_step_hook(after_hook)
after_calls.clear()
pipeline(transition)
assert len(after_calls) == 0
# Test reset_hook
reset_calls = []
def reset_hook():
reset_calls.append(True)
pipeline.register_reset_hook(reset_hook)
pipeline.reset()
assert len(reset_calls) == 1
pipeline.unregister_reset_hook(reset_hook)
reset_calls.clear()
pipeline.reset()
assert len(reset_calls) == 0
def test_unregister_nonexistent_hook():
"""Test error handling when unregistering hooks that don't exist."""
pipeline = RobotProcessor([MockStep()])
def some_hook(idx: int, transition: EnvTransition):
return None
def reset_hook():
pass
# Test unregistering hooks that were never registered
with pytest.raises(ValueError, match="not found in before_step_hooks"):
pipeline.unregister_before_step_hook(some_hook)
with pytest.raises(ValueError, match="not found in after_step_hooks"):
pipeline.unregister_after_step_hook(some_hook)
with pytest.raises(ValueError, match="not found in reset_hooks"):
pipeline.unregister_reset_hook(reset_hook)
def test_multiple_hooks_and_selective_unregister():
"""Test registering multiple hooks and selectively unregistering them."""
pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")])
calls_1 = []
calls_2 = []
calls_3 = []
def hook1(idx: int, transition: EnvTransition):
calls_1.append(f"hook1_step{idx}")
return None
def hook2(idx: int, transition: EnvTransition):
calls_2.append(f"hook2_step{idx}")
return None
def hook3(idx: int, transition: EnvTransition):
calls_3.append(f"hook3_step{idx}")
return None
# Register multiple hooks
pipeline.register_before_step_hook(hook1)
pipeline.register_before_step_hook(hook2)
pipeline.register_before_step_hook(hook3)
# Run pipeline - all hooks should be called for both steps
transition = create_transition()
pipeline(transition)
assert calls_1 == ["hook1_step0", "hook1_step1"]
assert calls_2 == ["hook2_step0", "hook2_step1"]
assert calls_3 == ["hook3_step0", "hook3_step1"]
# Clear calls
calls_1.clear()
calls_2.clear()
calls_3.clear()
# Unregister middle hook
pipeline.unregister_before_step_hook(hook2)
# Run again - only hook1 and hook3 should be called
pipeline(transition)
assert calls_1 == ["hook1_step0", "hook1_step1"]
assert calls_2 == [] # hook2 was unregistered
assert calls_3 == ["hook3_step0", "hook3_step1"]
def test_hook_execution_order_documentation():
"""Test and document that hooks are executed sequentially in registration order."""
pipeline = RobotProcessor([MockStep("step")])
execution_order = []
def hook_a(idx: int, transition: EnvTransition):
execution_order.append("A")
return None
def hook_b(idx: int, transition: EnvTransition):
execution_order.append("B")
return None
def hook_c(idx: int, transition: EnvTransition):
execution_order.append("C")
return None
# Register in specific order: A, B, C
pipeline.register_before_step_hook(hook_a)
pipeline.register_before_step_hook(hook_b)
pipeline.register_before_step_hook(hook_c)
transition = create_transition()
pipeline(transition)
# Verify execution order matches registration order
assert execution_order == ["A", "B", "C"]
# Test that after unregistering B and re-registering it, it goes to the end
pipeline.unregister_before_step_hook(hook_b)
execution_order.clear()
pipeline(transition)
assert execution_order == ["A", "C"] # B is gone
# Re-register B - it should now be at the end
pipeline.register_before_step_hook(hook_b)
execution_order.clear()
pipeline(transition)
assert execution_order == ["A", "C", "B"] # B is now last
def test_profile_steps(): def test_profile_steps():
"""Test step profiling functionality.""" """Test step profiling functionality."""
step1 = MockStep("step1") step1 = MockStep("step1")