feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration

- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration.
- Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation.
- Refactored the loading of pretrained processors to utilize the new configuration options.
This commit is contained in:
Adil Zouitine
2025-07-24 17:25:19 +02:00
committed by Steven Palma
parent 1cad87ebd2
commit 023b8f3466
+21 -3
View File
@@ -14,7 +14,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import logging
from typing import TypedDict
from typing_extensions import Unpack
from torch import nn
@@ -102,10 +106,17 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
raise ValueError(f"Policy type '{policy_type}' is not available.")
class ProcessorConfigKwargs(TypedDict, total=False):
"""Keyword arguments for the processor config."""
preprocessor_config_filename: str | None
postprocessor_config_filename: str | None
def make_processor(
policy_cfg: PreTrainedConfig,
pretrained_path: str | None = None,
**kwargs,
**kwargs: Unpack[ProcessorConfigKwargs],
) -> tuple[RobotProcessor, RobotProcessor]:
"""Make a processor instance for a given policy type.
@@ -127,8 +138,15 @@ def make_processor(
if pretrained_path:
# Load a pretrained processor
# TODO(azouitine): Handle this case.
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
steps=[IdentityProcessor()], name="post_processor"
return (
RobotProcessor.from_pretrained(
source=pretrained_path,
config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"),
),
RobotProcessor.from_pretrained(
source=pretrained_path,
config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"),
),
)
# Create a new processor based on policy type