From 35c5d432558ab7446f517bc0d607fdc3e648bcdf Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Mon, 11 Aug 2025 18:00:25 +0200 Subject: [PATCH] chore(processor): Add default names for preprocessor and postprocessor in constants - Introduced `PREPROCESSOR_DEFAULT_NAME` and `POSTPROCESSOR_DEFAULT_NAME` constants for consistent naming across various processor implementations. - Updated processor creation in multiple policy files to utilize these constants, enhancing code readability and maintainability. - Modified the training script to load and save the preprocessor and postprocessor using the new constants. --- src/lerobot/constants.py | 3 +++ src/lerobot/policies/act/processor_act.py | 5 +++-- .../policies/diffusion/processor_diffusion.py | 5 +++-- src/lerobot/policies/pi0/processor_pi0.py | 5 +++-- src/lerobot/policies/pi0fast/processor_pi0fast.py | 5 +++-- src/lerobot/policies/sac/processor_sac.py | 5 +++-- src/lerobot/policies/smolvla/processor_smolvla.py | 5 +++-- src/lerobot/policies/tdmpc/processor_tdmpc.py | 5 +++-- src/lerobot/policies/vqbet/processor_vqbet.py | 5 +++-- src/lerobot/scripts/train.py | 15 ++++++++++----- src/lerobot/utils/train_utils.py | 6 +++++- 11 files changed, 42 insertions(+), 22 deletions(-) diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index a502a9570..98e1813c4 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -40,6 +40,9 @@ OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" +PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor" +POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor" + if "LEROBOT_HOME" in os.environ: raise ValueError( f"You have a 'LEROBOT_HOME' environment variable set to '{os.getenv('LEROBOT_HOME')}'.\n" diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index f3dc046f6..de004f676 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -15,6 +15,7 @@ # limitations under the License. import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( DeviceProcessor, @@ -45,6 +46,6 @@ def make_act_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 40002c3ed..89491d022 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -16,6 +16,7 @@ # limitations under the License. import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.processor import ( DeviceProcessor, @@ -46,6 +47,6 @@ def make_diffusion_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index 4c411dd66..3629f1071 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -19,6 +19,7 @@ from typing import Any import torch from lerobot.configs.types import PolicyFeature +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( DeviceProcessor, @@ -115,6 +116,6 @@ def make_pi0_processor( ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index 135f6d383..eccfeb44f 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -16,6 +16,7 @@ import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( DeviceProcessor, @@ -46,6 +47,6 @@ def make_pi0fast_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 54bbd7f3b..14a976bbe 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -17,6 +17,7 @@ import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.processor import ( DeviceProcessor, @@ -47,6 +48,6 @@ def make_sac_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 2c0221f9e..231f2969e 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -18,6 +18,7 @@ from typing import Any import torch from lerobot.configs.types import PolicyFeature +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( DeviceProcessor, @@ -57,8 +58,8 @@ def make_smolvla_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 157a12733..553aa1d04 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -16,6 +16,7 @@ # limitations under the License. import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.processor import ( DeviceProcessor, @@ -46,6 +47,6 @@ def make_tdmpc_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index e3e95bd09..cfcd31ddd 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -17,6 +17,7 @@ # limitations under the License. import torch +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import ( DeviceProcessor, @@ -47,6 +48,6 @@ def make_vqbet_processor( features=config.output_features, norm_map=config.normalization_mapping, stats=dataset_stats ), ] - return RobotProcessor(steps=input_steps, name="robot_preprocessor"), RobotProcessor( - steps=output_steps, name="robot_postprocessor" + return RobotProcessor(steps=input_steps, name=PREPROCESSOR_DEFAULT_NAME), RobotProcessor( + steps=output_steps, name=POSTPROCESSOR_DEFAULT_NAME ) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index d0202246d..57eb0db60 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -26,6 +26,7 @@ from torch.optim import Optimizer from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig +from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle @@ -152,6 +153,10 @@ def train(cfg: TrainPipelineConfig): if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) + preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json") + postprocessor.from_pretrained( + cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json" + ) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad) num_total_params = sum(p.numel() for p in policy.parameters()) @@ -240,7 +245,9 @@ def train(cfg: TrainPipelineConfig): if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor) + save_checkpoint( + checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor, postprocessor + ) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) @@ -284,10 +291,8 @@ def train(cfg: TrainPipelineConfig): if cfg.policy.push_to_hub: policy.push_model_to_hub(cfg) - if preprocessor: - preprocessor.push_to_hub(cfg.policy.repo_id) - if postprocessor: - postprocessor.push_to_hub(cfg.policy.repo_id) + preprocessor.push_to_hub(cfg.policy.repo_id) + postprocessor.push_to_hub(cfg.policy.repo_id) def main(): diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 430323794..1067fa619 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -32,6 +32,7 @@ from lerobot.datasets.utils import load_json, write_json from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor.pipeline import RobotProcessor from lerobot.utils.random_utils import load_rng_state, save_rng_state @@ -74,7 +75,8 @@ def save_checkpoint( policy: PreTrainedPolicy, optimizer: Optimizer, scheduler: LRScheduler | None = None, - preprocessor=None, + preprocessor: RobotProcessor | None = None, + postprocessor: RobotProcessor | None = None, ) -> None: """This function creates the following directory structure: @@ -105,6 +107,8 @@ def save_checkpoint( cfg.save_pretrained(pretrained_dir) if preprocessor is not None: preprocessor.save_pretrained(pretrained_dir) + if postprocessor is not None: + postprocessor.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler)