diff --git a/src/lerobot/processor/batch_processor.py b/src/lerobot/processor/batch_processor.py index 1ba016b4e..a563599cd 100644 --- a/src/lerobot/processor/batch_processor.py +++ b/src/lerobot/processor/batch_processor.py @@ -34,6 +34,7 @@ from .pipeline import ( PolicyActionProcessorStep, ProcessorStep, ProcessorStepRegistry, + TransitionKey, ) @@ -227,9 +228,12 @@ class AddBatchDimensionProcessorStep(ProcessorStep): Returns: The environment transition with a batch dimension added. """ - transition = self.to_batch_action_processor(transition) - transition = self.to_batch_observation_processor(transition) - transition = self.to_batch_complementary_data_processor(transition) + if transition[TransitionKey.ACTION] is not None: + transition = self.to_batch_action_processor(transition) + if transition[TransitionKey.OBSERVATION] is not None: + transition = self.to_batch_observation_processor(transition) + if transition[TransitionKey.COMPLEMENTARY_DATA] is not None: + transition = self.to_batch_complementary_data_processor(transition) return transition def transform_features(