From 26cb9a24c3b1d0ee04d255a91c731b92622498c9 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Tue, 22 Jul 2025 10:41:22 +0200 Subject: [PATCH] refactor(pipeline): Clarify hook behavior and improve documentation - Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method. --- src/lerobot/processor/pipeline.py | 115 +++++++++++++----------------- tests/processor/test_pipeline.py | 64 +++++++++-------- 2 files changed, 84 insertions(+), 95 deletions(-) diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 34504ed96..846118266 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -265,10 +265,8 @@ class RobotProcessor(ModelHubMixin): Hook Semantics: - Hooks are executed sequentially in the order they were registered. There is no way to reorder hooks after registration without creating a new pipeline. - - Hooks CAN modify transitions by returning a new transition dict. If a hook returns None, - the current transition remains unchanged. While this capability exists, it should be used - with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects. - IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK. + - Hooks are for observation/monitoring only and DO NOT modify transitions. They are called + with the step index and current transition for logging, debugging, or monitoring purposes. - All hooks for a given type (before/after) are executed for every step, or none at all if an error occurs. There is no partial execution of hooks. - Hooks should generally be stateless to maintain predictable behavior. If you need stateful @@ -287,15 +285,10 @@ class RobotProcessor(ModelHubMixin): default_factory=lambda: _default_transition_to_batch, repr=False ) - # Processor-level hooks - # A hook can optionally return a modified transition. If it returns - # ``None`` the current value is left untouched. - before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field( - default_factory=list, repr=False - ) - after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field( - default_factory=list, repr=False - ) + # Processor-level hooks for observation/monitoring + # Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes + before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) + after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False) reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False) def __call__(self, data: EnvTransition | dict[str, Any]): @@ -316,54 +309,34 @@ class RobotProcessor(ModelHubMixin): Raises: ValueError: If the transition is not a valid EnvTransition format. """ + iterator = self.step_through(data) + current_result = next(iterator) # Get initial state - # Check if data is already an EnvTransition or needs conversion - if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): - # It's a batch dict, convert it - called_with_batch = True - transition = self.to_transition(data) - else: - # It's already an EnvTransition - called_with_batch = False - transition = data - - # Basic validation - if not isinstance(transition, dict): - raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") - - # Hook execution subtleties: - # - Hooks are executed sequentially in the order they were registered (list order) - # - Each hook sees the potentially modified transition from the previous hook - # - If a hook returns None, the transition remains unchanged - # - All hooks for a given type (before/after) run for every step, creating a - # multiplicative effect: N steps × M hooks = N×M hook executions - # - Hook execution cannot be interrupted - they all run or none run (on error) - for idx, processor_step in enumerate(self.steps): + # Process through all steps with hooks + for idx, step_result in enumerate(iterator): + # Apply before hooks for hook in self.before_step_hooks: - updated = hook(idx, transition) - if updated is not None: - transition = updated - - transition = processor_step(transition) + _ = hook(idx, step_result) + # Apply after hooks for hook in self.after_step_hooks: - updated = hook(idx, transition) - if updated is not None: - transition = updated + _ = hook(idx, step_result) - return self.to_output(transition) if called_with_batch else transition + current_result = step_result - def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: - """Yield the intermediate results after each processor step. + return current_result - Like __call__, this method accepts either EnvTransition dicts or batch dictionaries - and preserves the input format in the yielded results. + def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: + """Prepare and validate transition data for processing. Args: data: Either an EnvTransition dict or a batch dictionary to process. - Yields: - The intermediate results after each step, in the same format as the input. + Returns: + A tuple of (prepared_transition, called_with_batch_flag) + + Raises: + ValueError: If the transition is not a valid EnvTransition format. """ # Check if data is already an EnvTransition or needs conversion if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()): @@ -379,22 +352,32 @@ class RobotProcessor(ModelHubMixin): if not isinstance(transition, dict): raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}") + return transition, called_with_batch + + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: + """Yield the intermediate results after each processor step. + + This is a low-level method that does NOT apply hooks. It simply executes each step + and yields the intermediate results. This allows users to debug the pipeline or + apply custom logic between steps if needed. + + Like __call__, this method accepts either EnvTransition dicts or batch dictionaries + and preserves the input format in the yielded results. + + Args: + data: Either an EnvTransition dict or a batch dictionary to process. + + Yields: + The intermediate results after each step, in the same format as the input. + """ + transition, called_with_batch = self._prepare_transition(data) + # Yield initial state yield self.to_output(transition) if called_with_batch else transition - for idx, processor_step in enumerate(self.steps): - for hook in self.before_step_hooks: - updated = hook(idx, transition) - if updated is not None: - transition = updated - + # Process each step WITHOUT hooks (low-level method) + for processor_step in self.steps: transition = processor_step(transition) - - for hook in self.after_step_hooks: - updated = hook(idx, transition) - if updated is not None: - transition = updated - yield self.to_output(transition) if called_with_batch else transition _CFG_NAME = "processor.json" @@ -654,11 +637,11 @@ class RobotProcessor(ModelHubMixin): return RobotProcessor(self.steps[idx], self.name, self.seed) return self.steps[idx] - def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): """Attach fn to be executed before every processor step.""" self.before_step_hooks.append(fn) - def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]): """Remove a previously registered before_step hook. Args: @@ -674,11 +657,11 @@ class RobotProcessor(ModelHubMixin): f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference." ) from None - def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): """Attach fn to be executed after every processor step.""" self.after_step_hooks.append(fn) - def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]): + def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]): """Remove a previously registered after_step hook. Args: diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index f9f6237ff..7a595fcff 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -263,6 +263,40 @@ def test_step_through_with_dict(): # For now, just check that we get dict outputs +def test_step_through_no_hooks(): + """Test that step_through doesn't execute hooks.""" + step = MockStep("test_step") + pipeline = RobotProcessor([step]) + + hook_calls = [] + + def tracking_hook(idx: int, transition: EnvTransition): + hook_calls.append(f"hook_called_step_{idx}") + + # Register hooks + pipeline.register_before_step_hook(tracking_hook) + pipeline.register_after_step_hook(tracking_hook) + + # Use step_through + transition = create_transition() + results = list(pipeline.step_through(transition)) + + # Verify step was executed (counter should increment) + assert len(results) == 2 # Initial + 1 step + assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0 + + # Verify hooks were NOT called + assert len(hook_calls) == 0 + + # Now use __call__ to verify hooks ARE called there + hook_calls.clear() + pipeline(transition) + + # Verify hooks were called (before and after for 1 step = 2 calls) + assert len(hook_calls) == 2 + assert hook_calls == ["hook_called_step_0", "hook_called_step_0"] + + def test_indexing(): """Test pipeline indexing.""" step1 = MockStep("step1") @@ -290,11 +324,9 @@ def test_hooks(): def before_hook(idx: int, transition: EnvTransition): before_calls.append(idx) - return transition def after_hook(idx: int, transition: EnvTransition): after_calls.append(idx) - return transition pipeline.register_before_step_hook(before_hook) pipeline.register_after_step_hook(after_hook) @@ -306,24 +338,6 @@ def test_hooks(): assert after_calls == [0] -def test_hook_modification(): - """Test that hooks can modify transitions.""" - step = MockStep("test_step") - pipeline = RobotProcessor([step]) - - def modify_reward_hook(idx: int, transition: EnvTransition): - new_transition = transition.copy() - new_transition[TransitionKey.REWARD] = 42.0 - return new_transition - - pipeline.register_before_step_hook(modify_reward_hook) - - transition = create_transition() - result = pipeline(transition) - - assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook - - def test_reset(): """Test pipeline reset functionality.""" step = MockStep("test_step") @@ -360,7 +374,6 @@ def test_unregister_hooks(): def before_hook(idx: int, transition: EnvTransition): before_calls.append(idx) - return None pipeline.register_before_step_hook(before_hook) @@ -380,7 +393,6 @@ def test_unregister_hooks(): def after_hook(idx: int, transition: EnvTransition): after_calls.append(idx) - return None pipeline.register_after_step_hook(after_hook) pipeline(transition) @@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook(): pipeline = RobotProcessor([MockStep()]) def some_hook(idx: int, transition: EnvTransition): - return None + pass def reset_hook(): pass @@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister(): def hook1(idx: int, transition: EnvTransition): calls_1.append(f"hook1_step{idx}") - return None def hook2(idx: int, transition: EnvTransition): calls_2.append(f"hook2_step{idx}") - return None def hook3(idx: int, transition: EnvTransition): calls_3.append(f"hook3_step{idx}") - return None # Register multiple hooks pipeline.register_before_step_hook(hook1) @@ -485,15 +494,12 @@ def test_hook_execution_order_documentation(): def hook_a(idx: int, transition: EnvTransition): execution_order.append("A") - return None def hook_b(idx: int, transition: EnvTransition): execution_order.append("B") - return None def hook_c(idx: int, transition: EnvTransition): execution_order.append("C") - return None # Register in specific order: A, B, C pipeline.register_before_step_hook(hook_a)