diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 31621ea7d..e73c76384 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -59,7 +59,7 @@ python -m lerobot.record \ import logging import time -from dataclasses import asdict, dataclass +from dataclasses import asdict, dataclass, field from pathlib import Path from pprint import pformat @@ -77,6 +77,7 @@ from lerobot.datasets.video_utils import VideoEncodingManager from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.processor import RobotProcessor +from lerobot.processor.normalize_processor import rename_stats from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -150,6 +151,8 @@ class DatasetRecordConfig: # Number of episodes to record before batch encoding videos # Set to 1 for immediate encoding (default behavior), or higher for batched encoding video_encoding_batch_size: int = 1 + # Rename map for the observation to override the image and state keys + rename_map: dict[str, str] = field(default_factory=dict) def __post_init__(self): if self.single_task is None: @@ -341,7 +344,11 @@ def record(cfg: RecordConfig) -> LeRobotDataset: preprocessor, postprocessor = make_processor( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, - dataset_stats=dataset.meta.stats, + dataset_stats=rename_stats(dataset.meta.stats, cfg.dataset.rename_map), + preprocessor_overrides={ + "device_processor": {"device": cfg.policy.device}, + "rename_processor": {"rename_map": cfg.dataset.rename_map}, + }, ) robot.connect() diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index e980595c1..5150009de 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -140,7 +140,7 @@ def train(cfg: TrainPipelineConfig): cfg=cfg.policy, ds_meta=dataset.meta, ) - preprocessor, _ = make_processor( + preprocessor, postprocessor = make_processor( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats ) @@ -288,7 +288,10 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) - preprocessor.push_to_hub(cfg.policy.repo_id) + if preprocessor: + preprocessor.push_to_hub(cfg.policy.repo_id) + if postprocessor: + postprocessor.push_to_hub(cfg.policy.repo_id) def main():