mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration
- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options.
This commit is contained in:
committed by
Steven Palma
parent
1cad87ebd2
commit
023b8f3466
@@ -14,7 +14,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from typing import TypedDict
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from torch import nn
|
||||
|
||||
@@ -102,10 +106,17 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
"""Keyword arguments for the processor config."""
|
||||
|
||||
preprocessor_config_filename: str | None
|
||||
postprocessor_config_filename: str | None
|
||||
|
||||
|
||||
def make_processor(
|
||||
policy_cfg: PreTrainedConfig,
|
||||
pretrained_path: str | None = None,
|
||||
**kwargs,
|
||||
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||
"""Make a processor instance for a given policy type.
|
||||
|
||||
@@ -127,8 +138,15 @@ def make_processor(
|
||||
if pretrained_path:
|
||||
# Load a pretrained processor
|
||||
# TODO(azouitine): Handle this case.
|
||||
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
|
||||
steps=[IdentityProcessor()], name="post_processor"
|
||||
return (
|
||||
RobotProcessor.from_pretrained(
|
||||
source=pretrained_path,
|
||||
config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"),
|
||||
),
|
||||
RobotProcessor.from_pretrained(
|
||||
source=pretrained_path,
|
||||
config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"),
|
||||
),
|
||||
)
|
||||
|
||||
# Create a new processor based on policy type
|
||||
|
||||
Reference in New Issue
Block a user