mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 04:30:10 +00:00
feat(pipeline): Add hook unregistration functionality and enhance documentation
- Implemented methods to unregister before, after, and reset hooks in the RobotProcessor class, allowing for more flexible hook management. - Enhanced documentation to clarify hook execution semantics and the implications of modifying transitions within hooks. - Added comprehensive tests to verify the correct behavior of hook registration and unregistration, including error handling for non-existent hooks.
This commit is contained in:
@@ -261,6 +261,19 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
after_step_hooks: List of hooks called after each step. Each hook receives the step
|
||||||
index and transition, and can optionally return a modified transition.
|
index and transition, and can optionally return a modified transition.
|
||||||
reset_hooks: List of hooks called during processor reset.
|
reset_hooks: List of hooks called during processor reset.
|
||||||
|
|
||||||
|
Hook Semantics:
|
||||||
|
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||||
|
reorder hooks after registration without creating a new pipeline.
|
||||||
|
- Hooks CAN modify transitions by returning a new transition dict. If a hook returns None,
|
||||||
|
the current transition remains unchanged. While this capability exists, it should be used
|
||||||
|
with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects.
|
||||||
|
IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK.
|
||||||
|
- All hooks for a given type (before/after) are executed for every step, or none at all if
|
||||||
|
an error occurs. There is no partial execution of hooks.
|
||||||
|
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
||||||
|
processing, consider implementing a proper ProcessorStep instead.
|
||||||
|
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||||
@@ -318,6 +331,13 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
if not isinstance(transition, dict):
|
if not isinstance(transition, dict):
|
||||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||||
|
|
||||||
|
# Hook execution subtleties:
|
||||||
|
# - Hooks are executed sequentially in the order they were registered (list order)
|
||||||
|
# - Each hook sees the potentially modified transition from the previous hook
|
||||||
|
# - If a hook returns None, the transition remains unchanged
|
||||||
|
# - All hooks for a given type (before/after) run for every step, creating a
|
||||||
|
# multiplicative effect: N steps × M hooks = N×M hook executions
|
||||||
|
# - Hook execution cannot be interrupted - they all run or none run (on error)
|
||||||
for idx, processor_step in enumerate(self.steps):
|
for idx, processor_step in enumerate(self.steps):
|
||||||
for hook in self.before_step_hooks:
|
for hook in self.before_step_hooks:
|
||||||
updated = hook(idx, transition)
|
updated = hook(idx, transition)
|
||||||
@@ -638,14 +658,62 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
"""Attach fn to be executed before every processor step."""
|
"""Attach fn to be executed before every processor step."""
|
||||||
self.before_step_hooks.append(fn)
|
self.before_step_hooks.append(fn)
|
||||||
|
|
||||||
|
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||||
|
"""Remove a previously registered before_step hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The exact function reference that was registered. Must be the same object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the hook is not found in the registered hooks.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.before_step_hooks.remove(fn)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
|
||||||
|
) from None
|
||||||
|
|
||||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||||
"""Attach fn to be executed after every processor step."""
|
"""Attach fn to be executed after every processor step."""
|
||||||
self.after_step_hooks.append(fn)
|
self.after_step_hooks.append(fn)
|
||||||
|
|
||||||
|
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||||
|
"""Remove a previously registered after_step hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The exact function reference that was registered. Must be the same object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the hook is not found in the registered hooks.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.after_step_hooks.remove(fn)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hook {fn} not found in after_step_hooks. Make sure to pass the exact same function reference."
|
||||||
|
) from None
|
||||||
|
|
||||||
def register_reset_hook(self, fn: Callable[[], None]):
|
def register_reset_hook(self, fn: Callable[[], None]):
|
||||||
"""Attach fn to be executed when reset is called."""
|
"""Attach fn to be executed when reset is called."""
|
||||||
self.reset_hooks.append(fn)
|
self.reset_hooks.append(fn)
|
||||||
|
|
||||||
|
def unregister_reset_hook(self, fn: Callable[[], None]):
|
||||||
|
"""Remove a previously registered reset hook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
fn: The exact function reference that was registered. Must be the same object.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the hook is not found in the registered hooks.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
self.reset_hooks.remove(fn)
|
||||||
|
except ValueError:
|
||||||
|
raise ValueError(
|
||||||
|
f"Hook {fn} not found in reset_hooks. Make sure to pass the exact same function reference."
|
||||||
|
) from None
|
||||||
|
|
||||||
def reset(self):
|
def reset(self):
|
||||||
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
"""Clear state in every step that implements ``reset()`` and fire registered hooks."""
|
||||||
for step in self.steps:
|
for step in self.steps:
|
||||||
|
|||||||
@@ -350,6 +350,177 @@ def test_reset():
|
|||||||
assert len(reset_called) == 1
|
assert len(reset_called) == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_unregister_hooks():
|
||||||
|
"""Test unregistering hooks from the pipeline."""
|
||||||
|
step = MockStep("test_step")
|
||||||
|
pipeline = RobotProcessor([step])
|
||||||
|
|
||||||
|
# Test before_step_hook
|
||||||
|
before_calls = []
|
||||||
|
|
||||||
|
def before_hook(idx: int, transition: EnvTransition):
|
||||||
|
before_calls.append(idx)
|
||||||
|
return None
|
||||||
|
|
||||||
|
pipeline.register_before_step_hook(before_hook)
|
||||||
|
|
||||||
|
# Verify hook is registered
|
||||||
|
transition = create_transition()
|
||||||
|
pipeline(transition)
|
||||||
|
assert len(before_calls) == 1
|
||||||
|
|
||||||
|
# Unregister and verify it's no longer called
|
||||||
|
pipeline.unregister_before_step_hook(before_hook)
|
||||||
|
before_calls.clear()
|
||||||
|
pipeline(transition)
|
||||||
|
assert len(before_calls) == 0
|
||||||
|
|
||||||
|
# Test after_step_hook
|
||||||
|
after_calls = []
|
||||||
|
|
||||||
|
def after_hook(idx: int, transition: EnvTransition):
|
||||||
|
after_calls.append(idx)
|
||||||
|
return None
|
||||||
|
|
||||||
|
pipeline.register_after_step_hook(after_hook)
|
||||||
|
pipeline(transition)
|
||||||
|
assert len(after_calls) == 1
|
||||||
|
|
||||||
|
pipeline.unregister_after_step_hook(after_hook)
|
||||||
|
after_calls.clear()
|
||||||
|
pipeline(transition)
|
||||||
|
assert len(after_calls) == 0
|
||||||
|
|
||||||
|
# Test reset_hook
|
||||||
|
reset_calls = []
|
||||||
|
|
||||||
|
def reset_hook():
|
||||||
|
reset_calls.append(True)
|
||||||
|
|
||||||
|
pipeline.register_reset_hook(reset_hook)
|
||||||
|
pipeline.reset()
|
||||||
|
assert len(reset_calls) == 1
|
||||||
|
|
||||||
|
pipeline.unregister_reset_hook(reset_hook)
|
||||||
|
reset_calls.clear()
|
||||||
|
pipeline.reset()
|
||||||
|
assert len(reset_calls) == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_unregister_nonexistent_hook():
|
||||||
|
"""Test error handling when unregistering hooks that don't exist."""
|
||||||
|
pipeline = RobotProcessor([MockStep()])
|
||||||
|
|
||||||
|
def some_hook(idx: int, transition: EnvTransition):
|
||||||
|
return None
|
||||||
|
|
||||||
|
def reset_hook():
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Test unregistering hooks that were never registered
|
||||||
|
with pytest.raises(ValueError, match="not found in before_step_hooks"):
|
||||||
|
pipeline.unregister_before_step_hook(some_hook)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found in after_step_hooks"):
|
||||||
|
pipeline.unregister_after_step_hook(some_hook)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="not found in reset_hooks"):
|
||||||
|
pipeline.unregister_reset_hook(reset_hook)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_hooks_and_selective_unregister():
|
||||||
|
"""Test registering multiple hooks and selectively unregistering them."""
|
||||||
|
pipeline = RobotProcessor([MockStep("step1"), MockStep("step2")])
|
||||||
|
|
||||||
|
calls_1 = []
|
||||||
|
calls_2 = []
|
||||||
|
calls_3 = []
|
||||||
|
|
||||||
|
def hook1(idx: int, transition: EnvTransition):
|
||||||
|
calls_1.append(f"hook1_step{idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook2(idx: int, transition: EnvTransition):
|
||||||
|
calls_2.append(f"hook2_step{idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook3(idx: int, transition: EnvTransition):
|
||||||
|
calls_3.append(f"hook3_step{idx}")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Register multiple hooks
|
||||||
|
pipeline.register_before_step_hook(hook1)
|
||||||
|
pipeline.register_before_step_hook(hook2)
|
||||||
|
pipeline.register_before_step_hook(hook3)
|
||||||
|
|
||||||
|
# Run pipeline - all hooks should be called for both steps
|
||||||
|
transition = create_transition()
|
||||||
|
pipeline(transition)
|
||||||
|
|
||||||
|
assert calls_1 == ["hook1_step0", "hook1_step1"]
|
||||||
|
assert calls_2 == ["hook2_step0", "hook2_step1"]
|
||||||
|
assert calls_3 == ["hook3_step0", "hook3_step1"]
|
||||||
|
|
||||||
|
# Clear calls
|
||||||
|
calls_1.clear()
|
||||||
|
calls_2.clear()
|
||||||
|
calls_3.clear()
|
||||||
|
|
||||||
|
# Unregister middle hook
|
||||||
|
pipeline.unregister_before_step_hook(hook2)
|
||||||
|
|
||||||
|
# Run again - only hook1 and hook3 should be called
|
||||||
|
pipeline(transition)
|
||||||
|
|
||||||
|
assert calls_1 == ["hook1_step0", "hook1_step1"]
|
||||||
|
assert calls_2 == [] # hook2 was unregistered
|
||||||
|
assert calls_3 == ["hook3_step0", "hook3_step1"]
|
||||||
|
|
||||||
|
|
||||||
|
def test_hook_execution_order_documentation():
|
||||||
|
"""Test and document that hooks are executed sequentially in registration order."""
|
||||||
|
pipeline = RobotProcessor([MockStep("step")])
|
||||||
|
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
def hook_a(idx: int, transition: EnvTransition):
|
||||||
|
execution_order.append("A")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook_b(idx: int, transition: EnvTransition):
|
||||||
|
execution_order.append("B")
|
||||||
|
return None
|
||||||
|
|
||||||
|
def hook_c(idx: int, transition: EnvTransition):
|
||||||
|
execution_order.append("C")
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Register in specific order: A, B, C
|
||||||
|
pipeline.register_before_step_hook(hook_a)
|
||||||
|
pipeline.register_before_step_hook(hook_b)
|
||||||
|
pipeline.register_before_step_hook(hook_c)
|
||||||
|
|
||||||
|
transition = create_transition()
|
||||||
|
pipeline(transition)
|
||||||
|
|
||||||
|
# Verify execution order matches registration order
|
||||||
|
assert execution_order == ["A", "B", "C"]
|
||||||
|
|
||||||
|
# Test that after unregistering B and re-registering it, it goes to the end
|
||||||
|
pipeline.unregister_before_step_hook(hook_b)
|
||||||
|
execution_order.clear()
|
||||||
|
|
||||||
|
pipeline(transition)
|
||||||
|
assert execution_order == ["A", "C"] # B is gone
|
||||||
|
|
||||||
|
# Re-register B - it should now be at the end
|
||||||
|
pipeline.register_before_step_hook(hook_b)
|
||||||
|
execution_order.clear()
|
||||||
|
|
||||||
|
pipeline(transition)
|
||||||
|
assert execution_order == ["A", "C", "B"] # B is now last
|
||||||
|
|
||||||
|
|
||||||
def test_profile_steps():
|
def test_profile_steps():
|
||||||
"""Test step profiling functionality."""
|
"""Test step profiling functionality."""
|
||||||
step1 = MockStep("step1")
|
step1 = MockStep("step1")
|
||||||
|
|||||||
Reference in New Issue
Block a user