fix(train.py) push postprocessor with preprocessor

- Add preprocesser policy overrides for device and rename_map
- Add rename_map to DatasetRecordConfig (record.py)
This commit is contained in:
Michel Aractingi
2025-08-06 13:00:18 +02:00
committed by Steven Palma
parent 28ef6fcd14
commit 2805ae347c
2 changed files with 14 additions and 4 deletions
+9 -2
View File
@@ -59,7 +59,7 @@ python -m lerobot.record \
import logging import logging
import time import time
from dataclasses import asdict, dataclass from dataclasses import asdict, dataclass, field
from pathlib import Path from pathlib import Path
from pprint import pformat from pprint import pformat
@@ -77,6 +77,7 @@ from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor from lerobot.processor import RobotProcessor
from lerobot.processor.normalize_processor import rename_stats
from lerobot.robots import ( # noqa: F401 from lerobot.robots import ( # noqa: F401
Robot, Robot,
RobotConfig, RobotConfig,
@@ -150,6 +151,8 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos # Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding # Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1 video_encoding_batch_size: int = 1
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self): def __post_init__(self):
if self.single_task is None: if self.single_task is None:
@@ -341,7 +344,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
preprocessor, postprocessor = make_processor( preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path, pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats, dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
) )
robot.connect() robot.connect()
+5 -2
View File
@@ -140,7 +140,7 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy, cfg=cfg.policy,
ds_meta=dataset.meta, ds_meta=dataset.meta,
) )
preprocessor, _ = make_processor( preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
) )
@@ -288,7 +288,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.policy.push_to_hub: if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg) policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id) if preprocessor:
preprocessor.push_to_hub(cfg.policy.repo_id)
if postprocessor:
postprocessor.push_to_hub(cfg.policy.repo_id)
def main(): def main():