diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index ae9e03099..d2beed3c3 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -329,12 +329,46 @@ class RobotProcessor(ModelHubMixin): return self.to_output(transition) if called_with_batch else transition - def step_through(self, transition: EnvTransition) -> Iterable[EnvTransition]: - """Yield the intermediate Transition instances after each processor step.""" - yield transition - for processor_step in self.steps: + def step_through(self, data: EnvTransition | dict[str, Any]) -> Iterable[EnvTransition | dict[str, Any]]: + """Yield the intermediate results after each processor step. + + Like __call__, this method accepts either EnvTransition tuples or batch dictionaries + and preserves the input format in the yielded results. + + Args: + data: Either an EnvTransition tuple or a batch dictionary to process. + + Yields: + The intermediate results after each step, in the same format as the input. + """ + called_with_batch = isinstance(data, dict) + transition = self.to_transition(data) if called_with_batch else data + + # Basic validation with helpful error message for tuple input + if not isinstance(transition, tuple) or len(transition) != 7: + raise ValueError( + "EnvTransition must be a 7-tuple of (observation, action, reward, done, " + "truncated, info, complementary_data). " + f"Got {type(transition).__name__} with length {len(transition) if hasattr(transition, '__len__') else 'unknown'}." + ) + + # Yield initial state + yield self.to_output(transition) if called_with_batch else transition + + for idx, processor_step in enumerate(self.steps): + for hook in self.before_step_hooks: + updated = hook(idx, transition) + if updated is not None: + transition = updated + transition = processor_step(transition) - yield transition + + for hook in self.after_step_hooks: + updated = hook(idx, transition) + if updated is not None: + transition = updated + + yield self.to_output(transition) if called_with_batch else transition _CFG_NAME = "processor.json" diff --git a/tests/processor/test_pipeline.py b/tests/processor/test_pipeline.py index 280913e49..b5952b412 100644 --- a/tests/processor/test_pipeline.py +++ b/tests/processor/test_pipeline.py @@ -194,7 +194,7 @@ def test_invalid_transition_format(): def test_step_through(): - """Test step_through method.""" + """Test step_through method with tuple input.""" step1 = MockStep("step1") step2 = MockStep("step2") pipeline = RobotProcessor([step1, step2]) @@ -208,6 +208,40 @@ def test_step_through(): assert "step1_counter" in results[1][6] # After step1 assert "step2_counter" in results[2][6] # After step2 + # Ensure all results are tuples (same format as input) + for result in results: + assert isinstance(result, tuple) + assert len(result) == 7 + + +def test_step_through_with_dict(): + """Test step_through method with dict input.""" + step1 = MockStep("step1") + step2 = MockStep("step2") + pipeline = RobotProcessor([step1, step2]) + + batch = { + "observation.image": None, + "action": None, + "next.reward": 0.0, + "next.done": False, + "next.truncated": False, + "info": {}, + } + + results = list(pipeline.step_through(batch)) + + assert len(results) == 3 # Original + 2 steps + + # Ensure all results are dicts (same format as input) + for result in results: + assert isinstance(result, dict) + + # Check that the processing worked - the complementary data from steps + # should show up in the info or complementary_data fields when converted back to dict + # Note: This depends on how _default_transition_to_batch handles complementary_data + # For now, just check that we get dict outputs + def test_indexing(): """Test pipeline indexing."""