From c227107f609d983362343d6957807e9010c389db Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Fri, 4 Jul 2025 13:07:58 +0200 Subject: [PATCH] feat (device processor): Implement device processor --- src/lerobot/processor/__init__.py | 3 ++ src/lerobot/processor/device_processor.py | 62 +++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 src/lerobot/processor/device_processor.py diff --git a/src/lerobot/processor/__init__.py b/src/lerobot/processor/__init__.py index 1b104199b..3f0267eae 100644 --- a/src/lerobot/processor/__init__.py +++ b/src/lerobot/processor/__init__.py @@ -13,6 +13,8 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. + +from .device_processor import DeviceProcessor from .normalize_processor import NormalizerProcessor, UnnormalizerProcessor from .observation_processor import ( ImageProcessor, @@ -34,6 +36,7 @@ from .rename_processor import RenameProcessor __all__ = [ "ActionProcessor", + "DeviceProcessor", "DoneProcessor", "EnvTransition", "ImageProcessor", diff --git a/src/lerobot/processor/device_processor.py b/src/lerobot/processor/device_processor.py new file mode 100644 index 000000000..0ff6ef9da --- /dev/null +++ b/src/lerobot/processor/device_processor.py @@ -0,0 +1,62 @@ +#!/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any + +import torch + +from lerobot.processor.pipeline import EnvTransition, TransitionIndex + + +@dataclass +class DeviceProcessor: + """Processes transitions by moving tensors to the specified device. + + This processor ensures that all tensors in the transition are moved to the + specified device (CPU or GPU) before they are returned. + """ + + device: str = "cpu" + + def __call__(self, transition: EnvTransition) -> EnvTransition: + observation: dict[str, torch.Tensor] = transition[TransitionIndex.OBSERVATION] + action = transition[TransitionIndex.ACTION] + reward = transition[TransitionIndex.REWARD] + done = transition[TransitionIndex.DONE] + truncated = transition[TransitionIndex.TRUNCATED] + info = transition[TransitionIndex.INFO] + complementary_data = transition[TransitionIndex.COMPLEMENTARY_DATA] + + if observation is not None: + observation = {k: v.to(self.device) for k, v in observation.items()} + if action is not None: + action = action.to(self.device) + + return ( + observation, + action, + reward, + done, + truncated, + info, + complementary_data, + ) + + def get_config(self) -> dict[str, Any]: + """Return configuration for serialization.""" + return {"device": self.device}