mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
fix(train.py) push postprocessor with preprocessor
- Add preprocesser policy overrides for device and rename_map - Add rename_map to DatasetRecordConfig (record.py)
This commit is contained in:
committed by
Steven Palma
parent
28ef6fcd14
commit
2805ae347c
@@ -59,7 +59,7 @@ python -m lerobot.record \
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict, dataclass
|
from dataclasses import asdict, dataclass, field
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
|
|
||||||
@@ -77,6 +77,7 @@ from lerobot.datasets.video_utils import VideoEncodingManager
|
|||||||
from lerobot.policies.factory import make_policy, make_processor
|
from lerobot.policies.factory import make_policy, make_processor
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.processor import RobotProcessor
|
from lerobot.processor import RobotProcessor
|
||||||
|
from lerobot.processor.normalize_processor import rename_stats
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import ( # noqa: F401
|
||||||
Robot,
|
Robot,
|
||||||
RobotConfig,
|
RobotConfig,
|
||||||
@@ -150,6 +151,8 @@ class DatasetRecordConfig:
|
|||||||
# Number of episodes to record before batch encoding videos
|
# Number of episodes to record before batch encoding videos
|
||||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||||
video_encoding_batch_size: int = 1
|
video_encoding_batch_size: int = 1
|
||||||
|
# Rename map for the observation to override the image and state keys
|
||||||
|
rename_map: dict[str, str] = field(default_factory=dict)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
if self.single_task is None:
|
if self.single_task is None:
|
||||||
@@ -341,7 +344,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
preprocessor, postprocessor = make_processor(
|
preprocessor, postprocessor = make_processor(
|
||||||
policy_cfg=cfg.policy,
|
policy_cfg=cfg.policy,
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
dataset_stats=dataset.meta.stats,
|
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
|
||||||
|
preprocessor_overrides={
|
||||||
|
"device_processor": {"device": cfg.policy.device},
|
||||||
|
"rename_processor": {"rename_map": cfg.dataset.rename_map},
|
||||||
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
robot.connect()
|
robot.connect()
|
||||||
|
|||||||
@@ -140,7 +140,7 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
cfg=cfg.policy,
|
cfg=cfg.policy,
|
||||||
ds_meta=dataset.meta,
|
ds_meta=dataset.meta,
|
||||||
)
|
)
|
||||||
preprocessor, _ = make_processor(
|
preprocessor, postprocessor = make_processor(
|
||||||
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -288,7 +288,10 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if cfg.policy.push_to_hub:
|
if cfg.policy.push_to_hub:
|
||||||
policy.push_model_to_hub(cfg)
|
policy.push_model_to_hub(cfg)
|
||||||
preprocessor.push_to_hub(cfg.policy.repo_id)
|
if preprocessor:
|
||||||
|
preprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
if postprocessor:
|
||||||
|
postprocessor.push_to_hub(cfg.policy.repo_id)
|
||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
|
|||||||
Reference in New Issue
Block a user