mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
feat(pipeline): Enhance step_through method to support both tuple and dict inputs
This commit is contained in:
@@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin):
|
|||||||
|
|
||||||
return self.to_output(transition) if called_with_batch else transition
|
return self.to_output(transition) if called_with_batch else transition
|
||||||
|
|
||||||
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]:
|
def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
|
||||||
"""Yield the intermediate Transition instances after each processor step."""
|
"""Yield the intermediate results after each processor step.
|
||||||
yield transition
|
|
||||||
for processor_step in self.steps:
|
Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
|
||||||
|
and preserves the input format in the yielded results.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
data: Either an EnvTransition tuple or a batch dictionary to process.
|
||||||
|
|
||||||
|
Yields:
|
||||||
|
The intermediate results after each step, in the same format as the input.
|
||||||
|
"""
|
||||||
|
called_with_batch = isinstance(data, dict)
|
||||||
|
transition = self.to_transition(data) if called_with_batch else data
|
||||||
|
|
||||||
|
# Basic validation with helpful error message for tuple input
|
||||||
|
if not isinstance(transition, tuple) or len(transition) != 7:
|
||||||
|
raise ValueError(
|
||||||
|
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
|
||||||
|
"truncated, info, complementary_data). "
|
||||||
|
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Yield initial state
|
||||||
|
yield self.to_output(transition) if called_with_batch else transition
|
||||||
|
|
||||||
|
for idx, processor_step in enumerate(self.steps):
|
||||||
|
for hook in self.before_step_hooks:
|
||||||
|
updated = hook(idx, transition)
|
||||||
|
if updated is not None:
|
||||||
|
transition = updated
|
||||||
|
|
||||||
transition = processor_step(transition)
|
transition = processor_step(transition)
|
||||||
yield transition
|
|
||||||
|
for hook in self.after_step_hooks:
|
||||||
|
updated = hook(idx, transition)
|
||||||
|
if updated is not None:
|
||||||
|
transition = updated
|
||||||
|
|
||||||
|
yield self.to_output(transition) if called_with_batch else transition
|
||||||
|
|
||||||
_CFG_NAME = "processor.json"
|
_CFG_NAME = "processor.json"
|
||||||
|
|
||||||
|
|||||||
@@ -194,7 +194,7 @@ def test_invalid_transition_format():
|
|||||||
|
|
||||||
|
|
||||||
def test_step_through():
|
def test_step_through():
|
||||||
"""Test step_through method."""
|
"""Test step_through method with tuple input."""
|
||||||
step1 = MockStep("step1")
|
step1 = MockStep("step1")
|
||||||
step2 = MockStep("step2")
|
step2 = MockStep("step2")
|
||||||
pipeline = RobotProcessor([step1, step2])
|
pipeline = RobotProcessor([step1, step2])
|
||||||
@@ -208,6 +208,40 @@ def test_step_through():
|
|||||||
assert "step1_counter" in results[1][6] # After step1
|
assert "step1_counter" in results[1][6] # After step1
|
||||||
assert "step2_counter" in results[2][6] # After step2
|
assert "step2_counter" in results[2][6] # After step2
|
||||||
|
|
||||||
|
# Ensure all results are tuples (same format as input)
|
||||||
|
for result in results:
|
||||||
|
assert isinstance(result, tuple)
|
||||||
|
assert len(result) == 7
|
||||||
|
|
||||||
|
|
||||||
|
def test_step_through_with_dict():
|
||||||
|
"""Test step_through method with dict input."""
|
||||||
|
step1 = MockStep("step1")
|
||||||
|
step2 = MockStep("step2")
|
||||||
|
pipeline = RobotProcessor([step1, step2])
|
||||||
|
|
||||||
|
batch = {
|
||||||
|
"observation.image": None,
|
||||||
|
"action": None,
|
||||||
|
"next.reward": 0.0,
|
||||||
|
"next.done": False,
|
||||||
|
"next.truncated": False,
|
||||||
|
"info": {},
|
||||||
|
}
|
||||||
|
|
||||||
|
results = list(pipeline.step_through(batch))
|
||||||
|
|
||||||
|
assert len(results) == 3 # Original + 2 steps
|
||||||
|
|
||||||
|
# Ensure all results are dicts (same format as input)
|
||||||
|
for result in results:
|
||||||
|
assert isinstance(result, dict)
|
||||||
|
|
||||||
|
# Check that the processing worked - the complementary data from steps
|
||||||
|
# should show up in the info or complementary_data fields when converted back to dict
|
||||||
|
# Note: This depends on how _default_transition_to_batch handles complementary_data
|
||||||
|
# For now, just check that we get dict outputs
|
||||||
|
|
||||||
|
|
||||||
def test_indexing():
|
def test_indexing():
|
||||||
"""Test pipeline indexing."""
|
"""Test pipeline indexing."""
|
||||||
|
|||||||
Reference in New Issue
Block a user