mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
docs(pipeline): Clarify transition handling and hook behavior
- Updated documentation to specify that hooks always receive transitions in EnvTransition format, ensuring consistent behavior across input formats. - Refactored the step_through method to yield only EnvTransition objects, regardless of the input format, and updated related tests to reflect this change. - Enhanced test assertions to verify the structure of results and the correctness of processing steps.
This commit is contained in:
@@ -284,6 +284,9 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
||||||
processing, consider implementing a proper ProcessorStep instead.
|
processing, consider implementing a proper ProcessorStep instead.
|
||||||
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
|
- To remove hooks, use the unregister methods. To remove steps, you must create a new pipeline.
|
||||||
|
- Hooks ALWAYS receive transitions in EnvTransition format, regardless of the input format
|
||||||
|
passed to __call__. This ensures consistent hook behavior whether processing batch dicts
|
||||||
|
or EnvTransition objects.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
steps: Sequence[ProcessorStep] = field(default_factory=list)
|
||||||
@@ -321,22 +324,30 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: If the transition is not a valid EnvTransition format.
|
ValueError: If the transition is not a valid EnvTransition format.
|
||||||
"""
|
"""
|
||||||
iterator = self.step_through(data)
|
# Check if we need to convert back to batch format at the end
|
||||||
current_result = next(iterator) # Get initial state
|
_, called_with_batch = self._prepare_transition(data)
|
||||||
|
|
||||||
# Process through all steps with hooks
|
# Use step_through to get the iterator
|
||||||
for idx, step_result in enumerate(iterator):
|
step_iterator = self.step_through(data)
|
||||||
# Apply before hooks
|
|
||||||
|
# Get initial state (before any steps)
|
||||||
|
current_transition = next(step_iterator)
|
||||||
|
|
||||||
|
# Process each step with hooks
|
||||||
|
for idx, next_transition in enumerate(step_iterator):
|
||||||
|
# Apply before hooks with current state (before step execution)
|
||||||
for hook in self.before_step_hooks:
|
for hook in self.before_step_hooks:
|
||||||
_ = hook(idx, step_result)
|
hook(idx, current_transition)
|
||||||
|
|
||||||
# Apply after hooks
|
# Move to next state (after step execution)
|
||||||
|
current_transition = next_transition
|
||||||
|
|
||||||
|
# Apply after hooks with updated state
|
||||||
for hook in self.after_step_hooks:
|
for hook in self.after_step_hooks:
|
||||||
_ = hook(idx, step_result)
|
hook(idx, current_transition)
|
||||||
|
|
||||||
current_result = step_result
|
# Convert back to original format if needed
|
||||||
|
return self.to_output(current_transition) if called_with_batch else current_transition
|
||||||
return current_result
|
|
||||||
|
|
||||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||||
"""Prepare and validate transition data for processing.
|
"""Prepare and validate transition data for processing.
|
||||||
@@ -366,31 +377,32 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
return transition, called_with_batch
|
return transition, called_with_batch
|
||||||
|
|
||||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition]:
|
||||||
"""Yield the intermediate results after each processor step.
|
"""Yield the intermediate results after each processor step.
|
||||||
|
|
||||||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||||||
and yields the intermediate results. This allows users to debug the pipeline or
|
and yields the intermediate results. This allows users to debug the pipeline or
|
||||||
apply custom logic between steps if needed.
|
apply custom logic between steps if needed.
|
||||||
|
|
||||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
Note: This method always yields EnvTransition objects regardless of input format.
|
||||||
and preserves the input format in the yielded results.
|
If you need the results in the original input format, you'll need to convert them
|
||||||
|
using `to_output()`.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||||
|
|
||||||
Yields:
|
Yields:
|
||||||
The intermediate results after each step, in the same format as the input.
|
The intermediate EnvTransition results after each step.
|
||||||
"""
|
"""
|
||||||
transition, called_with_batch = self._prepare_transition(data)
|
transition, _ = self._prepare_transition(data)
|
||||||
|
|
||||||
# Yield initial state
|
# Yield initial state
|
||||||
yield self.to_output(transition) if called_with_batch else transition
|
yield transition
|
||||||
|
|
||||||
# Process each step WITHOUT hooks (low-level method)
|
# Process each step WITHOUT hooks (low-level method)
|
||||||
for processor_step in self.steps:
|
for processor_step in self.steps:
|
||||||
transition = processor_step(transition)
|
transition = processor_step(transition)
|
||||||
yield self.to_output(transition) if called_with_batch else transition
|
yield transition
|
||||||
|
|
||||||
def _save_pretrained(self, destination_path: str, **kwargs):
|
def _save_pretrained(self, destination_path: str, **kwargs):
|
||||||
"""Internal save method for ModelHubMixin compatibility."""
|
"""Internal save method for ModelHubMixin compatibility."""
|
||||||
|
|||||||
@@ -268,14 +268,25 @@ def test_step_through_with_dict():
|
|||||||
|
|
||||||
assert len(results) == 3 # Original + 2 steps
|
assert len(results) == 3 # Original + 2 steps
|
||||||
|
|
||||||
# Ensure all results are dicts (same format as input)
|
# Ensure all results are EnvTransition dicts (regardless of input format)
|
||||||
for result in results:
|
for result in results:
|
||||||
assert isinstance(result, dict)
|
assert isinstance(result, dict)
|
||||||
|
# Check that keys are TransitionKey enums or at least valid transition keys
|
||||||
|
for key in result:
|
||||||
|
assert key in [
|
||||||
|
TransitionKey.OBSERVATION,
|
||||||
|
TransitionKey.ACTION,
|
||||||
|
TransitionKey.REWARD,
|
||||||
|
TransitionKey.DONE,
|
||||||
|
TransitionKey.TRUNCATED,
|
||||||
|
TransitionKey.INFO,
|
||||||
|
TransitionKey.COMPLEMENTARY_DATA,
|
||||||
|
]
|
||||||
|
|
||||||
# Check that the processing worked - the complementary data from steps
|
# Check that the processing worked - verify step counters in complementary_data
|
||||||
# should show up in the info or complementary_data fields when converted back to dict
|
assert results[1].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
|
||||||
# Note: This depends on how _default_transition_to_batch handles complementary_data
|
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step1_counter") == 0
|
||||||
# For now, just check that we get dict outputs
|
assert results[2].get(TransitionKey.COMPLEMENTARY_DATA, {}).get("step2_counter") == 0
|
||||||
|
|
||||||
|
|
||||||
def test_step_through_no_hooks():
|
def test_step_through_no_hooks():
|
||||||
|
|||||||
Reference in New Issue
Block a user