mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-19 10:40:04 +00:00
fix(train.py) push postprocessor with preprocessor
- Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py)
This commit is contained in:
committed by
Steven Palma
parent
28ef6fcd14
commit
2805ae347c
@@ -59,7 +59,7 @@ python -m lerobot.record \
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import asdict, dataclass
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
|
||||
@@ -77,6 +77,7 @@ from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
from lerobot.policies.factory import make_policy, make_processor
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor.normalize_processor import rename_stats
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
@@ -150,6 +151,8 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.single_task is None:
|
||||
@@ -341,7 +344,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||
},
|
||||
)
|
||||
|
||||
robot.connect()
|
||||
|
||||
@@ -140,7 +140,7 @@ def train(cfg: TrainPipelineConfig):
|
||||
cfg=cfg.policy,
|
||||
ds_meta=dataset.meta,
|
||||
)
|
||||
preprocessor, _ = make_processor(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||
)
|
||||
|
||||
@@ -288,7 +288,10 @@ def train(cfg: TrainPipelineConfig):
|
||||
|
||||
if cfg.policy.push_to_hub:
|
||||
policy.push_model_to_hub(cfg)
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
if preprocessor:
|
||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
if postprocessor:
|
||||
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
Reference in New Issue
Block a user