mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
refactor(pipeline): Clarify hook behavior and improve documentation
- Updated the RobotProcessor class to ensure hooks are strictly for observation and do not modify transitions, enhancing clarity and maintainability. - Refactored hook registration methods to reflect the new behavior, ensuring they accept only functions that do not return modified transitions. - Enhanced documentation to clearly outline the purpose of hooks and their execution semantics. - Added tests to verify that hooks are not executed during the step_through method while ensuring they function correctly during the __call__ method.
This commit is contained in:
@@ -265,10 +265,8 @@ class RobotProcessor(ModelHubMixin):
|
||||
Hook Semantics:
|
||||
- Hooks are executed sequentially in the order they were registered. There is no way to
|
||||
reorder hooks after registration without creating a new pipeline.
|
||||
- Hooks CAN modify transitions by returning a new transition dict. If a hook returns None,
|
||||
the current transition remains unchanged. While this capability exists, it should be used
|
||||
with EXTREME CAUTION as it can make debugging difficult and create unexpected side effects.
|
||||
IT'S ADVISED TO NOT MODIFY THE TRANSITION IN A HOOK.
|
||||
- Hooks are for observation/monitoring only and DO NOT modify transitions. They are called
|
||||
with the step index and current transition for logging, debugging, or monitoring purposes.
|
||||
- All hooks for a given type (before/after) are executed for every step, or none at all if
|
||||
an error occurs. There is no partial execution of hooks.
|
||||
- Hooks should generally be stateless to maintain predictable behavior. If you need stateful
|
||||
@@ -287,15 +285,10 @@ class RobotProcessor(ModelHubMixin):
|
||||
default_factory=lambda: _default_transition_to_batch, repr=False
|
||||
)
|
||||
|
||||
# Processor-level hooks
|
||||
# A hook can optionally return a modified transition. If it returns
|
||||
# ``None`` the current value is left untouched.
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], EnvTransition | None]] = field(
|
||||
default_factory=list, repr=False
|
||||
)
|
||||
# Processor-level hooks for observation/monitoring
|
||||
# Hooks do not modify transitions - they are called for logging, debugging, or monitoring purposes
|
||||
before_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
after_step_hooks: list[Callable[[int, EnvTransition], None]] = field(default_factory=list, repr=False)
|
||||
reset_hooks: list[Callable[[], None]] = field(default_factory=list, repr=False)
|
||||
|
||||
def __call__(self, data: EnvTransition | dict[str, Any]):
|
||||
@@ -316,54 +309,34 @@ class RobotProcessor(ModelHubMixin):
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
iterator = self.step_through(data)
|
||||
current_result = next(iterator) # Get initial state
|
||||
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
# It's a batch dict, convert it
|
||||
called_with_batch = True
|
||||
transition = self.to_transition(data)
|
||||
else:
|
||||
# It's already an EnvTransition
|
||||
called_with_batch = False
|
||||
transition = data
|
||||
|
||||
# Basic validation
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
# Hook execution subtleties:
|
||||
# - Hooks are executed sequentially in the order they were registered (list order)
|
||||
# - Each hook sees the potentially modified transition from the previous hook
|
||||
# - If a hook returns None, the transition remains unchanged
|
||||
# - All hooks for a given type (before/after) run for every step, creating a
|
||||
# multiplicative effect: N steps × M hooks = N×M hook executions
|
||||
# - Hook execution cannot be interrupted - they all run or none run (on error)
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
# Process through all steps with hooks
|
||||
for idx, step_result in enumerate(iterator):
|
||||
# Apply before hooks
|
||||
for hook in self.before_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
transition = processor_step(transition)
|
||||
_ = hook(idx, step_result)
|
||||
|
||||
# Apply after hooks
|
||||
for hook in self.after_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
_ = hook(idx, step_result)
|
||||
|
||||
return self.to_output(transition) if called_with_batch else transition
|
||||
current_result = step_result
|
||||
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
return current_result
|
||||
|
||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
def _prepare_transition(self, data: EnvTransition | dict[str, Any]) -> tuple[EnvTransition, bool]:
|
||||
"""Prepare and validate transition data for processing.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
Returns:
|
||||
A tuple of (prepared_transition, called_with_batch_flag)
|
||||
|
||||
Raises:
|
||||
ValueError: If the transition is not a valid EnvTransition format.
|
||||
"""
|
||||
# Check if data is already an EnvTransition or needs conversion
|
||||
if isinstance(data, dict) and not all(isinstance(k, TransitionKey) for k in data.keys()):
|
||||
@@ -379,22 +352,32 @@ class RobotProcessor(ModelHubMixin):
|
||||
if not isinstance(transition, dict):
|
||||
raise ValueError(f"EnvTransition must be a dictionary. Got {type(transition).__name__}")
|
||||
|
||||
return transition, called_with_batch
|
||||
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
This is a low-level method that does NOT apply hooks. It simply executes each step
|
||||
and yields the intermediate results. This allows users to debug the pipeline or
|
||||
apply custom logic between steps if needed.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition dicts or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition dict or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
"""
|
||||
transition, called_with_batch = self._prepare_transition(data)
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
# Process each step WITHOUT hooks (low-level method)
|
||||
for processor_step in self.steps:
|
||||
transition = processor_step(transition)
|
||||
|
||||
for hook in self.after_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
_CFG_NAME = "processor.json"
|
||||
@@ -654,11 +637,11 @@ class RobotProcessor(ModelHubMixin):
|
||||
return RobotProcessor(self.steps[idx], self.name, self.seed)
|
||||
return self.steps[idx]
|
||||
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
def register_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
"""Attach fn to be executed before every processor step."""
|
||||
self.before_step_hooks.append(fn)
|
||||
|
||||
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
def unregister_before_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
"""Remove a previously registered before_step hook.
|
||||
|
||||
Args:
|
||||
@@ -674,11 +657,11 @@ class RobotProcessor(ModelHubMixin):
|
||||
f"Hook {fn} not found in before_step_hooks. Make sure to pass the exact same function reference."
|
||||
) from None
|
||||
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
def register_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
"""Attach fn to be executed after every processor step."""
|
||||
self.after_step_hooks.append(fn)
|
||||
|
||||
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], EnvTransition | None]):
|
||||
def unregister_after_step_hook(self, fn: Callable[[int, EnvTransition], None]):
|
||||
"""Remove a previously registered after_step hook.
|
||||
|
||||
Args:
|
||||
|
||||
@@ -263,6 +263,40 @@ def test_step_through_with_dict():
|
||||
# For now, just check that we get dict outputs
|
||||
|
||||
|
||||
def test_step_through_no_hooks():
|
||||
"""Test that step_through doesn't execute hooks."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
hook_calls = []
|
||||
|
||||
def tracking_hook(idx: int, transition: EnvTransition):
|
||||
hook_calls.append(f"hook_called_step_{idx}")
|
||||
|
||||
# Register hooks
|
||||
pipeline.register_before_step_hook(tracking_hook)
|
||||
pipeline.register_after_step_hook(tracking_hook)
|
||||
|
||||
# Use step_through
|
||||
transition = create_transition()
|
||||
results = list(pipeline.step_through(transition))
|
||||
|
||||
# Verify step was executed (counter should increment)
|
||||
assert len(results) == 2 # Initial + 1 step
|
||||
assert results[1][TransitionKey.COMPLEMENTARY_DATA]["test_step_counter"] == 0
|
||||
|
||||
# Verify hooks were NOT called
|
||||
assert len(hook_calls) == 0
|
||||
|
||||
# Now use __call__ to verify hooks ARE called there
|
||||
hook_calls.clear()
|
||||
pipeline(transition)
|
||||
|
||||
# Verify hooks were called (before and after for 1 step = 2 calls)
|
||||
assert len(hook_calls) == 2
|
||||
assert hook_calls == ["hook_called_step_0", "hook_called_step_0"]
|
||||
|
||||
|
||||
def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
step1 = MockStep("step1")
|
||||
@@ -290,11 +324,9 @@ def test_hooks():
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return transition
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return transition
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
@@ -306,24 +338,6 @@ def test_hooks():
|
||||
assert after_calls == [0]
|
||||
|
||||
|
||||
def test_hook_modification():
|
||||
"""Test that hooks can modify transitions."""
|
||||
step = MockStep("test_step")
|
||||
pipeline = RobotProcessor([step])
|
||||
|
||||
def modify_reward_hook(idx: int, transition: EnvTransition):
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.REWARD] = 42.0
|
||||
return new_transition
|
||||
|
||||
pipeline.register_before_step_hook(modify_reward_hook)
|
||||
|
||||
transition = create_transition()
|
||||
result = pipeline(transition)
|
||||
|
||||
assert result[TransitionKey.REWARD] == 42.0 # reward modified by hook
|
||||
|
||||
|
||||
def test_reset():
|
||||
"""Test pipeline reset functionality."""
|
||||
step = MockStep("test_step")
|
||||
@@ -360,7 +374,6 @@ def test_unregister_hooks():
|
||||
|
||||
def before_hook(idx: int, transition: EnvTransition):
|
||||
before_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_before_step_hook(before_hook)
|
||||
|
||||
@@ -380,7 +393,6 @@ def test_unregister_hooks():
|
||||
|
||||
def after_hook(idx: int, transition: EnvTransition):
|
||||
after_calls.append(idx)
|
||||
return None
|
||||
|
||||
pipeline.register_after_step_hook(after_hook)
|
||||
pipeline(transition)
|
||||
@@ -412,7 +424,7 @@ def test_unregister_nonexistent_hook():
|
||||
pipeline = RobotProcessor([MockStep()])
|
||||
|
||||
def some_hook(idx: int, transition: EnvTransition):
|
||||
return None
|
||||
pass
|
||||
|
||||
def reset_hook():
|
||||
pass
|
||||
@@ -438,15 +450,12 @@ def test_multiple_hooks_and_selective_unregister():
|
||||
|
||||
def hook1(idx: int, transition: EnvTransition):
|
||||
calls_1.append(f"hook1_step{idx}")
|
||||
return None
|
||||
|
||||
def hook2(idx: int, transition: EnvTransition):
|
||||
calls_2.append(f"hook2_step{idx}")
|
||||
return None
|
||||
|
||||
def hook3(idx: int, transition: EnvTransition):
|
||||
calls_3.append(f"hook3_step{idx}")
|
||||
return None
|
||||
|
||||
# Register multiple hooks
|
||||
pipeline.register_before_step_hook(hook1)
|
||||
@@ -485,15 +494,12 @@ def test_hook_execution_order_documentation():
|
||||
|
||||
def hook_a(idx: int, transition: EnvTransition):
|
||||
execution_order.append("A")
|
||||
return None
|
||||
|
||||
def hook_b(idx: int, transition: EnvTransition):
|
||||
execution_order.append("B")
|
||||
return None
|
||||
|
||||
def hook_c(idx: int, transition: EnvTransition):
|
||||
execution_order.append("C")
|
||||
return None
|
||||
|
||||
# Register in specific order: A, B, C
|
||||
pipeline.register_before_step_hook(hook_a)
|
||||
|
||||
Reference in New Issue
Block a user