From 1ac6a6d3fe62e63b516266f87c8a9909261710e1 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 4 Sep 2025 17:01:53 +0200 Subject: [PATCH] refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868) - Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context. - Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase. --- src/lerobot/constants.py | 4 ++-- src/lerobot/policies/act/processor_act.py | 6 +++--- src/lerobot/policies/diffusion/processor_diffusion.py | 6 +++--- src/lerobot/policies/factory.py | 9 +++++++-- src/lerobot/policies/pi0/processor_pi0.py | 6 +++--- src/lerobot/policies/pi0fast/processor_pi0fast.py | 6 +++--- src/lerobot/policies/sac/processor_sac.py | 6 +++--- src/lerobot/policies/smolvla/processor_smolvla.py | 6 +++--- src/lerobot/policies/tdmpc/processor_tdmpc.py | 6 +++--- src/lerobot/policies/vqbet/processor_vqbet.py | 6 +++--- src/lerobot/scripts/train.py | 8 +++++--- 11 files changed, 38 insertions(+), 31 deletions(-) diff --git a/src/lerobot/constants.py b/src/lerobot/constants.py index 683c5ff0e..fad5c0b77 100644 --- a/src/lerobot/constants.py +++ b/src/lerobot/constants.py @@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors" OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json" SCHEDULER_STATE = "scheduler_state.json" -PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor" -POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor" +POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor" +POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor" if "LEROBOT_HOME" in os.environ: raise ValueError( diff --git a/src/lerobot/policies/act/processor_act.py b/src/lerobot/policies/act/processor_act.py index 3bf208184..698740ce8 100644 --- a/src/lerobot/policies/act/processor_act.py +++ b/src/lerobot/policies/act/processor_act.py @@ -15,7 +15,7 @@ # limitations under the License. import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.act.configuration_act import ACTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -59,12 +59,12 @@ def make_act_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/diffusion/processor_diffusion.py b/src/lerobot/policies/diffusion/processor_diffusion.py index 19ea30b79..9914cd0c1 100644 --- a/src/lerobot/policies/diffusion/processor_diffusion.py +++ b/src/lerobot/policies/diffusion/processor_diffusion.py @@ -16,7 +16,7 @@ # limitations under the License. import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -59,12 +59,12 @@ def make_diffusion_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 125b526ee..c251210b3 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -24,6 +24,7 @@ from typing_extensions import Unpack from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.types import FeatureType +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata from lerobot.datasets.utils import dataset_to_policy_features from lerobot.envs.configs import EnvConfig @@ -148,14 +149,18 @@ def make_pre_post_processors( return ( PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, - config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"), + config_filename=kwargs.get( + "preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" + ), overrides=kwargs.get("preprocessor_overrides", {}), to_transition=preprocessor_kwargs.get("to_transition"), to_output=preprocessor_kwargs.get("to_output"), ), PolicyProcessorPipeline.from_pretrained( pretrained_model_name_or_path=pretrained_path, - config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"), + config_filename=kwargs.get( + "postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" + ), overrides=kwargs.get("postprocessor_overrides", {}), to_transition=postprocessor_kwargs.get("to_transition"), to_output=postprocessor_kwargs.get("to_output"), diff --git a/src/lerobot/policies/pi0/processor_pi0.py b/src/lerobot/policies/pi0/processor_pi0.py index d2a675c74..86cd76f84 100644 --- a/src/lerobot/policies/pi0/processor_pi0.py +++ b/src/lerobot/policies/pi0/processor_pi0.py @@ -18,7 +18,7 @@ import torch from lerobot.configs.types import PolicyFeature -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -107,12 +107,12 @@ def make_pi0_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/pi0fast/processor_pi0fast.py b/src/lerobot/policies/pi0fast/processor_pi0fast.py index cc5eb4fdc..c815c8379 100644 --- a/src/lerobot/policies/pi0fast/processor_pi0fast.py +++ b/src/lerobot/policies/pi0fast/processor_pi0fast.py @@ -16,7 +16,7 @@ import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -59,12 +59,12 @@ def make_pi0fast_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/sac/processor_sac.py b/src/lerobot/policies/sac/processor_sac.py index 14dba9151..4f0f8c5a3 100644 --- a/src/lerobot/policies/sac/processor_sac.py +++ b/src/lerobot/policies/sac/processor_sac.py @@ -17,7 +17,7 @@ import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.sac.configuration_sac import SACConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -60,12 +60,12 @@ def make_sac_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/smolvla/processor_smolvla.py b/src/lerobot/policies/smolvla/processor_smolvla.py index 9a0fb067f..2123efb50 100644 --- a/src/lerobot/policies/smolvla/processor_smolvla.py +++ b/src/lerobot/policies/smolvla/processor_smolvla.py @@ -17,7 +17,7 @@ import torch from lerobot.configs.types import PolicyFeature -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -70,12 +70,12 @@ def make_smolvla_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/tdmpc/processor_tdmpc.py b/src/lerobot/policies/tdmpc/processor_tdmpc.py index 3085db9cf..28a66fc3e 100644 --- a/src/lerobot/policies/tdmpc/processor_tdmpc.py +++ b/src/lerobot/policies/tdmpc/processor_tdmpc.py @@ -16,7 +16,7 @@ # limitations under the License. import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -59,12 +59,12 @@ def make_tdmpc_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/policies/vqbet/processor_vqbet.py b/src/lerobot/policies/vqbet/processor_vqbet.py index ea60cff54..7743b6ee0 100644 --- a/src/lerobot/policies/vqbet/processor_vqbet.py +++ b/src/lerobot/policies/vqbet/processor_vqbet.py @@ -17,7 +17,7 @@ # limitations under the License. import torch -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig from lerobot.processor import ( AddBatchDimensionProcessorStep, @@ -60,12 +60,12 @@ def make_vqbet_pre_post_processors( return ( PolicyProcessorPipeline( steps=input_steps, - name=PREPROCESSOR_DEFAULT_NAME, + name=POLICY_PREPROCESSOR_DEFAULT_NAME, **preprocessor_kwargs, ), PolicyProcessorPipeline( steps=output_steps, - name=POSTPROCESSOR_DEFAULT_NAME, + name=POLICY_POSTPROCESSOR_DEFAULT_NAME, **postprocessor_kwargs, ), ) diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 68361fe14..d0240b427 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -26,7 +26,7 @@ from torch.optim import Optimizer from lerobot.configs import parser from lerobot.configs.train import TrainPipelineConfig -from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME +from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME from lerobot.datasets.factory import make_dataset from lerobot.datasets.sampler import EpisodeAwareSampler from lerobot.datasets.utils import cycle @@ -153,9 +153,11 @@ def train(cfg: TrainPipelineConfig): if cfg.resume: step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler) - preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json") + preprocessor.from_pretrained( + cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json" + ) postprocessor.from_pretrained( - cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json" + cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json" ) num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)