mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(factory): Enhance make_processor to support preprocessor and postprocessor configuration
- Introduced ProcessorConfigKwargs TypedDict for better type safety in processor configuration. - Updated make_processor to accept preprocessor and postprocessor configuration filenames, improving flexibility in processor instantiation. - Refactored the loading of pretrained processors to utilize the new configuration options.
This commit is contained in:
committed by
Steven Palma
parent
1cad87ebd2
commit
023b8f3466
@@ -14,7 +14,11 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
from typing import TypedDict
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from torch import nn
|
from torch import nn
|
||||||
|
|
||||||
@@ -102,10 +106,17 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
|||||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||||
|
|
||||||
|
|
||||||
|
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||||
|
"""Keyword arguments for the processor config."""
|
||||||
|
|
||||||
|
preprocessor_config_filename: str | None
|
||||||
|
postprocessor_config_filename: str | None
|
||||||
|
|
||||||
|
|
||||||
def make_processor(
|
def make_processor(
|
||||||
policy_cfg: PreTrainedConfig,
|
policy_cfg: PreTrainedConfig,
|
||||||
pretrained_path: str | None = None,
|
pretrained_path: str | None = None,
|
||||||
**kwargs,
|
**kwargs: Unpack[ProcessorConfigKwargs],
|
||||||
) -> tuple[RobotProcessor, RobotProcessor]:
|
) -> tuple[RobotProcessor, RobotProcessor]:
|
||||||
"""Make a processor instance for a given policy type.
|
"""Make a processor instance for a given policy type.
|
||||||
|
|
||||||
@@ -127,8 +138,15 @@ def make_processor(
|
|||||||
if pretrained_path:
|
if pretrained_path:
|
||||||
# Load a pretrained processor
|
# Load a pretrained processor
|
||||||
# TODO(azouitine): Handle this case.
|
# TODO(azouitine): Handle this case.
|
||||||
return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor(
|
return (
|
||||||
steps=[IdentityProcessor()], name="post_processor"
|
RobotProcessor.from_pretrained(
|
||||||
|
source=pretrained_path,
|
||||||
|
config_filename=kwargs.get("preprocessor_config_filename", "preprocessor.json"),
|
||||||
|
),
|
||||||
|
RobotProcessor.from_pretrained(
|
||||||
|
source=pretrained_path,
|
||||||
|
config_filename=kwargs.get("postprocessor_config_filename", "postprocessor.json"),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Create a new processor based on policy type
|
# Create a new processor based on policy type
|
||||||
|
|||||||
Reference in New Issue
Block a user