diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py index 7dcd8abda..feb5eb72b 100644 --- a/src/lerobot/processor/device_processor.py +++ b/src/lerobot/processor/device_processor.py @@ -13,8 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from dataclasses import dataclass from typing import Any diff --git a/src/lerobot/processor/observation_processor.py b/src/lerobot/processor/observation_processor.py index 0d3718c43..c2c240d32 100644 --- a/src/lerobot/processor/observation_processor.py +++ b/src/lerobot/processor/observation_processor.py @@ -13,8 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from dataclasses import dataclass, field from typing import Any diff --git a/src/lerobot/processor/pipeline.py b/src/lerobot/processor/pipeline.py index 62fa732aa..545704463 100644 --- a/src/lerobot/processor/pipeline.py +++ b/src/lerobot/processor/pipeline.py @@ -187,9 +187,10 @@ def _default_batch_to_transition(batch: dict[str, Any]) -> EnvTransition: # noq observation_keys = {k: v for k, v in batch.items() if k.startswith("observation.")} observation = observation_keys if observation_keys else None - # Extract padding keys for complementary data + # Extract padding and task keys for complementary data pad_keys = {k: v for k, v in batch.items() if "_is_pad" in k} - complementary_data = pad_keys if pad_keys else {} + task_key = {"task": batch["task"]} if "task" in batch else {} + complementary_data = {**pad_keys, **task_key} if pad_keys or task_key else {} return ( observation, @@ -225,11 +226,14 @@ def _default_transition_to_batch(transition: EnvTransition) -> dict[str, Any]: "info": info, } - # Add padding data from complementary_data + # Add padding and task data from complementary_data if complementary_data: pad_data = {k: v for k, v in complementary_data.items() if "_is_pad" in k} batch.update(pad_data) + if "task" in complementary_data: + batch["task"] = complementary_data["task"] + # Handle observation - flatten dict to observation.* keys if it's a dict if isinstance(observation, dict): batch.update(observation) @@ -947,6 +951,40 @@ class InfoProcessor: return transition +class ComplementaryDataProcessor: + """Base class for processors that modify only the complementary data of a transition. + + Subclasses should override the `complementary_data` method to implement custom complementary data processing. + This class handles the boilerplate of extracting and reinserting the processed complementary data + into the transition tuple, eliminating the need to implement the `__call__` method in subclasses. + """ + + def complementary_data(self, complementary_data): + """Process the complementary data. + + Args: + complementary_data: The complementary data to process + + Returns: + The processed complementary data + """ + return complementary_data + + def __call__(self, transition: EnvTransition) -> EnvTransition: + complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] + complementary_data = self.complementary_data(complementary_data) + transition = ( + transition[TransitionIndex.OBSERVATION], + transition[TransitionIndex.ACTION], + transition[TransitionIndex.REWARD], + transition[TransitionIndex.DONE], + transition[TransitionIndex.TRUNCATED], + transition[TransitionIndex.INFO], + complementary_data, + ) + return transition + + class IdentityProcessor: """Identity processor that does nothing.""" diff --git a/src/lerobot/processor/rename_processor.py b/src/lerobot/processor/rename_processor.py index 8e6b1a03f..0eb3d0b98 100644 --- a/src/lerobot/processor/rename_processor.py +++ b/src/lerobot/processor/rename_processor.py @@ -13,8 +13,6 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. -from __future__ import annotations - from dataclasses import dataclass, field from typing import Any