refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
This commit is contained in:
Adil Zouitine
2025-07-22 10:41:22 +02:00
parent 77106697c3
commit 26cb9a24c3
2 changed files with 84 additions and 95 deletions
+49 -66
View File
@@ -265,10 +265,8 @@ class RobotProcessor(ModelHubMixin):
Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to
reorder hooks after registration without creating a new pipeline.
- Hooks CAN modify transitions by returning a new transition dict. If a hook returns None,
the current transition remains unchanged. While this capability exists, it should be used
with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects.
IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK.
- Hooks are for observation/monitoring only and DO NOT modify transitions. They are called
with the step index and current transition for logging, debugging, or monitoring purposes.
- All hooks for a given type (before/after) are executed for every step, or none at all if
an error occurs. There is no partial execution of hooks.
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
@@ -287,15 +285,10 @@ class RobotProcessor(ModelHubMixin):
default_factory=lambda: _default_transition_to_batch, repr=False
)
# Processor-level hooks
# A hook can optionally return a modified transition. If it returns
# ``None`` the current value is left untouched.
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
# Processor-level hooks for observation/monitoring
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]):
@@ -316,54 +309,34 @@ class RobotProcessor(ModelHubMixin):
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
iterator = self.step_through(data)
current_result = next(iterator) # Get initial state
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
# It's a batch dict, convert it
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
# Hook execution subtleties:
# - Hooks are executed sequentially in the order they were registered (list order)
# - Each hook sees the potentially modified transition from the previous hook
# - If a hook returns None, the transition remains unchanged
# - All hooks for a given type (before/after) run for every step, creating a
# multiplicative effect: N steps × M hooks = N×M hook executions
# - Hook execution cannot be interrupted - they all run or none run (on error)
for idx, processor_step in enumerate(self.steps):
# Process through all steps with hooks
for idx, step_result in enumerate(iterator):
# Apply before hooks
for hook in self.before_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
transition = processor_step(transition)
_ = hook(idx, step_result)
# Apply after hooks
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
_ = hook(idx, step_result)
return self.to_output(transition) if called_with_batch else transition
current_result = step_result
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
return current_result
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
Returns:
A tuple of (prepared_transition, called_with_batch_flag)
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
# Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
@@ -379,22 +352,32 @@ class RobotProcessor(ModelHubMixin):
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed.
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
transition, called_with_batch = self._prepare_transition(data)
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
# Process each step WITHOUT hooks (low-level method)
for processor_step in self.steps:
transition = processor_step(transition)
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
yield self.to_output(transition) if called_with_batch else transition
_CFG_NAME = "processor.json"
@@ -654,11 +637,11 @@ class RobotProcessor(ModelHubMixin):
return RobotProcessor(self.steps[idx], self.name, self.seed)
return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Attach fn to be executed before every processor step."""
self.before_step_hooks.append(fn)
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Remove a previously registered before_step hook.
Args:
@@ -674,11 +657,11 @@ class RobotProcessor(ModelHubMixin):
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
) from None
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Attach fn to be executed after every processor step."""
self.after_step_hooks.append(fn)
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Remove a previously registered after_step hook.
Args:
+35 -29
View File
@@ -263,6 +263,40 @@ def test_step_through_with_dict():
# For now, just check that we get dict outputs
def test_step_through_no_hooks():
"""Test that step_through doesn't execute hooks."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
hook_calls = []
def tracking_hook(idx: int, transition: EnvTransition):
hook_calls.append(f"hook_called_step_{idx}")
# Register hooks
pipeline.register_before_step_hook(tracking_hook)
pipeline.register_after_step_hook(tracking_hook)
# Use step_through
transition = create_transition()
results = list(pipeline.step_through(transition))
# Verify step was executed (counter should increment)
assert len(results) == 2 # Initial + 1 step
assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Verify hooks were NOT called
assert len(hook_calls) == 0
# Now use __call__ to verify hooks ARE called there
hook_calls.clear()
pipeline(transition)
# Verify hooks were called (before and after for 1 step = 2 calls)
assert len(hook_calls) == 2
assert hook_calls == ["hook_called_step_0", "hook_called_step_0"]
def test_indexing():
"""Test pipeline indexing."""
step1 = MockStep("step1")
@@ -290,11 +324,9 @@ def test_hooks():
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook)
@@ -306,24 +338,6 @@ def test_hooks():
assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = create_transition()
result = pipeline(transition)
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset():
"""Test pipeline reset functionality."""
step = MockStep("test_step")
@@ -360,7 +374,6 @@ def test_unregister_hooks():
def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx)
return None
pipeline.register_before_step_hook(before_hook)
@@ -380,7 +393,6 @@ def test_unregister_hooks():
def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx)
return None
pipeline.register_after_step_hook(after_hook)
pipeline(transition)
@@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook():
pipeline = RobotProcessor([MockStep()])
def some_hook(idx: int, transition: EnvTransition):
return None
pass
def reset_hook():
pass
@@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister():
def hook1(idx: int, transition: EnvTransition):
calls_1.append(f"hook1_step{idx}")
return None
def hook2(idx: int, transition: EnvTransition):
calls_2.append(f"hook2_step{idx}")
return None
def hook3(idx: int, transition: EnvTransition):
calls_3.append(f"hook3_step{idx}")
return None
# Register multiple hooks
pipeline.register_before_step_hook(hook1)
@@ -485,15 +494,12 @@ def test_hook_execution_order_documentation():
def hook_a(idx: int, transition: EnvTransition):
execution_order.append("A")
return None
def hook_b(idx: int, transition: EnvTransition):
execution_order.append("B")
return None
def hook_c(idx: int, transition: EnvTransition):
execution_order.append("C")
return None
# Register in specific order: A, B, C
pipeline.register_before_step_hook(hook_a)