refactor(pipeline): Introduce ComplementaryDataProcessor for handling complementary data in transitions

This commit is contained in:
Adil Zouitine
2025-07-09 19:20:43 +02:00
parent f7bb3e2d90
commit 35612c61e1
4 changed files with 41 additions and 9 deletions
@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass
from typing import Any
@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any
+41 -3
View File
@@ -187,9 +187,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq
observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")}
observation = observation_keys if observation_keys else None
# Extract padding keys for complementary data
# Extract padding and task keys for complementary data
pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k}
complementary_data = pad_keys if pad_keys else {}
task_key = {"task": batch["task"]} if "task" in batch else {}
complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {}
return (
observation,
@@ -225,11 +226,14 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]:
"info": info,
}
# Add padding data from complementary_data
# Add padding and task data from complementary_data
if complementary_data:
pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k}
batch.update(pad_data)
if "task" in complementary_data:
batch["task"] = complementary_data["task"]
# Handle observation - flatten dict to observation.* keys if it's a dict
if isinstance(observation, dict):
batch.update(observation)
@@ -947,6 +951,40 @@ class InfoProcessor:
return transition
class ComplementaryDataProcessor:
"""Base class for processors that modify only the complementary data of a transition.
Subclasses should override the `complementary_data` method to implement custom complementary data processing.
This class handles the boilerplate of extracting and reinserting the processed complementary data
into the transition tuple, eliminating the need to implement the `__call__` method in subclasses.
"""
def complementary_data(self, complementary_data):
"""Process the complementary data.
Args:
complementary_data: The complementary data to process
Returns:
The processed complementary data
"""
return complementary_data
def __call__(self, transition: EnvTransition) -> EnvTransition:
complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA]
complementary_data = self.complementary_data(complementary_data)
transition = (
transition[TransitionIndex.OBSERVATION],
transition[TransitionIndex.ACTION],
transition[TransitionIndex.REWARD],
transition[TransitionIndex.DONE],
transition[TransitionIndex.TRUNCATED],
transition[TransitionIndex.INFO],
complementary_data,
)
return transition
class IdentityProcessor:
"""Identity processor that does nothing."""
@@ -13,8 +13,6 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from dataclasses import dataclass, field
from typing import Any