diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index ea9e06a8e..9873c3a84 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -14,7 +14,11 @@ # See the License for the specific language governing permissions and # limitations under the License. +from __future__ import annotations + import logging +from typing import TypedDict +from typing_extensions import Unpack from torch import nn @@ -102,10 +106,17 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig: raise ValueError(f"Policy type '{policy_type}' is not available.") +class ProcessorConfigKwargs(TypedDict, total=False): + """Keyword arguments for the processor config.""" + + preprocessor_config_filename: str | None + postprocessor_config_filename: str | None + + def make_processor( policy_cfg: PreTrainedConfig, pretrained_path: str | None = None, - **kwargs, + **kwargs: Unpack[ProcessorConfigKwargs], ) -> tuple[RobotProcessor, RobotProcessor]: """Make a processor instance for a given policy type. @@ -127,8 +138,15 @@ def make_processor( if pretrained_path: # Load a pretrained processor # TODO(azouitine): Handle this case. - return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor( - steps=[IdentityProcessor()], name="post_processor" + return ( + RobotProcessor.from_pretrained( + source=pretrained_path, + config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"), + ), + RobotProcessor.from_pretrained( + source=pretrained_path, + config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"), + ), ) # Create a new processor based on policy type