fix(processor): use subprocessors in AddBatchDimensionProcessorStep only if we have the ingredients

This commit is contained in:
Steven Palma
2025-09-10 23:52:58 +02:00
parent cda44e5a52
commit 014486999e
+7 -3
View File
@@ -34,6 +34,7 @@ from .pipeline import (
PolicyActionProcessorStep,
ProcessorStep,
ProcessorStepRegistry,
TransitionKey,
)
@@ -227,9 +228,12 @@ class AddBatchDimensionProcessorStep(ProcessorStep):
Returns:
The environment transition with a batch dimension added.
"""
transition = self.to_batch_action_processor(transition)
transition = self.to_batch_observation_processor(transition)
transition = self.to_batch_complementary_data_processor(transition)
if transition[TransitionKey.ACTION] is not None:
transition = self.to_batch_action_processor(transition)
if transition[TransitionKey.OBSERVATION] is not None:
transition = self.to_batch_observation_processor(transition)
if transition[TransitionKey.COMPLEMENTARY_DATA] is not None:
transition = self.to_batch_complementary_data_processor(transition)
return transition
def transform_features(