docs(pipeline): Clarify transition handling and hook behavior

- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats.
- Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change.
- Enhanced test assertions to verify the structure of results and the correctness of processing steps.
This commit is contained in:
Adil Zouitine
2025-08-02 14:51:52 +02:00
parent 2c4e888c7f
commit 41959389b6
2 changed files with 46 additions and 23 deletions
+30 -18
View File
@@ -284,6 +284,9 @@ class RobotProcessor(ModelHubMixin):
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful - Hooks should generally be stateless to maintain predictable behavior. If you need stateful
processing, consider implementing a proper ProcessorStep instead. processing, consider implementing a proper ProcessorStep instead.
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline. - To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
- Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format
passed to __call__. This ensures consistent hook behavior whether processing batch dicts
or EnvTransition objects.
""" """
steps: Sequence[ProcessorStep] = field(default_factory=list) steps: Sequence[ProcessorStep] = field(default_factory=list)
@@ -321,22 +324,30 @@ class RobotProcessor(ModelHubMixin):
Raises: Raises:
ValueError: If the transition is not a valid EnvTransition format. ValueError: If the transition is not a valid EnvTransition format.
""" """
iterator = self.step_through(data) # Check if we need to convert back to batch format at the end
current_result = next(iterator) # Get initial state _, called_with_batch = self._prepare_transition(data)
# Process through all steps with hooks # Use step_through to get the iterator
for idx, step_result in enumerate(iterator): step_iterator = self.step_through(data)
# Apply before hooks
# Get initial state (before any steps)
current_transition = next(step_iterator)
# Process each step with hooks
for idx, next_transition in enumerate(step_iterator):
# Apply before hooks with current state (before step execution)
for hook in self.before_step_hooks: for hook in self.before_step_hooks:
_ = hook(idx, step_result) hook(idx, current_transition)
# Apply after hooks # Move to next state (after step execution)
current_transition = next_transition
# Apply after hooks with updated state
for hook in self.after_step_hooks: for hook in self.after_step_hooks:
_ = hook(idx, step_result) hook(idx, current_transition)
current_result = step_result # Convert back to original format if needed
return self.to_output(current_transition) if called_with_batch else current_transition
return current_result
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]: def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing. """Prepare and validate transition data for processing.
@@ -366,31 +377,32 @@ class RobotProcessor(ModelHubMixin):
return transition, called_with_batch return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step. """Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed. apply custom logic between steps if needed.
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries Note: This method always yields EnvTransition objects regardless of input format.
and preserves the input format in the yielded results. If you need the results in the original input format, you'll need to convert them
using `to_output()`.
Args: Args:
data: Either an EnvTransition dict or a batch dictionary to process. data: Either an EnvTransition dict or a batch dictionary to process.
Yields: Yields:
The intermediate results after each step, in the same format as the input. The intermediate EnvTransition results after each step.
""" """
transition, called_with_batch = self._prepare_transition(data) transition, _ = self._prepare_transition(data)
# Yield initial state # Yield initial state
yield self.to_output(transition) if called_with_batch else transition yield transition
# Process each step WITHOUT hooks (low-level method) # Process each step WITHOUT hooks (low-level method)
for processor_step in self.steps: for processor_step in self.steps:
transition = processor_step(transition) transition = processor_step(transition)
yield self.to_output(transition) if called_with_batch else transition yield transition
def _save_pretrained(self, destination_path: str, **kwargs): def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility.""" """Internal save method for ModelHubMixin compatibility."""
+16 -5
View File
@@ -268,14 +268,25 @@ def test_step_through_with_dict():
assert len(results) == 3 # Original + 2 steps assert len(results) == 3 # Original + 2 steps
# Ensure all results are dicts (same format as input) # Ensure all results are EnvTransition dicts (regardless of input format)
for result in results: for result in results:
assert isinstance(result, dict) assert isinstance(result, dict)
# Check that keys are TransitionKey enums or at least valid transition keys
for key in result:
assert key in [
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.REWARD,
TransitionKey.DONE,
TransitionKey.TRUNCATED,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
]
# Check that the processing worked - the complementary data from steps # Check that the processing worked - verify step counters in complementary_data
# should show up in the info or complementary_data fields when converted back to dict assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
# Note: This depends on how _default_transition_to_batch handles complementary_data assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
# For now, just check that we get dict outputs assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0
def test_step_through_no_hooks(): def test_step_through_no_hooks():