From 4b24f942253d4d0ec7dc1a65a4653109b468f0d3 Mon Sep 17 00:00:00 2001 From: AdilZouitine Date: Thu, 17 Jul 2025 13:10:15 +0200 Subject: [PATCH] feat(record): Integrate RobotProcessor into recording loop and update policy handling - Added support for RobotProcessor in the record_loop function to enhance data processing capabilities. - Updated the logic to reset both policy and processor when provided, ensuring proper state management. - Modified action prediction to utilize the processor, improving the overall functionality of the recording process. - Adjusted the save_checkpoint function to include preprocessor state saving, enhancing checkpointing capabilities. --- src/lerobot/configs/policies.py | 3 +-- src/lerobot/policies/factory.py | 6 ++++-- src/lerobot/record.py | 27 +++++++++++++++++++-------- src/lerobot/scripts/train.py | 2 +- src/lerobot/utils/control_utils.py | 4 ++++ src/lerobot/utils/train_utils.py | 8 +++++++- 6 files changed, 36 insertions(+), 14 deletions(-) diff --git a/src/lerobot/configs/policies.py b/src/lerobot/configs/policies.py index f5fa727cf..7532f0612 100644 --- a/src/lerobot/configs/policies.py +++ b/src/lerobot/configs/policies.py @@ -26,7 +26,7 @@ from huggingface_hub import hf_hub_download from huggingface_hub.constants import CONFIG_NAME from huggingface_hub.errors import HfHubHTTPError -from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature +from lerobot.configs.types import FeatureType, PolicyFeature from lerobot.constants import ACTION, OBS_STATE from lerobot.optim.optimizers import OptimizerConfig from lerobot.optim.schedulers import LRSchedulerConfig @@ -53,7 +53,6 @@ class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC): """ n_obs_steps: int = 1 - normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict) input_features: dict[str, PolicyFeature] = field(default_factory=dict) output_features: dict[str, PolicyFeature] = field(default_factory=dict) diff --git a/src/lerobot/policies/factory.py b/src/lerobot/policies/factory.py index 9f250d72b..490658a09 100644 --- a/src/lerobot/policies/factory.py +++ b/src/lerobot/policies/factory.py @@ -35,7 +35,7 @@ from lerobot.policies.sac.reward_model.configuration_classifier import RewardCla from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig -from lerobot.processor.pipeline import EnvTransition, RobotProcessor, TransitionIndex +from lerobot.processor.pipeline import EnvTransition, IdentityProcessor, RobotProcessor, TransitionIndex def get_policy_class(name: str) -> PreTrainedPolicy: @@ -128,7 +128,9 @@ def make_processor( if pretrained_path: # Load a pretrained processor # TODO(azouitine): Handle this case. - raise NotImplementedError("Loading a pretrained processor is not implemented.") + return RobotProcessor.from_pretrained(source=pretrained_path), RobotProcessor( + steps=[IdentityProcessor()], name="post_processor" + ) # Create a new processor based on policy type if policy_cfg.type == "tdmpc": diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 575fcb94d..c1572a8dc 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -74,8 +74,9 @@ from lerobot.datasets.image_writer import safe_stop_image_writer from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features from lerobot.datasets.video_utils import VideoEncodingManager -from lerobot.policies.factory import make_policy +from lerobot.policies.factory import make_policy, make_processor from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import RobotProcessor from lerobot.robots import ( # noqa: F401 Robot, RobotConfig, @@ -195,6 +196,7 @@ def record_loop( dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, + processor: RobotProcessor | None = None, control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, @@ -219,9 +221,10 @@ def record_loop( "For multi-teleop, the list must contain exactly one KeyboardTeleop and one arm teleoperator. Currently only supported for LeKiwi robot." ) - # if policy is given it needs cleaning up - if policy is not None: + # Reset policy and processor if they are provided + if policy is not None or processor is not None: policy.reset() + processor.reset() timestamp = 0 start_episode_t = time.perf_counter() @@ -237,12 +240,13 @@ def record_loop( if policy is not None or dataset is not None: observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") - if policy is not None: + if policy is not None or processor is not None: action_values = predict_action( - observation_frame, - policy, - get_safe_torch_device(policy.config.device), - policy.config.use_amp, + observation=observation_frame, + policy=policy, + device=get_safe_torch_device(policy.config.device), + processor=processor, + use_amp=policy.config.use_amp, task=single_task, robot_type=robot.robot_type, ) @@ -328,6 +332,13 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) + processor = None + if cfg.policy is not None: + processor, _ = make_processor( + policy_cfg=cfg.policy, + pretrained_path=cfg.policy.pretrained_path, + dataset_stats=dataset.meta.stats, + ) robot.connect() if teleop is not None: diff --git a/src/lerobot/scripts/train.py b/src/lerobot/scripts/train.py index 5a69b4cad..e980595c1 100644 --- a/src/lerobot/scripts/train.py +++ b/src/lerobot/scripts/train.py @@ -244,7 +244,7 @@ def train(cfg: TrainPipelineConfig): if cfg.save_checkpoint and is_saving_step: logging.info(f"Checkpoint policy after step {step}") checkpoint_dir = get_step_checkpoint_dir(cfg.output_dir, cfg.steps, step) - save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler) + save_checkpoint(checkpoint_dir, step, cfg, policy, optimizer, lr_scheduler, preprocessor) update_last_checkpoint(checkpoint_dir) if wandb_logger: wandb_logger.log_policy(checkpoint_dir) diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index 4bcc241da..bf811b6e3 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -31,6 +31,7 @@ from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.processor import RobotProcessor from lerobot.robots import Robot @@ -101,6 +102,7 @@ def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, + processor: RobotProcessor, use_amp: bool, task: str | None = None, robot_type: str | None = None, @@ -122,6 +124,8 @@ def predict_action( observation["task"] = task if task else "" observation["robot_type"] = robot_type if robot_type else "" + observation = processor(observation) + # Compute the next action with the policy # based on the current observation action = policy.select_action(observation) diff --git a/src/lerobot/utils/train_utils.py b/src/lerobot/utils/train_utils.py index 2859fe057..430323794 100644 --- a/src/lerobot/utils/train_utils.py +++ b/src/lerobot/utils/train_utils.py @@ -74,6 +74,7 @@ def save_checkpoint( policy: PreTrainedPolicy, optimizer: Optimizer, scheduler: LRScheduler | None = None, + preprocessor=None, ) -> None: """This function creates the following directory structure: @@ -81,7 +82,9 @@ def save_checkpoint( ├── pretrained_model/ │ ├── config.json # policy config │ ├── model.safetensors # policy weights - │ └── train_config.json # train config + │ ├── train_config.json # train config + │ ├── processor.json # processor config (if preprocessor provided) + │ └── step_*.safetensors # processor state files (if any) └── training_state/ ├── optimizer_param_groups.json # optimizer param groups ├── optimizer_state.safetensors # optimizer state @@ -95,10 +98,13 @@ def save_checkpoint( policy (PreTrainedPolicy): The policy to save. optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None. scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None. + preprocessor: The preprocessor/pipeline to save. Defaults to None. """ pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR policy.save_pretrained(pretrained_dir) cfg.save_pretrained(pretrained_dir) + if preprocessor is not None: + preprocessor.save_pretrained(pretrained_dir) save_training_state(checkpoint_dir, step, optimizer, scheduler)