mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
refactor(record): Rename processor parameters and update processing logic
- Renamed `processor` to `preprocessor` and added `postprocessor` parameter for clarity. - Updated the `record_loop` and `predict_action` functions to utilize the new preprocessor and postprocessor, enhancing the processing flow. - Ensured compatibility with existing functionality while improving code readability.
This commit is contained in:
committed by
Steven Palma
parent
b95c219d96
commit
c4763f61a1
+14
-8
@@ -196,7 +196,8 @@ def record_loop(
|
|||||||
dataset: LeRobotDataset | None = None,
|
dataset: LeRobotDataset | None = None,
|
||||||
teleop: Teleoperator | list[Teleoperator] | None = None,
|
teleop: Teleoperator | list[Teleoperator] | None = None,
|
||||||
policy: PreTrainedPolicy | None = None,
|
policy: PreTrainedPolicy | None = None,
|
||||||
processor: RobotProcessor | None = None,
|
preprocessor: RobotProcessor | None = None,
|
||||||
|
postprocessor: RobotProcessor | None = None,
|
||||||
control_time_s: int | None = None,
|
control_time_s: int | None = None,
|
||||||
single_task: str | None = None,
|
single_task: str | None = None,
|
||||||
display_data: bool = False,
|
display_data: bool = False,
|
||||||
@@ -222,9 +223,10 @@ def record_loop(
|
|||||||
)
|
)
|
||||||
|
|
||||||
# Reset policy and processor if they are provided
|
# Reset policy and processor if they are provided
|
||||||
if policy is not None or processor is not None:
|
if policy is not None or preprocessor is not None:
|
||||||
policy.reset()
|
policy.reset()
|
||||||
processor.reset()
|
preprocessor.reset()
|
||||||
|
postprocessor.reset()
|
||||||
|
|
||||||
timestamp = 0
|
timestamp = 0
|
||||||
start_episode_t = time.perf_counter()
|
start_episode_t = time.perf_counter()
|
||||||
@@ -240,12 +242,13 @@ def record_loop(
|
|||||||
if policy is not None or dataset is not None:
|
if policy is not None or dataset is not None:
|
||||||
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
observation_frame = build_dataset_frame(dataset.features, observation, prefix="observation")
|
||||||
|
|
||||||
if policy is not None or processor is not None:
|
if policy is not None or preprocessor is not None:
|
||||||
action_values = predict_action(
|
action_values = predict_action(
|
||||||
observation=observation_frame,
|
observation=observation_frame,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
device=get_safe_torch_device(policy.config.device),
|
device=get_safe_torch_device(policy.config.device),
|
||||||
processor=processor,
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
use_amp=policy.config.use_amp,
|
use_amp=policy.config.use_amp,
|
||||||
task=single_task,
|
task=single_task,
|
||||||
robot_type=robot.robot_type,
|
robot_type=robot.robot_type,
|
||||||
@@ -271,7 +274,7 @@ def record_loop(
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
# Action can eventually be clipped using `max_relative_target`,
|
# Action can eventually be clipped using `max_relative_target`,
|
||||||
# so action actually sent is saved in the dataset.
|
# so action actually sent is saved in the dataset. action = postprocessor.process(action)
|
||||||
sent_action = robot.send_action(action)
|
sent_action = robot.send_action(action)
|
||||||
|
|
||||||
if dataset is not None:
|
if dataset is not None:
|
||||||
@@ -332,9 +335,10 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
# Load pretrained policy
|
# Load pretrained policy
|
||||||
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
policy = None if cfg.policy is None else make_policy(cfg.policy, ds_meta=dataset.meta)
|
||||||
processor = None
|
preprocessor = None
|
||||||
|
postprocessor = None
|
||||||
if cfg.policy is not None:
|
if cfg.policy is not None:
|
||||||
processor, _ = make_processor(
|
preprocessor, postprocessor = make_processor(
|
||||||
policy_cfg=cfg.policy,
|
policy_cfg=cfg.policy,
|
||||||
pretrained_path=cfg.policy.pretrained_path,
|
pretrained_path=cfg.policy.pretrained_path,
|
||||||
dataset_stats=dataset.meta.stats,
|
dataset_stats=dataset.meta.stats,
|
||||||
@@ -356,6 +360,8 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
|||||||
fps=cfg.dataset.fps,
|
fps=cfg.dataset.fps,
|
||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
policy=policy,
|
policy=policy,
|
||||||
|
preprocessor=preprocessor,
|
||||||
|
postprocessor=postprocessor,
|
||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=cfg.dataset.episode_time_s,
|
control_time_s=cfg.dataset.episode_time_s,
|
||||||
single_task=cfg.dataset.single_task,
|
single_task=cfg.dataset.single_task,
|
||||||
|
|||||||
@@ -31,7 +31,7 @@ from termcolor import colored
|
|||||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.processor import RobotProcessor
|
from lerobot.processor import RobotProcessor, TransitionKey
|
||||||
from lerobot.robots import Robot
|
from lerobot.robots import Robot
|
||||||
|
|
||||||
|
|
||||||
@@ -102,7 +102,8 @@ def predict_action(
|
|||||||
observation: dict[str, np.ndarray],
|
observation: dict[str, np.ndarray],
|
||||||
policy: PreTrainedPolicy,
|
policy: PreTrainedPolicy,
|
||||||
device: torch.device,
|
device: torch.device,
|
||||||
processor: RobotProcessor,
|
preprocessor: RobotProcessor,
|
||||||
|
postprocessor: RobotProcessor,
|
||||||
use_amp: bool,
|
use_amp: bool,
|
||||||
task: str | None = None,
|
task: str | None = None,
|
||||||
robot_type: str | None = None,
|
robot_type: str | None = None,
|
||||||
@@ -124,12 +125,14 @@ def predict_action(
|
|||||||
observation["task"] = task if task else ""
|
observation["task"] = task if task else ""
|
||||||
observation["robot_type"] = robot_type if robot_type else ""
|
observation["robot_type"] = robot_type if robot_type else ""
|
||||||
|
|
||||||
observation = processor(observation)
|
observation = preprocessor(observation)
|
||||||
|
|
||||||
# Compute the next action with the policy
|
# Compute the next action with the policy
|
||||||
# based on the current observation
|
# based on the current observation
|
||||||
action = policy.select_action(observation)
|
action = policy.select_action(observation)
|
||||||
|
|
||||||
|
action: torch.Tensor = postprocessor({TransitionKey.ACTION: action})[TransitionKey.ACTION]
|
||||||
|
|
||||||
# Remove batch dimension
|
# Remove batch dimension
|
||||||
action = action.squeeze(0)
|
action = action.squeeze(0)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user