diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index f945f367b..9db6d2f25 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -284,6 +284,9 @@ class RobotProcessor(ModelHubMixin): - Hooks should generally be stateless to maintain predictable behavior. If you need stateful processing, consider implementing a proper ProcessorStep instead. - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. + - Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format + passed to __call__. This ensures consistent hook behavior whether processing batch dicts + or EnvTransition objects. """ steps: Sequence[ProcessorStep] = field(default_factory=list) @@ -321,22 +324,30 @@ class RobotProcessor(ModelHubMixin): Raises: ValueError: If the transition is not a valid EnvTransition format. """ - iterator = self.step_through(data) - current_result = next(iterator) # Get initial state + # Check if we need to convert back to batch format at the end + _, called_with_batch = self._prepare_transition(data) - # Process through all steps with hooks - for idx, step_result in enumerate(iterator): - # Apply before hooks + # Use step_through to get the iterator + step_iterator = self.step_through(data) + + # Get initial state (before any steps) + current_transition = next(step_iterator) + + # Process each step with hooks + for idx, next_transition in enumerate(step_iterator): + # Apply before hooks with current state (before step execution) for hook in self.before_step_hooks: - _ = hook(idx, step_result) + hook(idx, current_transition) - # Apply after hooks + # Move to next state (after step execution) + current_transition = next_transition + + # Apply after hooks with updated state for hook in self.after_step_hooks: - _ = hook(idx, step_result) + hook(idx, current_transition) - current_result = step_result - - return current_result + # Convert back to original format if needed + return self.to_output(current_transition) if called_with_batch else current_transition def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: """Prepare and validate transition data for processing. @@ -366,31 +377,32 @@ class RobotProcessor(ModelHubMixin): return transition, called_with_batch - def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]: """Yield the intermediate results after each processor step. This is a low-level method that does NOT apply hooks. It simply executes each step and yields the intermediate results. This allows users to debug the pipeline or apply custom logic between steps if needed. - Like __call__, this method accepts either EnvTransition dicts or batch dictionaries - and preserves the input format in the yielded results. + Note: This method always yields EnvTransition objects regardless of input format. + If you need the results in the original input format, you'll need to convert them + using `to_output()`. Args: data: Either an EnvTransition dict or a batch dictionary to process. Yields: - The intermediate results after each step, in the same format as the input. + The intermediate EnvTransition results after each step. """ - transition, called_with_batch = self._prepare_transition(data) + transition, _ = self._prepare_transition(data) # Yield initial state - yield self.to_output(transition) if called_with_batch else transition + yield transition # Process each step WITHOUT hooks (low-level method) for processor_step in self.steps: transition = processor_step(transition) - yield self.to_output(transition) if called_with_batch else transition + yield transition def _save_pretrained(self, destination_path: str, **kwargs): """Internal save method for ModelHubMixin compatibility.""" diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 8c12e9167..a5f39b7ef 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -268,14 +268,25 @@ def test_step_through_with_dict(): assert len(results) == 3 # Original + 2 steps - # Ensure all results are dicts (same format as input) + # Ensure all results are EnvTransition dicts (regardless of input format) for result in results: assert isinstance(result, dict) + # Check that keys are TransitionKey enums or at least valid transition keys + for key in result: + assert key in [ + TransitionKey.OBSERVATION, + TransitionKey.ACTION, + TransitionKey.REWARD, + TransitionKey.DONE, + TransitionKey.TRUNCATED, + TransitionKey.INFO, + TransitionKey.COMPLEMENTARY_DATA, + ] - # Check that the processing worked - the complementary data from steps - # should show up in the info or complementary_data fields when converted back to dict - # Note: This depends on how _default_transition_to_batch handles complementary_data - # For now, just check that we get dict outputs + # Check that the processing worked - verify step counters in complementary_data + assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0 + assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0 def test_step_through_no_hooks():