refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
This commit is contained in:
Adil Zouitine
2025-07-22 10:41:22 +02:00
parent 77106697c3
commit 26cb9a24c3
2 changed files with 84 additions and 95 deletions
+35 -29
View File
@@ -263,6 +263,40 @@ def test_step_through_with_dict():
# For now, just check that we get dict outputs
def test_step_through_no_hooks():
"""Test that step_through doesn't execute hooks."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
hook_calls = []
def tracking_hook(idx: int, transition: EnvTransition):
hook_calls.append(f"hook_called_step_{idx}")
# Register hooks
pipeline.register_before_step_hook(tracking_hook)
pipeline.register_after_step_hook(tracking_hook)
# Use step_through
transition = create_transition()
results = list(pipeline.step_through(transition))
# Verify step was executed (counter should increment)
assert len(results) == 2 # Initial + 1 step
assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Verify hooks were NOT called
assert len(hook_calls) == 0
# Now use __call__ to verify hooks ARE called there
hook_calls.clear()
pipeline(transition)
# Verify hooks were called (before and after for 1 step = 2 calls)
assert len(hook_calls) == 2
assert hook_calls == ["hook_called_step_0", "hook_called_step_0"]
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
@@ -290,11 +324,9 @@ def test_hooks():
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
@@ -306,24 +338,6 @@ def test_hooks():
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = create_transition()
result = pipeline(transition)
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
@@ -360,7 +374,6 @@ def test_unregister_hooks():
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return None
pipeline.register_before_step_hook(before_hook)
@@ -380,7 +393,6 @@ def test_unregister_hooks():
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return None
pipeline.register_after_step_hook(after_hook)
pipeline(transition)
@@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook():
pipeline = RobotProcessor([MockStep()])
def some_hook(idx: int, transition: EnvTransition):
return None
pass
def reset_hook():
pass
@@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister():
def hook1(idx: int, transition: EnvTransition):
calls_1.append(f"hook1_step{idx}")
return None
def hook2(idx: int, transition: EnvTransition):
calls_2.append(f"hook2_step{idx}")
return None
def hook3(idx: int, transition: EnvTransition):
calls_3.append(f"hook3_step{idx}")
return None
# Register multiple hooks
pipeline.register_before_step_hook(hook1)
@@ -485,15 +494,12 @@ def test_hook_execution_order_documentation():
def hook_a(idx: int, transition: EnvTransition):
execution_order.append("A")
return None
def hook_b(idx: int, transition: EnvTransition):
execution_order.append("B")
return None
def hook_c(idx: int, transition: EnvTransition):
execution_order.append("C")
return None
# Register in specific order: A, B, C
pipeline.register_before_step_hook(hook_a)