mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 00:29:52 +00:00
refactor(record): Rename processor parameters and update processing logic
- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability.
This commit is contained in:
committed by
Steven Palma
parent
b95c219d96
commit
c4763f61a1
+14
-8
@@ -196,7 +196,8 @@ def record_loop(
|
||||
dataset: LeRobotDataset | None = None,
|
||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||
policy: PreTrainedPolicy | None = None,
|
||||
processor: RobotProcessor | None = None,
|
||||
preprocessor: RobotProcessor | None = None,
|
||||
postprocessor: RobotProcessor | None = None,
|
||||
control_time_s: int | None = None,
|
||||
single_task: str | None = None,
|
||||
display_data: bool = False,
|
||||
@@ -222,9 +223,10 @@ def record_loop(
|
||||
)
|
||||
|
||||
# Reset policy and processor if they are provided
|
||||
if policy is not None or processor is not None:
|
||||
if policy is not None or preprocessor is not None:
|
||||
policy.reset()
|
||||
processor.reset()
|
||||
preprocessor.reset()
|
||||
postprocessor.reset()
|
||||
|
||||
timestamp = 0
|
||||
start_episode_t = time.perf_counter()
|
||||
@@ -240,12 +242,13 @@ def record_loop(
|
||||
if policy is not None or dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||
|
||||
if policy is not None or processor is not None:
|
||||
if policy is not None or preprocessor is not None:
|
||||
action_values = predict_action(
|
||||
observation=observation_frame,
|
||||
policy=policy,
|
||||
device=get_safe_torch_device(policy.config.device),
|
||||
processor=processor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
use_amp=policy.config.use_amp,
|
||||
task=single_task,
|
||||
robot_type=robot.robot_type,
|
||||
@@ -271,7 +274,7 @@ def record_loop(
|
||||
continue
|
||||
|
||||
# Action can eventually be clipped using `max_relative_target`,
|
||||
# so action actually sent is saved in the dataset.
|
||||
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||
sent_action = robot.send_action(action)
|
||||
|
||||
if dataset is not None:
|
||||
@@ -332,9 +335,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
# Load pretrained policy
|
||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||
processor = None
|
||||
preprocessor = None
|
||||
postprocessor = None
|
||||
if cfg.policy is not None:
|
||||
processor, _ = make_processor(
|
||||
preprocessor, postprocessor = make_processor(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
@@ -356,6 +360,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
fps=cfg.dataset.fps,
|
||||
teleop=teleop,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=cfg.dataset.episode_time_s,
|
||||
single_task=cfg.dataset.single_task,
|
||||
|
||||
@@ -31,7 +31,7 @@ from termcolor import colored
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import RobotProcessor
|
||||
from lerobot.processor import RobotProcessor, TransitionKey
|
||||
from lerobot.robots import Robot
|
||||
|
||||
|
||||
@@ -102,7 +102,8 @@ def predict_action(
|
||||
observation: dict[str, np.ndarray],
|
||||
policy: PreTrainedPolicy,
|
||||
device: torch.device,
|
||||
processor: RobotProcessor,
|
||||
preprocessor: RobotProcessor,
|
||||
postprocessor: RobotProcessor,
|
||||
use_amp: bool,
|
||||
task: str | None = None,
|
||||
robot_type: str | None = None,
|
||||
@@ -124,12 +125,14 @@ def predict_action(
|
||||
observation["task"] = task if task else ""
|
||||
observation["robot_type"] = robot_type if robot_type else ""
|
||||
|
||||
observation = processor(observation)
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# Compute the next action with the policy
|
||||
# based on the current observation
|
||||
action = policy.select_action(observation)
|
||||
|
||||
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||
|
||||
# Remove batch dimension
|
||||
action = action.squeeze(0)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user