mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 11:39:50 +00:00
refactor(constants): rename preprocessor and postprocessor constants for clarity (#1868)
- Updated constant names from PREPROCESSOR_DEFAULT_NAME and POSTPROCESSOR_DEFAULT_NAME to POLICY_PREPROCESSOR_DEFAULT_NAME and POLICY_POSTPROCESSOR_DEFAULT_NAME for better context. - Adjusted references across multiple files to use the new constant names, ensuring consistency in the codebase.
This commit is contained in:
@@ -45,8 +45,8 @@ OPTIMIZER_STATE = "optimizer_state.safetensors"
|
|||||||
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
OPTIMIZER_PARAM_GROUPS = "optimizer_param_groups.json"
|
||||||
SCHEDULER_STATE = "scheduler_state.json"
|
SCHEDULER_STATE = "scheduler_state.json"
|
||||||
|
|
||||||
PREPROCESSOR_DEFAULT_NAME = "robot_preprocessor"
|
POLICY_PREPROCESSOR_DEFAULT_NAME = "policy_preprocessor"
|
||||||
POSTPROCESSOR_DEFAULT_NAME = "robot_postprocessor"
|
POLICY_POSTPROCESSOR_DEFAULT_NAME = "policy_postprocessor"
|
||||||
|
|
||||||
if "LEROBOT_HOME" in os.environ:
|
if "LEROBOT_HOME" in os.environ:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.act.configuration_act import ACTConfig
|
from lerobot.policies.act.configuration_act import ACTConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -59,12 +59,12 @@ def make_act_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -59,12 +59,12 @@ def make_diffusion_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ from typing_extensions import Unpack
|
|||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import FeatureType
|
from lerobot.configs.types import FeatureType
|
||||||
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||||
from lerobot.datasets.utils import dataset_to_policy_features
|
from lerobot.datasets.utils import dataset_to_policy_features
|
||||||
from lerobot.envs.configs import EnvConfig
|
from lerobot.envs.configs import EnvConfig
|
||||||
@@ -148,14 +149,18 @@ def make_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline.from_pretrained(
|
PolicyProcessorPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_path,
|
pretrained_model_name_or_path=pretrained_path,
|
||||||
config_filename=kwargs.get("preprocessor_config_filename", "robot_preprocessor.json"),
|
config_filename=kwargs.get(
|
||||||
|
"preprocessor_config_filename", f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||||
|
),
|
||||||
overrides=kwargs.get("preprocessor_overrides", {}),
|
overrides=kwargs.get("preprocessor_overrides", {}),
|
||||||
to_transition=preprocessor_kwargs.get("to_transition"),
|
to_transition=preprocessor_kwargs.get("to_transition"),
|
||||||
to_output=preprocessor_kwargs.get("to_output"),
|
to_output=preprocessor_kwargs.get("to_output"),
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline.from_pretrained(
|
PolicyProcessorPipeline.from_pretrained(
|
||||||
pretrained_model_name_or_path=pretrained_path,
|
pretrained_model_name_or_path=pretrained_path,
|
||||||
config_filename=kwargs.get("postprocessor_config_filename", "robot_postprocessor.json"),
|
config_filename=kwargs.get(
|
||||||
|
"postprocessor_config_filename", f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||||
|
),
|
||||||
overrides=kwargs.get("postprocessor_overrides", {}),
|
overrides=kwargs.get("postprocessor_overrides", {}),
|
||||||
to_transition=postprocessor_kwargs.get("to_transition"),
|
to_transition=postprocessor_kwargs.get("to_transition"),
|
||||||
to_output=postprocessor_kwargs.get("to_output"),
|
to_output=postprocessor_kwargs.get("to_output"),
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -107,12 +107,12 @@ def make_pi0_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -59,12 +59,12 @@ def make_pi0fast_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -60,12 +60,12 @@ def make_sac_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.configs.types import PolicyFeature
|
from lerobot.configs.types import PolicyFeature
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -70,12 +70,12 @@ def make_smolvla_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -16,7 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -59,12 +59,12 @@ def make_tdmpc_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -17,7 +17,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||||
from lerobot.processor import (
|
from lerobot.processor import (
|
||||||
AddBatchDimensionProcessorStep,
|
AddBatchDimensionProcessorStep,
|
||||||
@@ -60,12 +60,12 @@ def make_vqbet_pre_post_processors(
|
|||||||
return (
|
return (
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=input_steps,
|
steps=input_steps,
|
||||||
name=PREPROCESSOR_DEFAULT_NAME,
|
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||||
**preprocessor_kwargs,
|
**preprocessor_kwargs,
|
||||||
),
|
),
|
||||||
PolicyProcessorPipeline(
|
PolicyProcessorPipeline(
|
||||||
steps=output_steps,
|
steps=output_steps,
|
||||||
name=POSTPROCESSOR_DEFAULT_NAME,
|
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||||
**postprocessor_kwargs,
|
**postprocessor_kwargs,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ from torch.optim import Optimizer
|
|||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from lerobot.constants import POSTPROCESSOR_DEFAULT_NAME, PREPROCESSOR_DEFAULT_NAME
|
from lerobot.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||||
from lerobot.datasets.factory import make_dataset
|
from lerobot.datasets.factory import make_dataset
|
||||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||||
from lerobot.datasets.utils import cycle
|
from lerobot.datasets.utils import cycle
|
||||||
@@ -153,9 +153,11 @@ def train(cfg: TrainPipelineConfig):
|
|||||||
|
|
||||||
if cfg.resume:
|
if cfg.resume:
|
||||||
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
step, optimizer, lr_scheduler = load_training_state(cfg.checkpoint_path, optimizer, lr_scheduler)
|
||||||
preprocessor.from_pretrained(cfg.checkpoint_path, config_filename=f"{PREPROCESSOR_DEFAULT_NAME}.json")
|
preprocessor.from_pretrained(
|
||||||
|
cfg.policy.pretrained_path, config_filename=f"{POLICY_PREPROCESSOR_DEFAULT_NAME}.json"
|
||||||
|
)
|
||||||
postprocessor.from_pretrained(
|
postprocessor.from_pretrained(
|
||||||
cfg.checkpoint_path, config_filename=f"{POSTPROCESSOR_DEFAULT_NAME}.json"
|
cfg.policy.pretrained_path, config_filename=f"{POLICY_POSTPROCESSOR_DEFAULT_NAME}.json"
|
||||||
)
|
)
|
||||||
|
|
||||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||||
|
|||||||
Reference in New Issue
Block a user