fix(train.py) push postprocessor with preprocessor

- Add preprocesser policy overrides for device and rename_map
- Add rename_map to DatasetRecordConfig (record.py)
This commit is contained in:
Michel Aractingi
2025-08-06 13:00:18 +02:00
committed by Steven Palma
parent 28ef6fcd14
commit 2805ae347c
2 changed files with 14 additions and 4 deletions
+9 -2
View File
@@ -59,7 +59,7 @@ python -m lerobot.record \
import logging
import time
from dataclasses import asdict, dataclass
from dataclasses import asdict, dataclass, field
from pathlib import Path
from pprint import pformat
@@ -77,6 +77,7 @@ from lerobot.datasets.video_utils import VideoEncodingManager
from lerobot.policies.factory import make_policy, make_processor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor.normalize_processor import rename_stats
from lerobot.robots import ( # noqa: F401
Robot,
RobotConfig,
@@ -150,6 +151,8 @@ class DatasetRecordConfig:
# Number of episodes to record before batch encoding videos
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
video_encoding_batch_size: int = 1
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
def __post_init__(self):
if self.single_task is None:
@@ -341,7 +344,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats,
dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map),
preprocessor_overrides={
"device_processor": {"device": cfg.policy.device},
"rename_processor": {"rename_map": cfg.dataset.rename_map},
},
)
robot.connect()
+5 -2
View File
@@ -140,7 +140,7 @@ def train(cfg: TrainPipelineConfig):
cfg=cfg.policy,
ds_meta=dataset.meta,
)
preprocessor, _ = make_processor(
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats
)
@@ -288,7 +288,10 @@ def train(cfg: TrainPipelineConfig):
if cfg.policy.push_to_hub:
policy.push_model_to_hub(cfg)
preprocessor.push_to_hub(cfg.policy.repo_id)
if preprocessor:
preprocessor.push_to_hub(cfg.policy.repo_id)
if postprocessor:
postprocessor.push_to_hub(cfg.policy.repo_id)
def main():