refactor(pipeline): Clarify hook behavior and improve documentation

- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability.
- Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions.
- Enhanced documentation to clearly outline the purpose of hooks and their execution semantics.
- Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
This commit is contained in:
Adil Zouitine
2025-07-22 10:41:22 +02:00
parent 77106697c3
commit 26cb9a24c3
2 changed files with 84 additions and 95 deletions
+49 -66
View File
@@ -265,10 +265,8 @@ class RobotProcessor(ModelHubMixin):
Hook Semantics: Hook Semantics:
- Hooks are executed sequentially in the order they were registered. There is no way to - Hooks are executed sequentially in the order they were registered. There is no way to
reorder hooks after registration without creating a new pipeline. reorder hooks after registration without creating a new pipeline.
- Hooks CAN modify transitions by returning a new transition dict. If a hook returns None, - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called
the current transition remains unchanged. While this capability exists, it should be used with the step index and current transition for logging, debugging, or monitoring purposes.
with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects.
IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK.
- All hooks for a given type (before/after) are executed for every step, or none at all if - All hooks for a given type (before/after) are executed for every step, or none at all if
an error occurs. There is no partial execution of hooks. an error occurs. There is no partial execution of hooks.
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful - Hooks should generally be stateless to maintain predictable behavior. If you need stateful
@@ -287,15 +285,10 @@ class RobotProcessor(ModelHubMixin):
default_factory=lambda: _default_transition_to_batch, repr=False default_factory=lambda: _default_transition_to_batch, repr=False
) )
# Processor-level hooks # Processor-level hooks for observation/monitoring
# A hook can optionally return a modified transition. If it returns # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
# ``None`` the current value is left untouched. before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field( after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
default_factory=list, repr=False
)
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
default_factory=list, repr=False
)
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
def __call__(self, data: EnvTransition | dict[str, Any]): def __call__(self, data: EnvTransition | dict[str, Any]):
@@ -316,54 +309,34 @@ class RobotProcessor(ModelHubMixin):
Raises: Raises:
ValueError: If the transition is not a valid EnvTransition format. ValueError: If the transition is not a valid EnvTransition format.
""" """
iterator = self.step_through(data)
current_result = next(iterator) # Get initial state
# Check if data is already an EnvTransition or needs conversion # Process through all steps with hooks
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): for idx, step_result in enumerate(iterator):
# It's a batch dict, convert it # Apply before hooks
called_with_batch = True
transition = self.to_transition(data)
else:
# It's already an EnvTransition
called_with_batch = False
transition = data
# Basic validation
if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
# Hook execution subtleties:
# - Hooks are executed sequentially in the order they were registered (list order)
# - Each hook sees the potentially modified transition from the previous hook
# - If a hook returns None, the transition remains unchanged
# - All hooks for a given type (before/after) run for every step, creating a
# multiplicative effect: N steps × M hooks = N×M hook executions
# - Hook execution cannot be interrupted - they all run or none run (on error)
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks: for hook in self.before_step_hooks:
updated = hook(idx, transition) _ = hook(idx, step_result)
if updated is not None:
transition = updated
transition = processor_step(transition)
# Apply after hooks
for hook in self.after_step_hooks: for hook in self.after_step_hooks:
updated = hook(idx, transition) _ = hook(idx, step_result)
if updated is not None:
transition = updated
return self.to_output(transition) if called_with_batch else transition current_result = step_result
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: return current_result
"""Yield the intermediate results after each processor step.
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
and preserves the input format in the yielded results. """Prepare and validate transition data for processing.
Args: Args:
data: Either an EnvTransition dict or a batch dictionary to process. data: Either an EnvTransition dict or a batch dictionary to process.
Yields: Returns:
The intermediate results after each step, in the same format as the input. A tuple of (prepared_transition, called_with_batch_flag)
Raises:
ValueError: If the transition is not a valid EnvTransition format.
""" """
# Check if data is already an EnvTransition or needs conversion # Check if data is already an EnvTransition or needs conversion
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
@@ -379,22 +352,32 @@ class RobotProcessor(ModelHubMixin):
if not isinstance(transition, dict): if not isinstance(transition, dict):
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed.
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
transition, called_with_batch = self._prepare_transition(data)
# Yield initial state # Yield initial state
yield self.to_output(transition) if called_with_batch else transition yield self.to_output(transition) if called_with_batch else transition
for idx, processor_step in enumerate(self.steps): # Process each step WITHOUT hooks (low-level method)
for hook in self.before_step_hooks: for processor_step in self.steps:
updated = hook(idx, transition)
if updated is not None:
transition = updated
transition = processor_step(transition) transition = processor_step(transition)
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
yield self.to_output(transition) if called_with_batch else transition yield self.to_output(transition) if called_with_batch else transition
_CFG_NAME = "processor.json" _CFG_NAME = "processor.json"
@@ -654,11 +637,11 @@ class RobotProcessor(ModelHubMixin):
return RobotProcessor(self.steps[idx], self.name, self.seed) return RobotProcessor(self.steps[idx], self.name, self.seed)
return self.steps[idx] return self.steps[idx]
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Attach fn to be executed before every processor step.""" """Attach fn to be executed before every processor step."""
self.before_step_hooks.append(fn) self.before_step_hooks.append(fn)
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Remove a previously registered before_step hook. """Remove a previously registered before_step hook.
Args: Args:
@@ -674,11 +657,11 @@ class RobotProcessor(ModelHubMixin):
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
) from None ) from None
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Attach fn to be executed after every processor step.""" """Attach fn to be executed after every processor step."""
self.after_step_hooks.append(fn) self.after_step_hooks.append(fn)
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
"""Remove a previously registered after_step hook. """Remove a previously registered after_step hook.
Args: Args:
+35 -29
View File
@@ -263,6 +263,40 @@ def test_step_through_with_dict():
# For now, just check that we get dict outputs # For now, just check that we get dict outputs
def test_step_through_no_hooks():
"""Test that step_through doesn't execute hooks."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
hook_calls = []
def tracking_hook(idx: int, transition: EnvTransition):
hook_calls.append(f"hook_called_step_{idx}")
# Register hooks
pipeline.register_before_step_hook(tracking_hook)
pipeline.register_after_step_hook(tracking_hook)
# Use step_through
transition = create_transition()
results = list(pipeline.step_through(transition))
# Verify step was executed (counter should increment)
assert len(results) == 2 # Initial + 1 step
assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
# Verify hooks were NOT called
assert len(hook_calls) == 0
# Now use __call__ to verify hooks ARE called there
hook_calls.clear()
pipeline(transition)
# Verify hooks were called (before and after for 1 step = 2 calls)
assert len(hook_calls) == 2
assert hook_calls == ["hook_called_step_0", "hook_called_step_0"]
def test_indexing(): def test_indexing():
"""Test pipeline indexing.""" """Test pipeline indexing."""
step1 = MockStep("step1") step1 = MockStep("step1")
@@ -290,11 +324,9 @@ def test_hooks():
def before_hook(idx: int, transition: EnvTransition): def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx) before_calls.append(idx)
return transition
def after_hook(idx: int, transition: EnvTransition): def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx) after_calls.append(idx)
return transition
pipeline.register_before_step_hook(before_hook) pipeline.register_before_step_hook(before_hook)
pipeline.register_after_step_hook(after_hook) pipeline.register_after_step_hook(after_hook)
@@ -306,24 +338,6 @@ def test_hooks():
assert after_calls == [0] assert after_calls == [0]
def test_hook_modification():
"""Test that hooks can modify transitions."""
step = MockStep("test_step")
pipeline = RobotProcessor([step])
def modify_reward_hook(idx: int, transition: EnvTransition):
new_transition = transition.copy()
new_transition[TransitionKey.REWARD] = 42.0
return new_transition
pipeline.register_before_step_hook(modify_reward_hook)
transition = create_transition()
result = pipeline(transition)
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
def test_reset(): def test_reset():
"""Test pipeline reset functionality.""" """Test pipeline reset functionality."""
step = MockStep("test_step") step = MockStep("test_step")
@@ -360,7 +374,6 @@ def test_unregister_hooks():
def before_hook(idx: int, transition: EnvTransition): def before_hook(idx: int, transition: EnvTransition):
before_calls.append(idx) before_calls.append(idx)
return None
pipeline.register_before_step_hook(before_hook) pipeline.register_before_step_hook(before_hook)
@@ -380,7 +393,6 @@ def test_unregister_hooks():
def after_hook(idx: int, transition: EnvTransition): def after_hook(idx: int, transition: EnvTransition):
after_calls.append(idx) after_calls.append(idx)
return None
pipeline.register_after_step_hook(after_hook) pipeline.register_after_step_hook(after_hook)
pipeline(transition) pipeline(transition)
@@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook():
pipeline = RobotProcessor([MockStep()]) pipeline = RobotProcessor([MockStep()])
def some_hook(idx: int, transition: EnvTransition): def some_hook(idx: int, transition: EnvTransition):
return None pass
def reset_hook(): def reset_hook():
pass pass
@@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister():
def hook1(idx: int, transition: EnvTransition): def hook1(idx: int, transition: EnvTransition):
calls_1.append(f"hook1_step{idx}") calls_1.append(f"hook1_step{idx}")
return None
def hook2(idx: int, transition: EnvTransition): def hook2(idx: int, transition: EnvTransition):
calls_2.append(f"hook2_step{idx}") calls_2.append(f"hook2_step{idx}")
return None
def hook3(idx: int, transition: EnvTransition): def hook3(idx: int, transition: EnvTransition):
calls_3.append(f"hook3_step{idx}") calls_3.append(f"hook3_step{idx}")
return None
# Register multiple hooks # Register multiple hooks
pipeline.register_before_step_hook(hook1) pipeline.register_before_step_hook(hook1)
@@ -485,15 +494,12 @@ def test_hook_execution_order_documentation():
def hook_a(idx: int, transition: EnvTransition): def hook_a(idx: int, transition: EnvTransition):
execution_order.append("A") execution_order.append("A")
return None
def hook_b(idx: int, transition: EnvTransition): def hook_b(idx: int, transition: EnvTransition):
execution_order.append("B") execution_order.append("B")
return None
def hook_c(idx: int, transition: EnvTransition): def hook_c(idx: int, transition: EnvTransition):
execution_order.append("C") execution_order.append("C")
return None
# Register in specific order: A, B, C # Register in specific order: A, B, C
pipeline.register_before_step_hook(hook_a) pipeline.register_before_step_hook(hook_a)