From c4763f61a1deb8b488e81189dc23a9570db85b12 Mon Sep 17 00:00:00 2001 From: Adil Zouitine Date: Thu, 24 Jul 2025 18:18:39 +0200 Subject: [PATCH] refactor(record): Rename processor parameters and update processing logic - Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability. --- src/lerobot/record.py | 22 ++++++++++++++-------- src/lerobot/utils/control_utils.py | 9 ++++++--- 2 files changed, 20 insertions(+), 11 deletions(-) diff --git a/src/lerobot/record.py b/src/lerobot/record.py index c1572a8dc..31621ea7d 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -196,7 +196,8 @@ def record_loop( dataset: LeRobotDataset | None = None, teleop: Teleoperator | list[Teleoperator] | None = None, policy: PreTrainedPolicy | None = None, - processor: RobotProcessor | None = None, + preprocessor: RobotProcessor | None = None, + postprocessor: RobotProcessor | None = None, control_time_s: int | None = None, single_task: str | None = None, display_data: bool = False, @@ -222,9 +223,10 @@ def record_loop( ) # Reset policy and processor if they are provided - if policy is not None or processor is not None: + if policy is not None or preprocessor is not None: policy.reset() - processor.reset() + preprocessor.reset() + postprocessor.reset() timestamp = 0 start_episode_t = time.perf_counter() @@ -240,12 +242,13 @@ def record_loop( if policy is not None or dataset is not None: observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") - if policy is not None or processor is not None: + if policy is not None or preprocessor is not None: action_values = predict_action( observation=observation_frame, policy=policy, device=get_safe_torch_device(policy.config.device), - processor=processor, + preprocessor=preprocessor, + postprocessor=postprocessor, use_amp=policy.config.use_amp, task=single_task, robot_type=robot.robot_type, @@ -271,7 +274,7 @@ def record_loop( continue # Action can eventually be clipped using `max_relative_target`, - # so action actually sent is saved in the dataset. + # so action actually sent is saved in the dataset. action = postprocessor.process(action) sent_action = robot.send_action(action) if dataset is not None: @@ -332,9 +335,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset: # Load pretrained policy policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) - processor = None + preprocessor = None + postprocessor = None if cfg.policy is not None: - processor, _ = make_processor( + preprocessor, postprocessor = make_processor( policy_cfg=cfg.policy, pretrained_path=cfg.policy.pretrained_path, dataset_stats=dataset.meta.stats, @@ -356,6 +360,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset: fps=cfg.dataset.fps, teleop=teleop, policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, dataset=dataset, control_time_s=cfg.dataset.episode_time_s, single_task=cfg.dataset.single_task, diff --git a/src/lerobot/utils/control_utils.py b/src/lerobot/utils/control_utils.py index bf811b6e3..d8c7c9d57 100644 --- a/src/lerobot/utils/control_utils.py +++ b/src/lerobot/utils/control_utils.py @@ -31,7 +31,7 @@ from termcolor import colored from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.policies.pretrained import PreTrainedPolicy -from lerobot.processor import RobotProcessor +from lerobot.processor import RobotProcessor, TransitionKey from lerobot.robots import Robot @@ -102,7 +102,8 @@ def predict_action( observation: dict[str, np.ndarray], policy: PreTrainedPolicy, device: torch.device, - processor: RobotProcessor, + preprocessor: RobotProcessor, + postprocessor: RobotProcessor, use_amp: bool, task: str | None = None, robot_type: str | None = None, @@ -124,12 +125,14 @@ def predict_action( observation["task"] = task if task else "" observation["robot_type"] = robot_type if robot_type else "" - observation = processor(observation) + observation = preprocessor(observation) # Compute the next action with the policy # based on the current observation action = policy.select_action(observation) + action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION] + # Remove batch dimension action = action.squeeze(0)