refactor(record): Rename processor parameters and update processing logic

- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity.
- Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow.
- Ensured compatibility with existing functionality while improving code readability.
This commit is contained in:
Adil Zouitine
2025-07-24 18:18:39 +02:00
committed by Steven Palma
parent b95c219d96
commit c4763f61a1
2 changed files with 20 additions and 11 deletions
+14 -8
View File
@@ -196,7 +196,8 @@ def record_loop(
dataset: LeRobotDataset | None = None,
teleop: Teleoperator | list[Teleoperator] | None = None,
policy: PreTrainedPolicy | None = None,
processor: RobotProcessor | None = None,
preprocessor: RobotProcessor | None = None,
postprocessor: RobotProcessor | None = None,
control_time_s: int | None = None,
single_task: str | None = None,
display_data: bool = False,
@@ -222,9 +223,10 @@ def record_loop(
)
# Reset policy and processor if they are provided
if policy is not None or processor is not None:
if policy is not None or preprocessor is not None:
policy.reset()
processor.reset()
preprocessor.reset()
postprocessor.reset()
timestamp = 0
start_episode_t = time.perf_counter()
@@ -240,12 +242,13 @@ def record_loop(
if policy is not None or dataset is not None:
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
if policy is not None or processor is not None:
if policy is not None or preprocessor is not None:
action_values = predict_action(
observation=observation_frame,
policy=policy,
device=get_safe_torch_device(policy.config.device),
processor=processor,
preprocessor=preprocessor,
postprocessor=postprocessor,
use_amp=policy.config.use_amp,
task=single_task,
robot_type=robot.robot_type,
@@ -271,7 +274,7 @@ def record_loop(
continue
# Action can eventually be clipped using `max_relative_target`,
# so action actually sent is saved in the dataset.
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
sent_action = robot.send_action(action)
if dataset is not None:
@@ -332,9 +335,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
# Load pretrained policy
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
processor = None
preprocessor = None
postprocessor = None
if cfg.policy is not None:
processor, _ = make_processor(
preprocessor, postprocessor = make_processor(
policy_cfg=cfg.policy,
pretrained_path=cfg.policy.pretrained_path,
dataset_stats=dataset.meta.stats,
@@ -356,6 +360,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
fps=cfg.dataset.fps,
teleop=teleop,
policy=policy,
preprocessor=preprocessor,
postprocessor=postprocessor,
dataset=dataset,
control_time_s=cfg.dataset.episode_time_s,
single_task=cfg.dataset.single_task,
+6 -3
View File
@@ -31,7 +31,7 @@ from termcolor import colored
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import DEFAULT_FEATURES
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import RobotProcessor
from lerobot.processor import RobotProcessor, TransitionKey
from lerobot.robots import Robot
@@ -102,7 +102,8 @@ def predict_action(
observation: dict[str, np.ndarray],
policy: PreTrainedPolicy,
device: torch.device,
processor: RobotProcessor,
preprocessor: RobotProcessor,
postprocessor: RobotProcessor,
use_amp: bool,
task: str | None = None,
robot_type: str | None = None,
@@ -124,12 +125,14 @@ def predict_action(
observation["task"] = task if task else ""
observation["robot_type"] = robot_type if robot_type else ""
observation = processor(observation)
observation = preprocessor(observation)
# Compute the next action with the policy
# based on the current observation
action = policy.select_action(observation)
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
# Remove batch dimension
action = action.squeeze(0)