From 023b8f3466eeefe574bd0e7dfb7b8f4d6c16e033 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 24 Jul 2025 17:25:19 +0200 Subject: [PATCH] feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration - Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options. --- src/lerobot/policies/factory.py | 24 +++++++++++++++++++++--- 1 file changed, 21 insertions(+), 3 deletions(-) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ea9e06a8e..9873c3a84 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -14,7 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging +from typing import TypedDict +from typing_extensions import Unpack from torch import nn @@ -102,10 +106,17 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: raise ValueError(f"Policy type '{policy_type}' is not available.") +class ProcessorConfigKwargs(TypedDict, total=False): + """Keyword arguments for the processor config.""" + + preprocessor_config_filename: str | None + postprocessor_config_filename: str | None + + def make_processor( policy_cfg: PreTrainedConfig, pretrained_path: str | None = None, - **kwargs, + **kwargs: Unpack[ProcessorConfigKwargs], ) -> tuple[RobotProcessor, RobotProcessor]: """Make a processor instance for a given policy type. @@ -127,8 +138,15 @@ def make_processor( if pretrained_path: # Load a pretrained processor # TODO(azouitine): Handle this case. - return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor( - steps=[IdentityProcessor()], name="post_processor" + return ( + RobotProcessor.from_pretrained( + source=pretrained_path, + config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"), + ), + RobotProcessor.from_pretrained( + source=pretrained_path, + config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"), + ), ) # Create a new processor based on policy type