mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
docs(pipeline): Clarify transition handling and hook behavior
- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps.
This commit is contained in:
@@ -284,6 +284,9 @@ class RobotProcessor(ModelHubMixin):
|
||||
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
||||
processing, consider implementing a proper ProcessorStep instead.
|
||||
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
|
||||
- Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format
|
||||
passed to __call__. This ensures consistent hook behavior whether processing batch dicts
|
||||
or EnvTransition objects.
|
||||
"""
|
||||
|
||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||
@@ -321,22 +324,30 @@ class RobotProcessor(ModelHubMixin):
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
iterator = self.step_through(data)
|
||||
current_result = next(iterator) # Get initial state
|
||||
# Check if we need to convert back to batch format at the end
|
||||
_, called_with_batch = self._prepare_transition(data)
|
||||
|
||||
# Process through all steps with hooks
|
||||
for idx, step_result in enumerate(iterator):
|
||||
# Apply before hooks
|
||||
# Use step_through to get the iterator
|
||||
step_iterator = self.step_through(data)
|
||||
|
||||
# Get initial state (before any steps)
|
||||
current_transition = next(step_iterator)
|
||||
|
||||
# Process each step with hooks
|
||||
for idx, next_transition in enumerate(step_iterator):
|
||||
# Apply before hooks with current state (before step execution)
|
||||
for hook in self.before_step_hooks:
|
||||
_ = hook(idx, step_result)
|
||||
hook(idx, current_transition)
|
||||
|
||||
# Apply after hooks
|
||||
# Move to next state (after step execution)
|
||||
current_transition = next_transition
|
||||
|
||||
# Apply after hooks with updated state
|
||||
for hook in self.after_step_hooks:
|
||||
_ = hook(idx, step_result)
|
||||
hook(idx, current_transition)
|
||||
|
||||
current_result = step_result
|
||||
|
||||
return current_result
|
||||
# Convert back to original format if needed
|
||||
return self.to_output(current_transition) if called_with_batch else current_transition
|
||||
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
@@ -366,31 +377,32 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
return transition, called_with_batch
|
||||
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||||
and yields the intermediate results. This allows users to debug the pipeline or
|
||||
apply custom logic between steps if needed.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
Note: This method always yields EnvTransition objects regardless of input format.
|
||||
If you need the results in the original input format, you'll need to convert them
|
||||
using `to_output()`.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
The intermediate EnvTransition results after each step.
|
||||
"""
|
||||
transition, called_with_batch = self._prepare_transition(data)
|
||||
transition, _ = self._prepare_transition(data)
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
yield transition
|
||||
|
||||
# Process each step WITHOUT hooks (low-level method)
|
||||
for processor_step in self.steps:
|
||||
transition = processor_step(transition)
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
yield transition
|
||||
|
||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||
"""Internal save method for ModelHubMixin compatibility."""
|
||||
|
||||
@@ -268,14 +268,25 @@ def test_step_through_with_dict():
|
||||
|
||||
assert len(results) == 3 # Original + 2 steps
|
||||
|
||||
# Ensure all results are dicts (same format as input)
|
||||
# Ensure all results are EnvTransition dicts (regardless of input format)
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
# Check that keys are TransitionKey enums or at least valid transition keys
|
||||
for key in result:
|
||||
assert key in [
|
||||
TransitionKey.OBSERVATION,
|
||||
TransitionKey.ACTION,
|
||||
TransitionKey.REWARD,
|
||||
TransitionKey.DONE,
|
||||
TransitionKey.TRUNCATED,
|
||||
TransitionKey.INFO,
|
||||
TransitionKey.COMPLEMENTARY_DATA,
|
||||
]
|
||||
|
||||
# Check that the processing worked - the complementary data from steps
|
||||
# should show up in the info or complementary_data fields when converted back to dict
|
||||
# Note: This depends on how _default_transition_to_batch handles complementary_data
|
||||
# For now, just check that we get dict outputs
|
||||
# Check that the processing worked - verify step counters in complementary_data
|
||||
assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
|
||||
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
|
||||
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0
|
||||
|
||||
|
||||
def test_step_through_no_hooks():
|
||||
|
||||
Reference in New Issue
Block a user