feat(pipeline): Enhance step_through method to support both tuple and dict inputs

This commit is contained in:
Adil Zouitine
2025-07-08 13:14:58 +02:00
parent e9f7f5127b
commit fa26290e8c
2 changed files with 74 additions and 6 deletions
+39 -5
View File
@@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin):
return self.to_output(transition) if called_with_batch else transition return self.to_output(transition) if called_with_batch else transition
def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]: def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]:
"""Yield the intermediate Transition instances after each processor step.""" """Yield the intermediate results after each processor step.
yield transition
for processor_step in self.steps: Like __call__, this method accepts either EnvTransition tuples or batch dictionaries
and preserves the input format in the yielded results.
Args:
data: Either an EnvTransition tuple or a batch dictionary to process.
Yields:
The intermediate results after each step, in the same format as the input.
"""
called_with_batch = isinstance(data, dict)
transition = self.to_transition(data) if called_with_batch else data
# Basic validation with helpful error message for tuple input
if not isinstance(transition, tuple) or len(transition) != 7:
raise ValueError(
"EnvTransition must be a 7-tuple of (observation, action, reward, done, "
"truncated, info, complementary_data). "
f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}."
)
# Yield initial state
yield self.to_output(transition) if called_with_batch else transition
for idx, processor_step in enumerate(self.steps):
for hook in self.before_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
transition = processor_step(transition) transition = processor_step(transition)
yield transition
for hook in self.after_step_hooks:
updated = hook(idx, transition)
if updated is not None:
transition = updated
yield self.to_output(transition) if called_with_batch else transition
_CFG_NAME = "processor.json" _CFG_NAME = "processor.json"
+35 -1
View File
@@ -194,7 +194,7 @@ def test_invalid_transition_format():
def test_step_through(): def test_step_through():
"""Test step_through method.""" """Test step_through method with tuple input."""
step1 = MockStep("step1") step1 = MockStep("step1")
step2 = MockStep("step2") step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2]) pipeline = RobotProcessor([step1, step2])
@@ -208,6 +208,40 @@ def test_step_through():
assert "step1_counter" in results[1][6] # After step1 assert "step1_counter" in results[1][6] # After step1
assert "step2_counter" in results[2][6] # After step2 assert "step2_counter" in results[2][6] # After step2
# Ensure all results are tuples (same format as input)
for result in results:
assert isinstance(result, tuple)
assert len(result) == 7
def test_step_through_with_dict():
"""Test step_through method with dict input."""
step1 = MockStep("step1")
step2 = MockStep("step2")
pipeline = RobotProcessor([step1, step2])
batch = {
"observation.image": None,
"action": None,
"next.reward": 0.0,
"next.done": False,
"next.truncated": False,
"info": {},
}
results = list(pipeline.step_through(batch))
assert len(results) == 3 # Original + 2 steps
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, dict)
# Check that the processing worked - the complementary data from steps
# should show up in the info or complementary_data fields when converted back to dict
# Note: This depends on how _default_transition_to_batch handles complementary_data
# For now, just check that we get dict outputs
def test_indexing(): def test_indexing():
"""Test pipeline indexing.""" """Test pipeline indexing."""