refactor(record): Rename processor parameters and update processing logic

- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.
This commit is contained in:
Adil Zouitine
2025-07-24 18:18:39 +02:00
committed by Steven Palma
parent b95c219d96
commit c4763f61a1
2 changed files with 20 additions and 11 deletions
+14 -8
View File
@@ -196,7 +196,8 @@ def record_loop(
dataset: LeRobotDataset | None = None, dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None, teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None, policy: PreTrainedPolicy | None = None,
processor: RobotProcessor | None = None, preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
control_time_s: int | None = None, control_time_s: int | None = None,
single_task: str | None = None, single_task: str | None = None,
display_data: bool = False, display_data: bool = False,
@@ -222,9 +223,10 @@ def record_loop(
) )
# Reset policy and processor if they are provided # Reset policy and processor if they are provided
if policy is not None or processor is not None: if policy is not None or preprocessor is not None:
policy.reset() policy.reset()
processor.reset() preprocessor.reset()
postprocessor.reset()
timestamp = 0 timestamp = 0
start_episode_t = time.perf_counter() start_episode_t = time.perf_counter()
@@ -240,12 +242,13 @@ def record_loop(
if policy is not None or dataset is not None: if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation") observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None or processor is not None: if policy is not None or preprocessor is not None:
action_values = predict_action( action_values = predict_action(
observation=observation_frame, observation=observation_frame,
policy=policy, policy=policy,
device=get_safe_torch_device(policy.config.device), device=get_safe_torch_device(policy.config.device),
processor=processor, preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp, use_amp=policy.config.use_amp,
task=single_task, task=single_task,
robot_type=robot.robot_type, robot_type=robot.robot_type,
@@ -271,7 +274,7 @@ def record_loop(
continue continue
# Action can eventually be clipped using `max_relative_target`, # Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset. # so action actually sent is saved in the dataset. action = postprocessor.process(action)
sent_action = robot.send_action(action) sent_action = robot.send_action(action)
if dataset is not None: if dataset is not None:
@@ -332,9 +335,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy # Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta) policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
processor = None preprocessor = None
postprocessor = None
if cfg.policy is not None: if cfg.policy is not None:
processor, _ = make_processor( preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy, policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path, pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats, dataset_stats=dataset.meta.stats,
@@ -356,6 +360,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
fps=cfg.dataset.fps, fps=cfg.dataset.fps,
teleop=teleop, teleop=teleop,
policy=policy, policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset, dataset=dataset,
control_time_s=cfg.dataset.episode_time_s, control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task, single_task=cfg.dataset.single_task,
+6 -3
View File
@@ -31,7 +31,7 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor from lerobot.processor import RobotProcessor, TransitionKey
from lerobot.robots import Robot from lerobot.robots import Robot
@@ -102,7 +102,8 @@ def predict_action(
observation: dict[str, np.ndarray], observation: dict[str, np.ndarray],
policy: PreTrainedPolicy, policy: PreTrainedPolicy,
device: torch.device, device: torch.device,
processor: RobotProcessor, preprocessor: RobotProcessor,
postprocessor: RobotProcessor,
use_amp: bool, use_amp: bool,
task: str | None = None, task: str | None = None,
robot_type: str | None = None, robot_type: str | None = None,
@@ -124,12 +125,14 @@ def predict_action(
observation["task"] = task if task else "" observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else "" observation["robot_type"] = robot_type if robot_type else ""
observation = processor(observation) observation = preprocessor(observation)
# Compute the next action with the policy # Compute the next action with the policy
# based on the current observation # based on the current observation
action = policy.select_action(observation) action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Remove batch dimension # Remove batch dimension
action = action.squeeze(0) action = action.squeeze(0)