docs(pipeline): Clarify transition handling and hook behavior

- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats.
- Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change.
- Enhanced test assertions to verify the structure of results and the correctness of processing steps.
This commit is contained in:
Adil Zouitine
2025-08-02 14:51:52 +02:00
parent 2c4e888c7f
commit 41959389b6
2 changed files with 46 additions and 23 deletions
+30 -18
View File
@@ -284,6 +284,9 @@ class RobotProcessor(ModelHubMixin):
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
processing, consider implementing a proper ProcessorStep instead.
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
- Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format
passed to __call__. This ensures consistent hook behavior whether processing batch dicts
or EnvTransition objects.
"""
steps: Sequence[ProcessorStep] = field(default_factory=list)
@@ -321,22 +324,30 @@ class RobotProcessor(ModelHubMixin):
Raises:
ValueError: If the transition is not a valid EnvTransition format.
"""
iterator = self.step_through(data)
current_result = next(iterator) # Get initial state
# Check if we need to convert back to batch format at the end
_, called_with_batch = self._prepare_transition(data)
# Process through all steps with hooks
for idx, step_result in enumerate(iterator):
# Apply before hooks
# Use step_through to get the iterator
step_iterator = self.step_through(data)
# Get initial state (before any steps)
current_transition = next(step_iterator)
# Process each step with hooks
for idx, next_transition in enumerate(step_iterator):
# Apply before hooks with current state (before step execution)
for hook in self.before_step_hooks:
_ = hook(idx, step_result)
hook(idx, current_transition)
# Apply after hooks
# Move to next state (after step execution)
current_transition = next_transition
# Apply after hooks with updated state
for hook in self.after_step_hooks:
_ = hook(idx, step_result)
hook(idx, current_transition)
current_result = step_result
return current_result
# Convert back to original format if needed
return self.to_output(current_transition) if called_with_batch else current_transition
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
"""Prepare and validate transition data for processing.
@@ -366,31 +377,32 @@ class RobotProcessor(ModelHubMixin):
return transition, called_with_batch
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
"""Yield the intermediate results after each processor step.
This is a low-level method that does NOT apply hooks. It simply executes each step
and yields the intermediate results. This allows users to debug the pipeline or
apply custom logic between steps if needed.
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
and preserves the input format in the yielded results.
Note: This method always yields EnvTransition objects regardless of input format.
If you need the results in the original input format, you'll need to convert them
using `to_output()`.
Args:
data: Either an EnvTransition dict or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
The intermediate EnvTransition results after each step.
"""
transition, called_with_batch = self._prepare_transition(data)
transition, _ = self._prepare_transition(data)
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
yield transition
# Process each step WITHOUT hooks (low-level method)
for processor_step in self.steps:
transition = processor_step(transition)
yield self.to_output(transition) if called_with_batch else transition
yield transition
def _save_pretrained(self, destination_path: str, **kwargs):
"""Internal save method for ModelHubMixin compatibility."""
+16 -5
View File
@@ -268,14 +268,25 @@ def test_step_through_with_dict():
assert len(results) == 3 # Original + 2 steps
# Ensure all results are dicts (same format as input)
# Ensure all results are EnvTransition dicts (regardless of input format)
for result in results:
assert isinstance(result, dict)
# Check that keys are TransitionKey enums or at least valid transition keys
for key in result:
assert key in [
TransitionKey.OBSERVATION,
TransitionKey.ACTION,
TransitionKey.REWARD,
TransitionKey.DONE,
TransitionKey.TRUNCATED,
TransitionKey.INFO,
TransitionKey.COMPLEMENTARY_DATA,
]
# Check that the processing worked - the complementary data from steps
# should show up in the info or complementary_data fields when converted back to dict
# Note: This depends on how _default_transition_to_batch handles complementary_data
# For now, just check that we get dict outputs
# Check that the processing worked - verify step counters in complementary_data
assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0
def test_step_through_no_hooks():