mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 14:09:47 +00:00
feat(pipeline): Enhance step_through method to support both tuple and dict inputs
This commit is contained in:
@@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin):
|
||||
|
||||
return self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
||||
"""Yield the intermediate Transition instances after each processor step."""
|
||||
yield transition
|
||||
for processor_step in self.steps:
|
||||
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||
"""Yield the intermediate results after each processor step.
|
||||
|
||||
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
|
||||
and preserves the input format in the yielded results.
|
||||
|
||||
Args:
|
||||
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||
|
||||
Yields:
|
||||
The intermediate results after each step, in the same format as the input.
|
||||
"""
|
||||
called_with_batch = isinstance(data, dict)
|
||||
transition = self.to_transition(data) if called_with_batch else data
|
||||
|
||||
# Basic validation with helpful error message for tuple input
|
||||
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||
raise ValueError(
|
||||
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||
"truncated, info, complementary_data). "
|
||||
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||
)
|
||||
|
||||
# Yield initial state
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
for idx, processor_step in enumerate(self.steps):
|
||||
for hook in self.before_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
transition = processor_step(transition)
|
||||
yield transition
|
||||
|
||||
for hook in self.after_step_hooks:
|
||||
updated = hook(idx, transition)
|
||||
if updated is not None:
|
||||
transition = updated
|
||||
|
||||
yield self.to_output(transition) if called_with_batch else transition
|
||||
|
||||
_CFG_NAME = "processor.json"
|
||||
|
||||
|
||||
@@ -194,7 +194,7 @@ def test_invalid_transition_format():
|
||||
|
||||
|
||||
def test_step_through():
|
||||
"""Test step_through method."""
|
||||
"""Test step_through method with tuple input."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
@@ -208,6 +208,40 @@ def test_step_through():
|
||||
assert "step1_counter" in results[1][6] # After step1
|
||||
assert "step2_counter" in results[2][6] # After step2
|
||||
|
||||
# Ensure all results are tuples (same format as input)
|
||||
for result in results:
|
||||
assert isinstance(result, tuple)
|
||||
assert len(result) == 7
|
||||
|
||||
|
||||
def test_step_through_with_dict():
|
||||
"""Test step_through method with dict input."""
|
||||
step1 = MockStep("step1")
|
||||
step2 = MockStep("step2")
|
||||
pipeline = RobotProcessor([step1, step2])
|
||||
|
||||
batch = {
|
||||
"observation.image": None,
|
||||
"action": None,
|
||||
"next.reward": 0.0,
|
||||
"next.done": False,
|
||||
"next.truncated": False,
|
||||
"info": {},
|
||||
}
|
||||
|
||||
results = list(pipeline.step_through(batch))
|
||||
|
||||
assert len(results) == 3 # Original + 2 steps
|
||||
|
||||
# Ensure all results are dicts (same format as input)
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
|
||||
# Check that the processing worked - the complementary data from steps
|
||||
# should show up in the info or complementary_data fields when converted back to dict
|
||||
# Note: This depends on how _default_transition_to_batch handles complementary_data
|
||||
# For now, just check that we get dict outputs
|
||||
|
||||
|
||||
def test_indexing():
|
||||
"""Test pipeline indexing."""
|
||||
|
||||
Reference in New Issue
Block a user