diff --git a/examples/lekiwi/rollout.py b/examples/lekiwi/rollout.py new file mode 100644 index 000000000..b785b6fcb --- /dev/null +++ b/examples/lekiwi/rollout.py @@ -0,0 +1,78 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained policy on LeKiwi without recording (base rollout). + +Uses the rollout engine's :class:`BaseStrategy` (autonomous execution, +no dataset) with :class:`SyncInferenceConfig` (inline policy call per +control tick). For a CLI entry point with the same capabilities plus +recording, upload, and human-in-the-loop variants, see ``lerobot-rollout``. +""" + +from lerobot.configs import PreTrainedConfig +from lerobot.robots.lekiwi import LeKiwiClientConfig +from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig +from lerobot.rollout.context import build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + # Robot: LeKiwi client — make sure lekiwi_host is already running on the robot. + robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi") + + # Policy: load the pretrained config. ``pretrained_path`` is read downstream + # by ``build_rollout_context`` to reload the full model. + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + # Assemble the rollout config: base strategy (no recording) + sync inference. + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + # Graceful Ctrl-C: the strategy loop exits when shutdown_event is set. + signal_handler = ProcessSignalHandler(use_threads=True) + + # Build the context (connects robot, loads policy, wires the inference strategy). + # No custom processors here — LeKiwi runs on raw joint features. + ctx = build_rollout_context(cfg, signal_handler.shutdown_event) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/examples/phone_to_so100/record.py b/examples/phone_to_so100/record.py index d0fb1c1f1..87b8e49fd 100644 --- a/examples/phone_to_so100/record.py +++ b/examples/phone_to_so100/record.py @@ -16,14 +16,29 @@ from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.common.control_utils import init_keyboard_listener -from lerobot.datasets import LeRobotDataset -from lerobot.processor import make_default_processors +from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + EEReferenceAndDelta, + ForwardKinematicsJointsToEE, + GripperVelocityToJoint, + InverseKinematicsEEToJoints, +) from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.phone import Phone, PhoneConfig from lerobot.teleoperators.phone.config_phone import PhoneOS -from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.feature_utils import hw_to_dataset_features +from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.feature_utils import combine_feature_dicts from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -50,16 +65,77 @@ def main(): robot = SO100Follower(robot_config) phone = Phone(teleop_config) - # Configure the dataset features - action_features = hw_to_dataset_features(robot.action_features, ACTION) - obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR) - dataset_features = {**action_features, **obs_features} + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(robot.bus.motors.keys()), + ) - # Create the dataset + # Build pipeline to convert phone action to EE action (with gripper velocity mapped to joint). + phone_to_robot_ee_pose_processor = RobotProcessorPipeline[ + tuple[RobotAction, RobotObservation], RobotAction + ]( + steps=[ + MapPhoneActionToRobotAction(platform=teleop_config.phone_os), + EEReferenceAndDelta( + kinematics=kinematics_solver, + end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, + motor_names=list(robot.bus.motors.keys()), + use_latched_reference=True, + ), + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.20, + ), + GripperVelocityToJoint(speed_factor=20.0), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Build pipeline to convert EE action to joints action (IK). + robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=list(robot.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Build pipeline to convert joint observation to EE observation (FK). + robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()) + ) + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + # Create the dataset, deriving features from the pipelines so the on-disk schema + # matches exactly what the pipelines produce at runtime. dataset = LeRobotDataset.create( repo_id=HF_REPO_ID, fps=FPS, - features=dataset_features, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=phone_to_robot_ee_pose_processor, + initial_features=create_initial_features(action=phone.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=robot_joints_to_ee_pose, + initial_features=create_initial_features(observation=robot.observation_features), + use_videos=True, + ), + ), robot_type=robot.name, use_videos=True, image_writer_threads=4, @@ -77,10 +153,6 @@ def main(): if not robot.is_connected or not phone.is_connected: raise ValueError("Robot or teleop is not connected!") - teleop_action_processor, robot_action_processor, robot_observation_processor = ( - make_default_processors() - ) - print("Starting record loop. Move your phone to teleoperate the robot...") episode_idx = 0 while episode_idx < NUM_EPISODES and not events["stop_recording"]: @@ -91,9 +163,9 @@ def main(): robot=robot, events=events, fps=FPS, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, teleop=phone, dataset=dataset, control_time_s=EPISODE_TIME_SEC, @@ -110,9 +182,9 @@ def main(): robot=robot, events=events, fps=FPS, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, + teleop_action_processor=phone_to_robot_ee_pose_processor, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose, teleop=phone, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/phone_to_so100/rollout.py b/examples/phone_to_so100/rollout.py new file mode 100644 index 000000000..2a17aa4d8 --- /dev/null +++ b/examples/phone_to_so100/rollout.py @@ -0,0 +1,127 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained EE-space policy on SO100 (phone-trained) without recording. + +Mirrors ``examples/so100_to_so100_EE/rollout.py`` — the model was trained +with phone teleoperation in EE space, so at deployment we only need the +joint↔EE conversion on the robot side; the phone is not used. + +Uses :class:`BaseStrategy` (no recording) + :class:`SyncInferenceConfig` +(inline policy call). For recording during rollout, switch to Sentry, +Highlight, or DAgger via ``lerobot-rollout --strategy.type=...``. +""" + +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.configs import PreTrainedConfig +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig +from lerobot.rollout.context import build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} + robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem58760434471", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, + ) + + # Peek at motor names once to build the kinematic solver. + temp_robot = SO100Follower(robot_config) + motor_names = list(temp_robot.bus.motors.keys()) + + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=motor_names, + ) + + robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=motor_names, + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + signal_handler = ProcessSignalHandler(use_threads=True) + + ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/examples/so100_to_so100_EE/record.py b/examples/so100_to_so100_EE/record.py index b849ac4de..a0b92da3b 100644 --- a/examples/so100_to_so100_EE/record.py +++ b/examples/so100_to_so100_EE/record.py @@ -17,13 +17,25 @@ from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.common.control_utils import init_keyboard_listener -from lerobot.datasets import LeRobotDataset -from lerobot.processor import make_default_processors +from lerobot.datasets import LeRobotDataset, aggregate_pipeline_dataset_features, create_initial_features +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + EEBoundsAndSafety, + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) from lerobot.scripts.lerobot_record import record_loop from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig -from lerobot.utils.constants import ACTION, OBS_STR -from lerobot.utils.feature_utils import hw_to_dataset_features +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.feature_utils import combine_feature_dicts from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import init_rerun @@ -50,16 +62,75 @@ def main(): follower = SO100Follower(follower_config) leader = SO100Leader(leader_config) - # Configure the dataset features - action_features = hw_to_dataset_features(follower.action_features, ACTION) - obs_features = hw_to_dataset_features(follower.observation_features, OBS_STR) - dataset_features = {**action_features, **obs_features} + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + follower_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(follower.bus.motors.keys()), + ) + leader_kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=list(leader.bus.motors.keys()), + ) - # Create the dataset + # Build pipeline to convert follower joints to EE observation. + follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys()) + ), + ], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + # Build pipeline to convert leader joints to EE action. + leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + ForwardKinematicsJointsToEE( + kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys()) + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Build pipeline to convert EE action to follower joints (with safety bounds). + ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + EEBoundsAndSafety( + end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, + max_ee_step_m=0.10, + ), + InverseKinematicsEEToJoints( + kinematics=follower_kinematics_solver, + motor_names=list(follower.bus.motors.keys()), + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Create the dataset, deriving features from the pipelines so the on-disk schema + # matches exactly what the pipelines produce at runtime. dataset = LeRobotDataset.create( repo_id=HF_REPO_ID, fps=FPS, - features=dataset_features, + features=combine_feature_dicts( + aggregate_pipeline_dataset_features( + pipeline=leader_joints_to_ee, + initial_features=create_initial_features(action=leader.action_features), + use_videos=True, + ), + aggregate_pipeline_dataset_features( + pipeline=follower_joints_to_ee, + initial_features=create_initial_features(observation=follower.observation_features), + use_videos=True, + ), + ), robot_type=follower.name, use_videos=True, image_writer_threads=4, @@ -71,16 +142,12 @@ def main(): # Initialize the keyboard listener and rerun visualization listener, events = init_keyboard_listener() - init_rerun(session_name="recording_phone") + init_rerun(session_name="recording_so100_ee") try: if not leader.is_connected or not follower.is_connected: raise ValueError("Robot or teleop is not connected!") - teleop_action_processor, robot_action_processor, robot_observation_processor = ( - make_default_processors() - ) - print("Starting record loop...") episode_idx = 0 while episode_idx < NUM_EPISODES and not events["stop_recording"]: @@ -91,9 +158,9 @@ def main(): robot=follower, events=events, fps=FPS, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, teleop=leader, dataset=dataset, control_time_s=EPISODE_TIME_SEC, @@ -110,9 +177,9 @@ def main(): robot=follower, events=events, fps=FPS, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, + teleop_action_processor=leader_joints_to_ee, + robot_action_processor=ee_to_follower_joints, + robot_observation_processor=follower_joints_to_ee, teleop=leader, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, diff --git a/examples/so100_to_so100_EE/rollout.py b/examples/so100_to_so100_EE/rollout.py new file mode 100644 index 000000000..95339029d --- /dev/null +++ b/examples/so100_to_so100_EE/rollout.py @@ -0,0 +1,135 @@ +# !/usr/bin/env python + +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Run a trained EE-space policy on SO100 without recording (base rollout). + +Uses the rollout engine's :class:`BaseStrategy` (autonomous execution, +no dataset) with :class:`SyncInferenceConfig` (inline policy call per +control tick). The custom observation/action processors convert between +joint space (robot hardware) and end-effector space (policy I/O) via +forward/inverse kinematics. +""" + +from lerobot.cameras.opencv import OpenCVCameraConfig +from lerobot.configs import PreTrainedConfig +from lerobot.model.kinematics import RobotKinematics +from lerobot.processor import ( + RobotProcessorPipeline, + observation_to_transition, + robot_action_observation_to_transition, + transition_to_observation, + transition_to_robot_action, +) +from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig +from lerobot.robots.so_follower.robot_kinematic_processor import ( + ForwardKinematicsJointsToEE, + InverseKinematicsEEToJoints, +) +from lerobot.rollout.configs import BaseStrategyConfig, RolloutConfig +from lerobot.rollout.context import build_rollout_context +from lerobot.rollout.inference import SyncInferenceConfig +from lerobot.rollout.strategies.base import BaseStrategy +from lerobot.types import RobotAction, RobotObservation +from lerobot.utils.process import ProcessSignalHandler +from lerobot.utils.utils import init_logging + +FPS = 30 +DURATION_SEC = 60 +TASK_DESCRIPTION = "My task description" +HF_MODEL_ID = "/" + + +def main(): + init_logging() + + # Robot configuration — the rollout engine will connect it inside build_rollout_context. + camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)} + robot_config = SO100FollowerConfig( + port="/dev/tty.usbmodem5A460814411", + id="my_awesome_follower_arm", + cameras=camera_config, + use_degrees=True, + ) + + # Kinematic solver: we need the motor-name list, so peek at the robot once. + # (The rollout engine owns the connected instance; we only use this for introspection.) + temp_robot = SO100Follower(robot_config) + motor_names = list(temp_robot.bus.motors.keys()) + + # NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: + # https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf + kinematics_solver = RobotKinematics( + urdf_path="./SO101/so101_new_calib.urdf", + target_frame_name="gripper_frame_link", + joint_names=motor_names, + ) + + # Joint-space observation → EE-space observation (consumed by the policy). + robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation]( + steps=[ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=motor_names)], + to_transition=observation_to_transition, + to_output=transition_to_observation, + ) + + # EE-space action (produced by the policy) → joint-space action (sent to robot). + robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction]( + steps=[ + InverseKinematicsEEToJoints( + kinematics=kinematics_solver, + motor_names=motor_names, + initial_guess_current_joints=True, + ), + ], + to_transition=robot_action_observation_to_transition, + to_output=transition_to_robot_action, + ) + + # Policy config (full model is loaded inside build_rollout_context). + policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID) + policy_config.pretrained_path = HF_MODEL_ID + + cfg = RolloutConfig( + robot=robot_config, + policy=policy_config, + strategy=BaseStrategyConfig(), + inference=SyncInferenceConfig(), + fps=FPS, + duration=DURATION_SEC, + task=TASK_DESCRIPTION, + ) + + signal_handler = ProcessSignalHandler(use_threads=True) + + # Pass the EE kinematic processors via kwargs; the defaults (identity) would + # otherwise skip the joint↔EE conversion and the policy would receive the + # wrong observation/action space. + ctx = build_rollout_context( + cfg, + signal_handler.shutdown_event, + robot_action_processor=robot_ee_to_joints_processor, + robot_observation_processor=robot_joints_to_ee_pose_processor, + ) + + strategy = BaseStrategy(cfg.strategy) + try: + strategy.setup(ctx) + strategy.run(ctx) + finally: + strategy.teardown(ctx) + + +if __name__ == "__main__": + main() diff --git a/src/lerobot/rl/actor.py b/src/lerobot/rl/actor.py index 588adffac..eab527250 100644 --- a/src/lerobot/rl/actor.py +++ b/src/lerobot/rl/actor.py @@ -76,6 +76,7 @@ from lerobot.transport.utils import ( ) from lerobot.types import TransitionKey from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.transition import ( @@ -94,7 +95,6 @@ from .gym_manipulator import ( make_robot_env, step_env_and_process_transition, ) -from .process import ProcessSignalHandler from .queue import get_last_item_from_queue # Main entry point diff --git a/src/lerobot/rl/learner.py b/src/lerobot/rl/learner.py index d1207421b..14542576d 100644 --- a/src/lerobot/rl/learner.py +++ b/src/lerobot/rl/learner.py @@ -90,6 +90,7 @@ from lerobot.utils.constants import ( TRAINING_STATE_DIR, ) from lerobot.utils.device_utils import get_safe_torch_device +from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.random_utils import set_seed from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device from lerobot.utils.utils import ( @@ -99,7 +100,6 @@ from lerobot.utils.utils import ( from .buffer import ReplayBuffer, concatenate_batch_transitions from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService -from .process import ProcessSignalHandler @parser.wrap() diff --git a/src/lerobot/rollout/__init__.py b/src/lerobot/rollout/__init__.py index 9b83d7107..f0f5bf140 100644 --- a/src/lerobot/rollout/__init__.py +++ b/src/lerobot/rollout/__init__.py @@ -24,7 +24,15 @@ from .configs import ( SentryStrategyConfig, ) from .context import RolloutContext, build_rollout_context -from .inference import InferenceEngine +from .inference import ( + InferenceStrategy, + InferenceStrategyConfig, + RTCInferenceConfig, + RTCInferenceStrategy, + SyncInferenceConfig, + SyncInferenceStrategy, + create_inference_strategy, +) from .ring_buffer import RolloutRingBuffer from .robot_wrapper import ThreadSafeRobot from .strategies import RolloutStrategy, create_strategy @@ -33,7 +41,10 @@ __all__ = [ "BaseStrategyConfig", "DAggerStrategyConfig", "HighlightStrategyConfig", - "InferenceEngine", + "InferenceStrategy", + "InferenceStrategyConfig", + "RTCInferenceConfig", + "RTCInferenceStrategy", "RolloutConfig", "RolloutContext", "RolloutDatasetConfig", @@ -41,7 +52,10 @@ __all__ = [ "RolloutStrategy", "RolloutStrategyConfig", "SentryStrategyConfig", + "SyncInferenceConfig", + "SyncInferenceStrategy", "ThreadSafeRobot", "build_rollout_context", + "create_inference_strategy", "create_strategy", ] diff --git a/src/lerobot/rollout/configs.py b/src/lerobot/rollout/configs.py index 994c49289..d3da043b5 100644 --- a/src/lerobot/rollout/configs.py +++ b/src/lerobot/rollout/configs.py @@ -24,10 +24,11 @@ from pathlib import Path import draccus from lerobot.configs import PreTrainedConfig, parser -from lerobot.policies.rtc import RTCConfig from lerobot.robots.config import RobotConfig from lerobot.teleoperators.config import TeleoperatorConfig +from .inference import InferenceStrategyConfig, SyncInferenceConfig + logger = logging.getLogger(__name__) @@ -92,6 +93,11 @@ class DAggerStrategyConfig(RolloutStrategyConfig): Alternates between autonomous policy execution and human intervention. Intervention frames are tagged with ``intervention=True``. + + When ``record_autonomous=True`` (default) both autonomous and correction + frames are recorded — this requires streaming encoding so the policy + loop never blocks on disk I/O. Set to ``False`` to record only the + human-correction windows; encoding can then happen between phases. """ episode_time_s: float = 120.0 @@ -100,6 +106,7 @@ class DAggerStrategyConfig(RolloutStrategyConfig): calibrate: bool = False log_hz: bool = True hz_log_interval_s: float = 2.0 + record_autonomous: bool = True # --------------------------------------------------------------------------- @@ -153,8 +160,8 @@ class RolloutConfig: # Strategy (polymorphic: --strategy.type=base|sentry|highlight|dagger) strategy: RolloutStrategyConfig = field(default_factory=BaseStrategyConfig) - # Inference backend - rtc: RTCConfig = field(default_factory=RTCConfig) + # Inference backend (polymorphic: --inference.type=sync|rtc) + inference: InferenceStrategyConfig = field(default_factory=SyncInferenceConfig) # Dataset (required for sentry, highlight, dagger; None for base) dataset: RolloutDatasetConfig | None = None @@ -211,6 +218,25 @@ class RolloutConfig: logger.warning("Sentry mode forces streaming_encoding=True") self.dataset.streaming_encoding = True + # Highlight writes frames while the policy is still running, so streaming is mandatory. + if ( + isinstance(self.strategy, HighlightStrategyConfig) + and self.dataset is not None + and not self.dataset.streaming_encoding + ): + logger.warning("Highlight mode forces streaming_encoding=True") + self.dataset.streaming_encoding = True + + # DAgger: streaming is mandatory only when the autonomous phase is also recorded. + if ( + isinstance(self.strategy, DAggerStrategyConfig) + and self.strategy.record_autonomous + and self.dataset is not None + and not self.dataset.streaming_encoding + ): + logger.warning("DAgger with record_autonomous=True forces streaming_encoding=True") + self.dataset.streaming_encoding = True + @classmethod def __get_path_fields__(cls) -> list[str]: return ["policy"] diff --git a/src/lerobot/rollout/context.py b/src/lerobot/rollout/context.py index f54d4248b..01c05270d 100644 --- a/src/lerobot/rollout/context.py +++ b/src/lerobot/rollout/context.py @@ -12,10 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rollout context: shared state created once before strategy dispatch.""" +"""Rollout context: shared state created once before strategy dispatch. + +Grouped into five topical sub-contexts — :class:`RuntimeContext`, +:class:`HardwareContext`, :class:`PolicyContext`, :class:`ProcessorContext`, +and :class:`DatasetContext` — assembled into :class:`RolloutContext`. +""" from __future__ import annotations +import datetime as _dt import logging from dataclasses import dataclass, field from threading import Event @@ -38,11 +44,16 @@ from lerobot.processor import ( make_default_processors, rename_stats, ) -from lerobot.robots import Robot, make_robot_from_config +from lerobot.robots import make_robot_from_config from lerobot.teleoperators import Teleoperator, make_teleoperator_from_config from lerobot.utils.feature_utils import combine_feature_dicts, hw_to_dataset_features -from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig +from .configs import BaseStrategyConfig, DAggerStrategyConfig, RolloutConfig, SentryStrategyConfig +from .inference import ( + InferenceStrategy, + RTCInferenceConfig, + create_inference_strategy, +) from .robot_wrapper import ThreadSafeRobot logger = logging.getLogger(__name__) @@ -68,71 +79,108 @@ def _resolve_action_key_order( return policy_action_names +# --------------------------------------------------------------------------- +# Sub-contexts +# --------------------------------------------------------------------------- + + +@dataclass +class RuntimeContext: + """Runtime knobs shared with every strategy.""" + + cfg: RolloutConfig + shutdown_event: Event + + +@dataclass +class HardwareContext: + """Connected hardware. + + The raw robot is available via ``robot_wrapper.inner`` when needed + (e.g. for disconnect); strategies should otherwise go through the + thread-safe wrapper. + """ + + robot_wrapper: ThreadSafeRobot + teleop: Teleoperator | None + + +@dataclass +class PolicyContext: + """Loaded policy and its inference strategy.""" + + policy: PreTrainedPolicy + preprocessor: PolicyProcessorPipeline + postprocessor: PolicyProcessorPipeline + inference: InferenceStrategy + + +@dataclass +class ProcessorContext: + """Robot-side pipelines (run outside the policy).""" + + teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] + robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] + robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation] + + +@dataclass +class DatasetContext: + """Dataset and feature bookkeeping.""" + + dataset: LeRobotDataset | None + dataset_features: dict = field(default_factory=dict) + hw_features: dict = field(default_factory=dict) + ordered_action_keys: list[str] = field(default_factory=list) + + @dataclass class RolloutContext: - """Bundle of shared resources passed to every rollout strategy. + """Bundle of sub-contexts passed to every rollout strategy. Built once by :func:`build_rollout_context` before strategy dispatch. """ - cfg: RolloutConfig - robot: Robot - robot_wrapper: ThreadSafeRobot - teleop: Teleoperator | None - policy: PreTrainedPolicy - preprocessor: PolicyProcessorPipeline - postprocessor: PolicyProcessorPipeline - teleop_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] - robot_action_processor: RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction] - robot_observation_processor: RobotProcessorPipeline[RobotObservation, RobotObservation] - dataset: LeRobotDataset | None - shutdown_event: Event = field(default_factory=Event) - dataset_features: dict = field(default_factory=dict) - action_keys: list[str] = field(default_factory=list) - ordered_action_keys: list[str] = field(default_factory=list) - hw_features: dict = field(default_factory=dict) + runtime: RuntimeContext + hardware: HardwareContext + policy: PolicyContext + processors: ProcessorContext + data: DatasetContext -def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutContext: - """Wire up hardware, policy, processors, and dataset from config. +# --------------------------------------------------------------------------- +# Build +# --------------------------------------------------------------------------- - This function performs all the one-time setup that every strategy - needs, keeping the strategy implementations lean. + +def build_rollout_context( + cfg: RolloutConfig, + shutdown_event: Event, + teleop_action_processor: RobotProcessorPipeline | None = None, + robot_action_processor: RobotProcessorPipeline | None = None, + robot_observation_processor: RobotProcessorPipeline | None = None, +) -> RolloutContext: + """Wire up policy, processors, hardware, dataset, and inference strategy. + + The order is policy-first / hardware-last so a bad ``--policy.path`` + fails fast without touching the robot. """ - # --- Hardware --- - robot = make_robot_from_config(cfg.robot) - robot.connect() - robot_wrapper = ThreadSafeRobot(robot) + is_rtc = isinstance(cfg.inference, RTCInferenceConfig) - teleop = None - if cfg.teleop is not None: - teleop = make_teleoperator_from_config(cfg.teleop) - teleop.connect() - - # --- Processors --- - teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() - - # --- Policy --- - # Use cfg.policy directly (already loaded in RolloutConfig.__post_init__) - # instead of reloading from disk. + # --- 1. Policy (heavy I/O, but no hardware yet) ------------------- policy_config = cfg.policy - use_rtc = cfg.rtc.enabled policy_class = get_policy_class(policy_config.type) - # Reload config from pretrained path for full model parameters full_config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path) - # Merge any CLI overrides from cfg.policy into full_config for attr in ("device", "use_amp"): if hasattr(cfg.policy, attr) and hasattr(full_config, attr): cli_val = getattr(cfg.policy, attr) if cli_val is not None: setattr(full_config, attr, cli_val) - # Set compile_model for pi0/pi05 if hasattr(full_config, "compile_model"): full_config.compile_model = cfg.use_torch_compile - # Handle PEFT models if full_config.use_peft: from peft import PeftConfig, PeftModel @@ -145,16 +193,14 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC else: policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=full_config) - # Enable RTC on the policy - if use_rtc: - policy.config.rtc_config = cfg.rtc + if is_rtc: + policy.config.rtc_config = cfg.inference.rtc if hasattr(policy, "init_rtc_processor"): policy.init_rtc_processor() policy = policy.to(cfg.device) policy.eval() - # Apply torch.compile if requested (skip pi0/pi05 — they handle their own) if cfg.use_torch_compile and policy.type not in ("pi0", "pi05"): try: if hasattr(torch, "compile"): @@ -168,18 +214,34 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC except Exception as e: logger.warning("Failed to apply torch.compile: %s", e) - # --- Observation features --- - # Hardware-level features: camera features are tuples (H, W, C), state - # features are the ``float`` type. This is the canonical pattern used - # throughout the codebase (see feature_utils.py:hw_to_dataset_features). + # --- 2. Robot-side processors (user-supplied or defaults) -------- + if ( + teleop_action_processor is None + or robot_action_processor is None + or robot_observation_processor is None + ): + _t, _r, _o = make_default_processors() + teleop_action_processor = teleop_action_processor or _t + robot_action_processor = robot_action_processor or _r + robot_observation_processor = robot_observation_processor or _o + + # --- 3. Hardware (heaviest side-effect, deferred) ----------------- + robot = make_robot_from_config(cfg.robot) + robot.connect() + robot_wrapper = ThreadSafeRobot(robot) + + teleop = None + if cfg.teleop is not None: + teleop = make_teleoperator_from_config(cfg.teleop) + teleop.connect() + + # --- 4. Features + action-key reconciliation --------------------- all_obs_features = robot.observation_features observation_features_hw = { k: v for k, v in all_obs_features.items() if v is float or isinstance(v, tuple) } - action_features_hw = robot.action_features - # Build dataset features dataset_features = combine_feature_dicts( aggregate_pipeline_dataset_features( pipeline=teleop_action_processor, @@ -192,22 +254,22 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC use_videos=cfg.dataset.video if cfg.dataset else True, ), ) - hw_features = hw_to_dataset_features(observation_features_hw, "observation") - - # Action keys - action_keys = list(robot.action_features.keys()) - - # Ordered action keys (reconcile policy vs dataset ordering) + raw_action_keys = list(robot.action_features.keys()) policy_action_names = getattr(policy_config, "action_feature_names", None) ordered_action_keys = _resolve_action_key_order( list(policy_action_names) if policy_action_names else None, - action_keys, + raw_action_keys, ) - # --- Dataset --- + # --- 5. Dataset (Sentry gets a unique per-run suffix) ------------- dataset = None if cfg.dataset is not None and not isinstance(cfg.strategy, BaseStrategyConfig): + if not cfg.resume and isinstance(cfg.strategy, SentryStrategyConfig) and cfg.dataset.repo_id: + suffix = _dt.datetime.now(_dt.UTC).strftime("%Y%m%dT%H%M%SZ") + cfg.dataset.repo_id = f"{cfg.dataset.repo_id}-{suffix}" + logger.info("Sentry mode: using run-suffixed repo_id=%s", cfg.dataset.repo_id) + if cfg.resume: dataset = LeRobotDataset.resume( cfg.dataset.repo_id, @@ -222,10 +284,9 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC * len(robot.cameras if hasattr(robot, "cameras") else []), ) else: - # Add intervention column for DAgger strategy if isinstance(cfg.strategy, DAggerStrategyConfig): dataset_features["intervention"] = { - "dtype": "int64", + "dtype": "bool", "shape": (1,), "names": None, } @@ -247,7 +308,7 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC encoder_threads=cfg.dataset.encoder_threads, ) - # --- Pre/post processors --- + # --- 6. Policy pre/post processors (needs dataset stats if any) --- dataset_stats = None if dataset is not None: dataset_stats = rename_stats( @@ -265,21 +326,44 @@ def build_rollout_context(cfg: RolloutConfig, shutdown_event: Event) -> RolloutC }, ) - return RolloutContext( - cfg=cfg, - robot=robot, - robot_wrapper=robot_wrapper, - teleop=teleop, + # --- 7. Inference strategy (needs policy + pre/post + hardware) -- + task_str = cfg.dataset.single_task if cfg.dataset else cfg.task + inference_strategy = create_inference_strategy( + cfg.inference, policy=policy, preprocessor=preprocessor, postprocessor=postprocessor, - teleop_action_processor=teleop_action_processor, - robot_action_processor=robot_action_processor, - robot_observation_processor=robot_observation_processor, - dataset=dataset, - shutdown_event=shutdown_event, - dataset_features=dataset_features, - action_keys=action_keys, - ordered_action_keys=ordered_action_keys, + robot_wrapper=robot_wrapper, hw_features=hw_features, + dataset_features=dataset_features, + ordered_action_keys=ordered_action_keys, + task=task_str, + fps=cfg.fps, + device=cfg.device, + use_torch_compile=cfg.use_torch_compile, + compile_warmup_inferences=cfg.compile_warmup_inferences, + shutdown_event=shutdown_event, + ) + + # --- 8. Assemble --------------------------------------------------- + return RolloutContext( + runtime=RuntimeContext(cfg=cfg, shutdown_event=shutdown_event), + hardware=HardwareContext(robot_wrapper=robot_wrapper, teleop=teleop), + policy=PolicyContext( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + inference=inference_strategy, + ), + processors=ProcessorContext( + teleop_action_processor=teleop_action_processor, + robot_action_processor=robot_action_processor, + robot_observation_processor=robot_observation_processor, + ), + data=DatasetContext( + dataset=dataset, + dataset_features=dataset_features, + hw_features=hw_features, + ordered_action_keys=ordered_action_keys, + ), ) diff --git a/src/lerobot/rollout/inference/__init__.py b/src/lerobot/rollout/inference/__init__.py new file mode 100644 index 000000000..b85801de9 --- /dev/null +++ b/src/lerobot/rollout/inference/__init__.py @@ -0,0 +1,39 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference strategy package — backend-agnostic action production. + +Concrete strategies (sync, RTC, …) expose the same small interface so +rollout strategies never branch on the inference backend. +""" + +from .base import InferenceStrategy +from .factory import ( + InferenceStrategyConfig, + RTCInferenceConfig, + SyncInferenceConfig, + create_inference_strategy, +) +from .rtc import RTCInferenceStrategy +from .sync import SyncInferenceStrategy + +__all__ = [ + "InferenceStrategy", + "InferenceStrategyConfig", + "RTCInferenceConfig", + "RTCInferenceStrategy", + "SyncInferenceConfig", + "SyncInferenceStrategy", + "create_inference_strategy", +] diff --git a/src/lerobot/rollout/inference/base.py b/src/lerobot/rollout/inference/base.py new file mode 100644 index 000000000..9ef51845e --- /dev/null +++ b/src/lerobot/rollout/inference/base.py @@ -0,0 +1,88 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference strategy ABC. + +Rollout strategies consume actions through this small interface so they +do not need to know whether inference is synchronous, runs in a +background thread (RTC), or comes from an external source. +""" + +from __future__ import annotations + +import abc + +import torch + + +class InferenceStrategy(abc.ABC): + """Abstract backend for producing actions during rollout. + + Subclasses decide whether inference happens inline, in a background + thread, or externally. The contract is minimal so new backends can + be added without touching rollout strategies. + + Lifecycle + --------- + ``start`` — prepare the backend (e.g. launch a background thread). + ``stop`` — shut the backend down cleanly. + ``reset`` — clear episode-scoped state (policy hidden state, queues…). + + Action production + ----------------- + ``get_action(obs_frame)`` — return the next action tensor, or + ``None`` if none is available (e.g. async queue empty). Sync + backends always compute from ``obs_frame``; async backends may + ignore it (they get observations via ``notify_observation``). + + Optional hooks + -------------- + ``notify_observation`` / ``pause`` / ``resume`` have a no-op default + so rollout strategies can invoke them unconditionally. + """ + + @abc.abstractmethod + def start(self) -> None: + """Initialise the backend.""" + + @abc.abstractmethod + def stop(self) -> None: + """Tear the backend down.""" + + @abc.abstractmethod + def reset(self) -> None: + """Clear episode-scoped state.""" + + @abc.abstractmethod + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + """Return the next action tensor, or ``None`` if unavailable.""" + + def notify_observation(self, obs: dict) -> None: # noqa: B027 + """Publish the latest processed observation. Default: no-op.""" + + def pause(self) -> None: # noqa: B027 + """Pause background inference. Default: no-op.""" + + def resume(self) -> None: # noqa: B027 + """Resume background inference. Default: no-op.""" + + @property + def ready(self) -> bool: + """True once the backend can produce actions (e.g. warmup done).""" + return True + + @property + def failed(self) -> bool: + """True if an unrecoverable error occurred in the backend.""" + return False diff --git a/src/lerobot/rollout/inference/factory.py b/src/lerobot/rollout/inference/factory.py new file mode 100644 index 000000000..f27065ebf --- /dev/null +++ b/src/lerobot/rollout/inference/factory.py @@ -0,0 +1,125 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Inference strategy configs and factory. + +Selection is explicit via ``--inference.type=sync|rtc``. Adding a new +backend requires registering its config subclass and dispatching it in +:func:`create_inference_strategy`. +""" + +from __future__ import annotations + +import abc +import logging +from dataclasses import dataclass, field +from threading import Event + +import draccus + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.rtc.configuration_rtc import RTCConfig +from lerobot.processor import PolicyProcessorPipeline + +from ..robot_wrapper import ThreadSafeRobot +from .base import InferenceStrategy +from .rtc import RTCInferenceStrategy +from .sync import SyncInferenceStrategy + +logger = logging.getLogger(__name__) + + +# --------------------------------------------------------------------------- +# Configs +# --------------------------------------------------------------------------- + + +@dataclass +class InferenceStrategyConfig(draccus.ChoiceRegistry, abc.ABC): + """Abstract base for inference backend configuration. + + Use ``--inference.type=`` on the CLI to select a backend. + """ + + @property + def type(self) -> str: + return self.get_choice_name(self.__class__) + + +@InferenceStrategyConfig.register_subclass("sync") +@dataclass +class SyncInferenceConfig(InferenceStrategyConfig): + """Inline synchronous inference (one policy call per control tick).""" + + +@InferenceStrategyConfig.register_subclass("rtc") +@dataclass +class RTCInferenceConfig(InferenceStrategyConfig): + """Real-Time Chunking: async policy inference in a background thread.""" + + rtc: RTCConfig = field(default_factory=RTCConfig) + queue_threshold: int = 30 + + +# --------------------------------------------------------------------------- +# Factory +# --------------------------------------------------------------------------- + + +def create_inference_strategy( + config: InferenceStrategyConfig, + *, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + robot_wrapper: ThreadSafeRobot, + hw_features: dict, + dataset_features: dict, + ordered_action_keys: list[str], + task: str, + fps: float, + device: str | None, + use_torch_compile: bool = False, + compile_warmup_inferences: int = 2, + shutdown_event: Event | None = None, +) -> InferenceStrategy: + """Instantiate the appropriate inference strategy from a config object.""" + if isinstance(config, SyncInferenceConfig): + return SyncInferenceStrategy( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + dataset_features=dataset_features, + ordered_action_keys=ordered_action_keys, + task=task, + device=device, + robot_type=robot_wrapper.robot_type, + ) + if isinstance(config, RTCInferenceConfig): + return RTCInferenceStrategy( + policy=policy, + preprocessor=preprocessor, + postprocessor=postprocessor, + robot_wrapper=robot_wrapper, + rtc_config=config.rtc, + hw_features=hw_features, + task=task, + fps=fps, + device=device, + use_torch_compile=use_torch_compile, + compile_warmup_inferences=compile_warmup_inferences, + rtc_queue_threshold=config.queue_threshold, + shutdown_event=shutdown_event, + ) + raise ValueError(f"Unknown inference strategy type: {type(config).__name__}") diff --git a/src/lerobot/rollout/inference.py b/src/lerobot/rollout/inference/rtc.py similarity index 73% rename from src/lerobot/rollout/inference.py rename to src/lerobot/rollout/inference/rtc.py index 01f9ca81f..a331f8636 100644 --- a/src/lerobot/rollout/inference.py +++ b/src/lerobot/rollout/inference/rtc.py @@ -12,11 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Unified inference engine supporting both synchronous and RTC backends. +"""Real-Time Chunking inference strategy. -The :class:`InferenceEngine` abstracts whether prediction happens inline -(sync) or in a background thread (RTC), so rollout strategies don't need -to branch on the inference backend. +A background thread produces action chunks asynchronously via +:meth:`policy.predict_action_chunk`. The main control loop polls +``get_action`` for the next ready action; observations flow the other +way via ``notify_observation``. """ from __future__ import annotations @@ -25,7 +26,6 @@ import logging import math import time import traceback -from copy import copy from threading import Event, Lock, Thread from typing import Any @@ -46,7 +46,8 @@ from lerobot.processor import ( from lerobot.utils.constants import OBS_STATE from lerobot.utils.feature_utils import build_dataset_frame -from .robot_wrapper import ThreadSafeRobot +from ..robot_wrapper import ThreadSafeRobot +from .base import InferenceStrategy logger = logging.getLogger(__name__) @@ -94,42 +95,17 @@ def _normalize_prev_actions_length(prev_actions: torch.Tensor, target_steps: int # --------------------------------------------------------------------------- -# InferenceEngine +# RTCInferenceStrategy # --------------------------------------------------------------------------- -class InferenceEngine: - """Abstracts sync vs. RTC (async) inference for rollout strategies. +class RTCInferenceStrategy(InferenceStrategy): + """Async RTC inference: a background thread produces action chunks. - Parameters - ---------- - policy: - The loaded policy (already on device, in eval mode, with RTC - processor initialised if applicable). - preprocessor / postprocessor: - Policy processor pipelines. - robot_wrapper: - Thread-safe robot wrapper. - rtc_config: - RTC configuration. If ``rtc_config.enabled`` is False, the - engine operates in synchronous mode. - hw_features: - Dataset-level feature dict built from ``hw_to_dataset_features``. - action_keys: - Ordered list of action feature names. - task: - Task description string. - fps: - Control loop frequency. - device: - Torch device string. - use_torch_compile: - Whether torch.compile warmup is needed. - compile_warmup_inferences: - Number of warmup inferences before live rollout. - rtc_queue_threshold: - Maximum RTC action queue size before the background thread - pauses generation. Prevents unbounded queue growth. + ``get_action`` pops the next action from the shared queue (or + returns ``None`` if the queue is empty). The main loop should call + ``notify_observation`` every tick and ``pause``/``resume`` around + human-intervention phases. """ def __init__( @@ -140,7 +116,6 @@ class InferenceEngine: robot_wrapper: ThreadSafeRobot, rtc_config: RTCConfig, hw_features: dict, - action_keys: list[str], task: str, fps: float, device: str | None, @@ -155,7 +130,6 @@ class InferenceEngine: self._robot = robot_wrapper self._rtc_config = rtc_config self._hw_features = hw_features - self._action_keys = action_keys self._task = task self._fps = fps self._device = device or "cpu" @@ -163,8 +137,6 @@ class InferenceEngine: self._compile_warmup_inferences = compile_warmup_inferences self._rtc_queue_threshold = rtc_queue_threshold - # RTC state - self._use_rtc = rtc_config.enabled self._action_queue: ActionQueue | None = None self._obs_holder: dict[str, Any] = {} self._obs_lock = Lock() @@ -178,7 +150,7 @@ class InferenceEngine: if not self._use_torch_compile: self._compile_warmup_done.set() - # Processor introspection for relative-action re-anchoring + # Processor introspection for relative-action re-anchoring. self._relative_step = next( (s for s in preprocessor.steps if isinstance(s, RelativeActionsProcessorStep) and s.enabled), None, @@ -203,38 +175,33 @@ class InferenceEngine: # ------------------------------------------------------------------ @property - def is_rtc(self) -> bool: - return self._use_rtc + def ready(self) -> bool: + return self._compile_warmup_done.is_set() + + @property + def failed(self) -> bool: + """True if the RTC background thread exited due to an unrecoverable error.""" + return self._rtc_error.is_set() @property def action_queue(self) -> ActionQueue | None: return self._action_queue - @property - def compile_warmup_done(self) -> Event: - return self._compile_warmup_done - - @property - def rtc_failed(self) -> bool: - """True if the RTC background thread exited due to an unrecoverable error.""" - return self._rtc_error.is_set() - def start(self) -> None: - """Start the inference engine. Launches the RTC background thread if enabled.""" - if self._use_rtc: - self._action_queue = ActionQueue(self._rtc_config) - self._obs_holder = { - "obs": None, - "robot_type": self._robot.robot_type, - } - self._shutdown_event.clear() - self._rtc_thread = Thread( - target=self._rtc_loop, - daemon=True, - name="RTCInference", - ) - self._rtc_thread.start() - logger.info("RTC inference thread started") + """Launch the RTC background thread.""" + self._action_queue = ActionQueue(self._rtc_config) + self._obs_holder = { + "obs": None, + "robot_type": self._robot.robot_type, + } + self._shutdown_event.clear() + self._rtc_thread = Thread( + target=self._rtc_loop, + daemon=True, + name="RTCInference", + ) + self._rtc_thread.start() + logger.info("RTC inference thread started") def stop(self) -> None: """Signal the RTC thread to stop and wait for it.""" @@ -245,67 +212,32 @@ class InferenceEngine: self._rtc_thread = None def pause(self) -> None: - """Pause the RTC background thread (used during DAgger takeover).""" self._policy_active.clear() def resume(self) -> None: - """Resume the RTC background thread.""" self._policy_active.set() def reset(self) -> None: - """Reset policy, processors, and action queue between episodes.""" self._policy.reset() self._preprocessor.reset() self._postprocessor.reset() - if self._use_rtc and self._action_queue is not None: + if self._action_queue is not None: self._action_queue.clear() # ------------------------------------------------------------------ - # Sync inference + # Action production (called from main thread) # ------------------------------------------------------------------ - def get_action_sync(self, obs_frame: dict) -> torch.Tensor: - """Run synchronous single-step inference. - - Parameters - ---------- - obs_frame: - Observation dict with numpy arrays (output of ``build_dataset_frame``). - - Returns - ------- - Action tensor on CPU. - """ - observation = copy(obs_frame) - policy_device = torch.device(self._device) - with ( - torch.inference_mode(), - torch.autocast(device_type=policy_device.type) - if policy_device.type == "cuda" and self._policy.config.use_amp - else torch.inference_mode(), - ): - observation = prepare_observation_for_inference( - observation, policy_device, self._task, self._robot.robot_type - ) - observation = self._preprocessor(observation) - action = self._policy.select_action(observation) - action = self._postprocessor(action) - return action.squeeze(0).cpu() - - # ------------------------------------------------------------------ - # RTC: action consumption (called from main thread) - # ------------------------------------------------------------------ - - def consume_rtc_action(self) -> torch.Tensor | None: - """Pop the next action from the RTC action queue. Returns None if empty.""" + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + """Pop the next action from the RTC queue (ignores ``obs_frame``).""" if self._action_queue is None: return None return self._action_queue.get() - def update_observation(self, obs_filtered: dict) -> None: - """Push the latest observation to the shared holder for the RTC thread.""" + def notify_observation(self, obs: dict) -> None: + """Publish the latest observation for the RTC thread to consume.""" with self._obs_lock: - self._obs_holder["obs"] = obs_filtered + self._obs_holder["obs"] = obs # ------------------------------------------------------------------ # RTC: background inference thread @@ -342,17 +274,14 @@ class InferenceEngine: latency = latency_tracker.max() delay = math.ceil(latency / time_per_chunk) if latency else 0 - # Build observation batch using the same pipeline as sync inference obs_batch = build_dataset_frame(self._hw_features, obs, prefix="observation") obs_batch = prepare_observation_for_inference( obs_batch, policy_device, self._task, self._robot.robot_type ) - # predict_action_chunk expects batched task format obs_batch["task"] = [self._task] preprocessed = self._preprocessor(obs_batch) - # Re-anchor leftover for relative-action policies if prev_actions is not None and self._relative_step is not None: state_tensor = preprocessed.get(OBS_STATE) if state_tensor is not None: diff --git a/src/lerobot/rollout/inference/sync.py b/src/lerobot/rollout/inference/sync.py new file mode 100644 index 000000000..da2a8a8ae --- /dev/null +++ b/src/lerobot/rollout/inference/sync.py @@ -0,0 +1,94 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Synchronous inference strategy: inline policy call per control tick.""" + +from __future__ import annotations + +import logging +from copy import copy + +import torch + +from lerobot.policies.pretrained import PreTrainedPolicy +from lerobot.policies.utils import make_robot_action, prepare_observation_for_inference +from lerobot.processor import PolicyProcessorPipeline + +from .base import InferenceStrategy + +logger = logging.getLogger(__name__) + + +class SyncInferenceStrategy(InferenceStrategy): + """Inline synchronous inference: compute one action per call. + + ``get_action`` runs the full policy pipeline (pre/post-processor + + ``select_action``) on the given observation frame and returns a + CPU action tensor reordered to match the dataset action keys. + """ + + def __init__( + self, + policy: PreTrainedPolicy, + preprocessor: PolicyProcessorPipeline, + postprocessor: PolicyProcessorPipeline, + dataset_features: dict, + ordered_action_keys: list[str], + task: str, + device: str | None, + robot_type: str, + ) -> None: + self._policy = policy + self._preprocessor = preprocessor + self._postprocessor = postprocessor + self._dataset_features = dataset_features + self._ordered_action_keys = ordered_action_keys + self._task = task + self._device = device or "cpu" + self._robot_type = robot_type + + def start(self) -> None: + """No background resources to start.""" + + def stop(self) -> None: + """No background resources to stop.""" + + def reset(self) -> None: + self._policy.reset() + self._preprocessor.reset() + self._postprocessor.reset() + + def get_action(self, obs_frame: dict | None) -> torch.Tensor | None: + if obs_frame is None: + return None + observation = copy(obs_frame) + policy_device = torch.device(self._device) + with ( + torch.inference_mode(), + torch.autocast(device_type=policy_device.type) + if policy_device.type == "cuda" and self._policy.config.use_amp + else torch.inference_mode(), + ): + observation = prepare_observation_for_inference( + observation, policy_device, self._task, self._robot_type + ) + observation = self._preprocessor(observation) + action = self._policy.select_action(observation) + action = self._postprocessor(action) + action_tensor = action.squeeze(0).cpu() + + # Reorder to match dataset action ordering so the caller can treat + # the returned tensor uniformly across backends. + action_dict = make_robot_action(action_tensor, self._dataset_features) + return torch.tensor([action_dict[k] for k in self._ordered_action_keys]) diff --git a/src/lerobot/rollout/strategies/__init__.py b/src/lerobot/rollout/strategies/__init__.py index 446bc7155..1a4943e79 100644 --- a/src/lerobot/rollout/strategies/__init__.py +++ b/src/lerobot/rollout/strategies/__init__.py @@ -14,11 +14,11 @@ """Rollout strategies — public API re-exports.""" -from .core import RolloutStrategy, infer_action +from .core import RolloutStrategy, send_next_action from .factory import create_strategy __all__ = [ "RolloutStrategy", "create_strategy", - "infer_action", + "send_next_action", ] diff --git a/src/lerobot/rollout/strategies/base.py b/src/lerobot/rollout/strategies/base.py index 9e1e16a29..b0714a297 100644 --- a/src/lerobot/rollout/strategies/base.py +++ b/src/lerobot/rollout/strategies/base.py @@ -22,7 +22,7 @@ import time from lerobot.utils.robot_utils import precise_sleep from ..context import RolloutContext -from .core import RolloutStrategy, infer_action +from .core import RolloutStrategy, send_next_action logger = logging.getLogger(__name__) @@ -30,45 +30,42 @@ logger = logging.getLogger(__name__) class BaseStrategy(RolloutStrategy): """Autonomous policy rollout with no data recording. - Supports both synchronous and RTC inference backends via the - :class:`InferenceEngine`. All actions flow through the - ``robot_action_processor`` pipeline before reaching the robot. + All actions flow through the ``robot_action_processor`` pipeline + before reaching the robot. """ def setup(self, ctx: RolloutContext) -> None: self._init_engine(ctx) - logger.info("Base strategy ready (rtc=%s)", self._engine.is_rtc) + logger.info("Base strategy ready") def run(self, ctx: RolloutContext) -> None: engine = self._engine - cfg = ctx.cfg - robot = ctx.robot_wrapper + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper interpolator = self._interpolator control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.ordered_action_keys + ordered_keys = ctx.data.ordered_action_keys start_time = time.perf_counter() + engine.resume() - if engine.is_rtc: - engine.resume() - - while not ctx.shutdown_event.is_set(): + while not ctx.runtime.shutdown_event.is_set(): loop_start = time.perf_counter() if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: break obs = robot.get_observation() - obs_processed = ctx.robot_observation_processor(obs) - - if engine.is_rtc: - engine.update_observation(obs_processed) + obs_processed = ctx.processors.robot_observation_processor(obs) + engine.notify_observation(obs_processed) if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - infer_action(engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.dataset_features) + send_next_action( + engine, obs_processed, obs, ctx, interpolator, ordered_keys, ctx.data.dataset_features + ) dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: diff --git a/src/lerobot/rollout/strategies/core.py b/src/lerobot/rollout/strategies/core.py index 38d4835c5..4ae8e0196 100644 --- a/src/lerobot/rollout/strategies/core.py +++ b/src/lerobot/rollout/strategies/core.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -"""Rollout strategy ABC and shared inference helper.""" +"""Rollout strategy ABC and shared action-dispatch helper.""" from __future__ import annotations @@ -20,20 +20,16 @@ import abc import time from typing import TYPE_CHECKING -import torch - from lerobot.policies.rtc import ActionInterpolator -from lerobot.policies.utils import make_robot_action from lerobot.utils.constants import OBS_STR from lerobot.utils.feature_utils import build_dataset_frame from lerobot.utils.robot_utils import precise_sleep -from ..inference import InferenceEngine +from ..inference import InferenceStrategy if TYPE_CHECKING: from ..configs import RolloutStrategyConfig from ..context import RolloutContext - from ..inference import InferenceEngine class RolloutStrategy(abc.ABC): @@ -46,33 +42,18 @@ class RolloutStrategy(abc.ABC): def __init__(self, config: RolloutStrategyConfig) -> None: self.config = config - self._engine: InferenceEngine | None = None + self._engine: InferenceStrategy | None = None self._interpolator: ActionInterpolator | None = None self._warmup_flushed: bool = False def _init_engine(self, ctx: RolloutContext) -> None: - """Create and start the inference engine and action interpolator. + """Attach the inference strategy + interpolator and start the backend. - Call this from ``setup()`` to avoid duplicating the engine - construction across every strategy. + Call this from ``setup()`` so strategies share identical setup + without duplicating code. """ - - self._interpolator = ActionInterpolator(multiplier=ctx.cfg.interpolation_multiplier) - self._engine = InferenceEngine( - policy=ctx.policy, - preprocessor=ctx.preprocessor, - postprocessor=ctx.postprocessor, - robot_wrapper=ctx.robot_wrapper, - rtc_config=ctx.cfg.rtc, - hw_features=ctx.hw_features, - action_keys=ctx.action_keys, - task=ctx.cfg.task, - fps=ctx.cfg.fps, - device=ctx.cfg.device, - use_torch_compile=ctx.cfg.use_torch_compile, - compile_warmup_inferences=ctx.cfg.compile_warmup_inferences, - shutdown_event=ctx.shutdown_event, - ) + self._interpolator = ActionInterpolator(multiplier=ctx.runtime.cfg.interpolation_multiplier) + self._engine = ctx.policy.inference self._engine.start() self._warmup_flushed = False @@ -87,7 +68,7 @@ class RolloutStrategy(abc.ABC): interpolator = self._interpolator if not use_torch_compile: return False - if not engine.compile_warmup_done.is_set(): + if not engine.ready: dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: precise_sleep(sleep_t) @@ -96,18 +77,19 @@ class RolloutStrategy(abc.ABC): engine.reset() interpolator.reset() self._warmup_flushed = True - if engine.is_rtc: - engine.resume() + engine.resume() return False def _teardown_hardware(self, ctx: RolloutContext) -> None: """Stop the inference engine and disconnect hardware.""" if self._engine is not None: self._engine.stop() - if ctx.robot.is_connected: - ctx.robot.disconnect() - if ctx.teleop is not None and ctx.teleop.is_connected: - ctx.teleop.disconnect() + robot = ctx.hardware.robot_wrapper.inner + if robot.is_connected: + robot.disconnect() + teleop = ctx.hardware.teleop + if teleop is not None and teleop.is_connected: + teleop.disconnect() @abc.abstractmethod def setup(self, ctx: RolloutContext) -> None: @@ -123,12 +105,12 @@ class RolloutStrategy(abc.ABC): # --------------------------------------------------------------------------- -# Shared inference helper +# Shared action-dispatch helper # --------------------------------------------------------------------------- -def infer_action( - engine: InferenceEngine, +def send_next_action( + engine: InferenceStrategy, obs_processed: dict, obs_raw: dict, ctx: RolloutContext, @@ -136,53 +118,27 @@ def infer_action( ordered_keys: list[str], features: dict, ) -> dict | None: - """Run one policy inference step and send the resulting action to the robot. + """Dispatch the next action to the robot. - Handles both sync and RTC backends. Uses the interpolator for smooth - control at higher-than-inference rates (works with any multiplier, - including 1 where it acts as a pass-through). + Pulls the next action tensor from the inference strategy, feeds the + interpolator, and sends the interpolated action through the + ``robot_action_processor`` to the robot. Works identically for + sync and async backends — the strategy never needs to branch. - Parameters - ---------- - engine: - The inference engine (sync or RTC). - obs_processed: - Observation dict after ``robot_observation_processor``. - obs_raw: - Raw observation dict (needed by ``robot_action_processor``). - ctx: - Rollout context. - interpolator: - Action interpolator for Nx control rate. - ordered_keys: - Ordered action feature names (policy-to-robot mapping). - features: - Feature specification dict for ``build_dataset_frame`` / - ``make_robot_action``. Use ``dataset.features`` when recording, - ``ctx.dataset_features`` otherwise. - - Returns - ------- - Action dict sent to the robot, or ``None`` if no action was - available (empty RTC queue, interpolator buffer not ready). + Returns the action dict that was sent, or ``None`` if no action was + ready (e.g. empty async queue, interpolator not yet primed). """ - if engine.is_rtc: - if interpolator.needs_new_action(): - action_tensor = engine.consume_rtc_action() - if action_tensor is not None: - interpolator.add(action_tensor.cpu()) - else: - if interpolator.needs_new_action(): - obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) - action_tensor = engine.get_action_sync(obs_frame) - action_dict = make_robot_action(action_tensor, features) - action_t = torch.tensor([action_dict[k] for k in ordered_keys]) - interpolator.add(action_t) + if interpolator.needs_new_action(): + obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) + action_tensor = engine.get_action(obs_frame) + if action_tensor is not None: + interpolator.add(action_tensor.cpu()) interp = interpolator.get() - if interp is not None: - action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} - processed = ctx.robot_action_processor((action_dict, obs_raw)) - ctx.robot_wrapper.send_action(processed) - return action_dict - return None + if interp is None: + return None + + action_dict = {k: interp[i].item() for i, k in enumerate(ordered_keys) if i < len(interp)} + processed = ctx.processors.robot_action_processor((action_dict, obs_raw)) + ctx.hardware.robot_wrapper.send_action(processed) + return action_dict diff --git a/src/lerobot/rollout/strategies/dagger.py b/src/lerobot/rollout/strategies/dagger.py index 95fdaf6a2..f2c7ae522 100644 --- a/src/lerobot/rollout/strategies/dagger.py +++ b/src/lerobot/rollout/strategies/dagger.py @@ -44,13 +44,14 @@ from lerobot.processor import RobotProcessorPipeline from lerobot.teleoperators import Teleoperator from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.feature_utils import build_dataset_frame +from lerobot.utils.pedal import start_pedal_listener from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import log_say from ..configs import DAggerStrategyConfig from ..context import RolloutContext from ..robot_wrapper import ThreadSafeRobot -from . import RolloutStrategy, infer_action +from .core import RolloutStrategy, send_next_action logger = logging.getLogger(__name__) @@ -80,9 +81,8 @@ _DAGGER_TRANSITIONS: dict[tuple[DAggerPhase, str], DAggerPhase] = { class DAggerEvents: """Thread-safe container for DAgger keyboard/pedal events. - Replaces the previous plain dict with a lock-protected phase enum - and edge-triggered transition requests. The keyboard/pedal threads - write transition requests; the main loop consumes them. + The keyboard/pedal threads write transition requests; the main loop + consumes them. """ def __init__(self) -> None: @@ -122,11 +122,7 @@ class DAggerEvents: self._pending_transition = event def consume_transition(self) -> tuple[DAggerPhase, DAggerPhase] | None: - """Consume a pending transition (called from main loop). - - Returns ``(old_phase, new_phase)`` if a valid transition was - pending, or ``None`` if there is nothing to process. - """ + """Consume a pending transition (called from main loop).""" with self._lock: if self._pending_transition is None: return None @@ -149,7 +145,7 @@ class DAggerEvents: # --------------------------------------------------------------------------- -# Teleoperator helpers (extracted from examples/hil/hil_utils.py) +# Teleoperator helpers # --------------------------------------------------------------------------- @@ -199,11 +195,7 @@ def _reset_loop( teleop_action_processor: RobotProcessorPipeline, robot_action_processor: RobotProcessorPipeline, ) -> None: - """Reset period where the human repositions the environment. - - All teleop actions flow through the processor pipelines to ensure - correct behavior for EE-space robots. - """ + """Reset period where the human repositions the environment.""" logger.info("RESET — press any key to enable teleoperation") events.in_reset = True @@ -250,7 +242,6 @@ def _init_dagger_keyboard(events: DAggerEvents): def on_press(key): try: - # During the reset phase, only accept episode-start or stop if events.in_reset: if ( key in [keyboard.Key.space, keyboard.Key.right] @@ -263,7 +254,6 @@ def _init_dagger_keyboard(events: DAggerEvents): events.start_next_episode = True return - # Phase-aware transition requests phase = events.phase if key == keyboard.Key.space and phase == DAggerPhase.AUTONOMOUS: logger.info("PAUSED — press 'c' to take control or 'p' to resume policy") @@ -283,7 +273,6 @@ def _init_dagger_keyboard(events: DAggerEvents): logger.info("Resuming policy...") events.request_transition("resume") - # Episode-level controls (valid in any phase) elif key == keyboard.Key.right: logger.info("End episode") events.exit_early = True @@ -303,49 +292,27 @@ def _init_dagger_keyboard(events: DAggerEvents): return listener -def _start_pedal_listener(events: DAggerEvents) -> None: - """Start foot pedal listener thread if evdev is available.""" - import threading +_DAGGER_PEDAL_KEYS = ("KEY_A", "KEY_C") - try: - from evdev import InputDevice, categorize, ecodes - except ImportError: - return - pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" +def _dagger_pedal_callback(events: DAggerEvents): + """Build the pedal key-press handler for DAgger's state machine.""" - def pedal_reader(): - try: - dev = InputDevice(pedal_device) - logger.info("Pedal connected: %s", dev.name) - for ev in dev.read_loop(): - if ev.type != ecodes.EV_KEY: - continue - key = categorize(ev) - code = key.keycode - if isinstance(code, (list, tuple)): - code = code[0] - if key.keystate != 1: - continue - if events.in_reset: - if code in ["KEY_A", "KEY_C"]: - events.start_next_episode = True - else: - if code not in ["KEY_A", "KEY_C"]: - continue - phase = events.phase - if phase == DAggerPhase.CORRECTING: - events.request_transition("resume") - elif phase == DAggerPhase.PAUSED: - events.request_transition("takeover") - elif phase == DAggerPhase.AUTONOMOUS: - events.request_transition("pause") - except (FileNotFoundError, PermissionError): - pass - except Exception as e: - logger.warning("Pedal error: %s", e) + def on_press(code: str) -> None: + if code not in _DAGGER_PEDAL_KEYS: + return + if events.in_reset: + events.start_next_episode = True + return + phase = events.phase + if phase == DAggerPhase.CORRECTING: + events.request_transition("resume") + elif phase == DAggerPhase.PAUSED: + events.request_transition("takeover") + elif phase == DAggerPhase.AUTONOMOUS: + events.request_transition("pause") - threading.Thread(target=pedal_reader, daemon=True).start() + return on_press # --------------------------------------------------------------------------- @@ -356,19 +323,14 @@ def _start_pedal_listener(events: DAggerEvents) -> None: class DAggerStrategy(RolloutStrategy): """Human-in-the-Loop data collection with intervention tagging. - Uses a formal state machine (see :class:`DAggerPhase`) for phase - transitions, eliminating impossible states:: + State machine:: AUTONOMOUS --(SPACE)--> PAUSED --(c)--> CORRECTING --(p)--> AUTONOMOUS --(p)--> AUTONOMOUS - Supports both synchronous and RTC inference backends. - All actions (policy and teleop) flow through the appropriate - processor pipelines, supporting EE-space recording. - - Intervention frames are tagged with ``intervention=1`` (int64) in - the dataset to allow downstream BC training to distinguish - autonomous from human-corrected data. + Intervention frames are tagged with ``intervention=True`` (bool) in + the dataset; autonomous frames with ``intervention=False``. When + ``record_autonomous=False`` only corrections are recorded. """ config: DAggerStrategyConfig @@ -382,20 +344,20 @@ class DAggerStrategy(RolloutStrategy): self._init_engine(ctx) self._listener = _init_dagger_keyboard(self._events) - _start_pedal_listener(self._events) + start_pedal_listener(_dagger_pedal_callback(self._events)) logger.info( - "DAgger strategy ready (rtc=%s, episodes=%d, episode_time=%.0fs)", - self._engine.is_rtc, + "DAgger strategy ready (episodes=%d, episode_time=%.0fs, record_autonomous=%s)", self.config.num_episodes, self.config.episode_time_s, + self.config.record_autonomous, ) logger.info("Controls: SPACE=pause, c=take control, p=resume, ->=end, <-=redo, ESC=stop") def run(self, ctx: RolloutContext) -> None: - dataset = ctx.dataset + dataset = ctx.data.dataset events = self._events - teleop = ctx.teleop + teleop = ctx.hardware.teleop with VideoEncodingManager(dataset): try: @@ -417,12 +379,12 @@ class DAggerStrategy(RolloutStrategy): if recorded < self.config.num_episodes and not events.stop_recording: _reset_loop( - ctx.robot_wrapper, + ctx.hardware.robot_wrapper, teleop, events, - int(ctx.cfg.fps), - ctx.teleop_action_processor, - ctx.robot_action_processor, + int(ctx.runtime.cfg.fps), + ctx.processors.teleop_action_processor, + ctx.processors.robot_action_processor, ) finally: @@ -435,12 +397,12 @@ class DAggerStrategy(RolloutStrategy): if self._listener is not None and not is_headless(): self._listener.stop() - if ctx.dataset is not None: - ctx.dataset.finalize() - if ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: - ctx.dataset.push_to_hub( - tags=ctx.cfg.dataset.tags, - private=ctx.cfg.dataset.private, + if ctx.data.dataset is not None: + ctx.data.dataset.finalize() + if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + ctx.data.dataset.push_to_hub( + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, ) self._teardown_hardware(ctx) @@ -453,18 +415,19 @@ class DAggerStrategy(RolloutStrategy): def _run_episode(self, ctx: RolloutContext) -> None: """Run a single DAgger episode with the HIL state machine.""" engine = self._engine - cfg = ctx.cfg - robot = ctx.robot_wrapper - teleop = ctx.teleop - dataset = ctx.dataset + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + teleop = ctx.hardware.teleop + dataset = ctx.data.dataset events = self._events interpolator = self._interpolator control_interval = interpolator.get_control_interval(cfg.fps) stream_online = bool(cfg.dataset.streaming_encoding) if cfg.dataset else False record_stride = max(1, cfg.interpolation_multiplier) + record_autonomous = self.config.record_autonomous - ordered_keys = ctx.ordered_action_keys + ordered_keys = ctx.data.ordered_action_keys features = dataset.features engine.reset() @@ -480,8 +443,7 @@ class DAggerStrategy(RolloutStrategy): record_tick = 0 start_t = time.perf_counter() - if engine.is_rtc: - engine.resume() + engine.resume() while timestamp < self.config.episode_time_s: loop_start = time.perf_counter() @@ -490,7 +452,6 @@ class DAggerStrategy(RolloutStrategy): events.exit_early = False break - # --- Process pending phase transition --- transition = events.consume_transition() if transition is not None: old_phase, new_phase = transition @@ -499,16 +460,15 @@ class DAggerStrategy(RolloutStrategy): phase = events.phase - # --- Get observation --- obs = robot.get_observation() - obs_processed = ctx.robot_observation_processor(obs) + obs_processed = ctx.processors.robot_observation_processor(obs) obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) # --- CORRECTING: human teleop control --- if phase == DAggerPhase.CORRECTING: teleop_action = teleop.get_action() - processed_teleop = ctx.teleop_action_processor((teleop_action, obs)) - robot_action_to_send = ctx.robot_action_processor((processed_teleop, obs)) + processed_teleop = ctx.processors.teleop_action_processor((teleop_action, obs)) + robot_action_to_send = ctx.processors.robot_action_processor((processed_teleop, obs)) robot.send_action(robot_action_to_send) action_frame = build_dataset_frame(features, processed_teleop, prefix=ACTION) if record_tick % record_stride == 0: @@ -516,7 +476,7 @@ class DAggerStrategy(RolloutStrategy): **obs_frame, **action_frame, "task": task_str, - "intervention": np.array([1], dtype=np.int64), + "intervention": np.array([True], dtype=bool), } if stream_online: dataset.add_frame(frame) @@ -531,26 +491,25 @@ class DAggerStrategy(RolloutStrategy): # --- AUTONOMOUS: policy control --- else: - if engine.is_rtc: - engine.update_observation(obs_processed) + engine.notify_observation(obs_processed) if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): timestamp = time.perf_counter() - start_t continue - action_dict = infer_action( + action_dict = send_next_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) if action_dict is not None: - last_action = ctx.robot_action_processor((action_dict, obs)) + last_action = ctx.processors.robot_action_processor((action_dict, obs)) action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) - if record_tick % record_stride == 0: + if record_autonomous and record_tick % record_stride == 0: frame = { **obs_frame, **action_frame, "task": task_str, - "intervention": np.array([0], dtype=np.int64), + "intervention": np.array([False], dtype=bool), } if stream_online: dataset.add_frame(frame) @@ -563,9 +522,8 @@ class DAggerStrategy(RolloutStrategy): precise_sleep(sleep_t) timestamp = time.perf_counter() - start_t - # End of episode: flush any buffered frames - if engine.is_rtc: - engine.pause() + # End of episode: pause engine, disable teleop, flush buffer + engine.pause() _teleop_disable_torque(teleop) if not stream_online: @@ -587,9 +545,7 @@ class DAggerStrategy(RolloutStrategy): ) -> None: """Execute side-effects for a validated phase transition.""" if old_phase == DAggerPhase.AUTONOMOUS and new_phase == DAggerPhase.PAUSED: - # Pause engine + align teleop to robot position - if engine.is_rtc: - engine.pause() + engine.pause() obs = robot.get_observation() robot_pos = { k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features @@ -598,14 +554,10 @@ class DAggerStrategy(RolloutStrategy): interpolator.reset() elif new_phase == DAggerPhase.CORRECTING: - # Enable human teleop control _teleop_disable_torque(teleop) - if engine.is_rtc: - engine.reset() + engine.reset() elif new_phase == DAggerPhase.AUTONOMOUS: - # Resume policy from pause or correction interpolator.reset() engine.reset() - if engine.is_rtc: - engine.resume() + engine.resume() diff --git a/src/lerobot/rollout/strategies/factory.py b/src/lerobot/rollout/strategies/factory.py index 9c43ea2af..0705ca3d0 100644 --- a/src/lerobot/rollout/strategies/factory.py +++ b/src/lerobot/rollout/strategies/factory.py @@ -18,37 +18,28 @@ from __future__ import annotations from typing import TYPE_CHECKING +from .base import BaseStrategy from .core import RolloutStrategy +from .dagger import DAggerStrategy +from .highlight import HighlightStrategy +from .sentry import SentryStrategy if TYPE_CHECKING: from lerobot.rollout.configs import RolloutStrategyConfig -def _lazy_strategy_map() -> dict[str, type[RolloutStrategy]]: - """Build the strategy type-name → class mapping with lazy imports.""" - from .base import BaseStrategy - from .dagger import DAggerStrategy - from .highlight import HighlightStrategy - from .sentry import SentryStrategy - - return { - "base": BaseStrategy, - "sentry": SentryStrategy, - "highlight": HighlightStrategy, - "dagger": DAggerStrategy, - } - - def create_strategy(config: RolloutStrategyConfig) -> RolloutStrategy: """Instantiate the appropriate strategy from a config object. - Uses ``config.type`` (the name registered via ``draccus.ChoiceRegistry``) - to look up the strategy class, so adding a new strategy only requires - registering its config subclass and adding one entry to - ``_lazy_strategy_map``. + Dispatches on ``config.type`` (the name registered via + ``draccus.ChoiceRegistry``). """ - strategy_map = _lazy_strategy_map() - strategy_cls = strategy_map.get(config.type) - if strategy_cls is None: - raise ValueError(f"Unknown strategy type '{config.type}'. Available: {sorted(strategy_map.keys())}") - return strategy_cls(config) + if config.type == "base": + return BaseStrategy(config) + if config.type == "sentry": + return SentryStrategy(config) + if config.type == "highlight": + return HighlightStrategy(config) + if config.type == "dagger": + return DAggerStrategy(config) + raise ValueError(f"Unknown strategy type '{config.type}'. Available: base, sentry, highlight, dagger") diff --git a/src/lerobot/rollout/strategies/highlight.py b/src/lerobot/rollout/strategies/highlight.py index 5982c6207..82789930b 100644 --- a/src/lerobot/rollout/strategies/highlight.py +++ b/src/lerobot/rollout/strategies/highlight.py @@ -30,7 +30,7 @@ from lerobot.utils.robot_utils import precise_sleep from ..configs import HighlightStrategyConfig from ..context import RolloutContext from ..ring_buffer import RolloutRingBuffer -from . import RolloutStrategy, infer_action +from .core import RolloutStrategy, send_next_action logger = logging.getLogger(__name__) @@ -45,6 +45,9 @@ class HighlightStrategy(RolloutStrategy): 2. Live recording continues until the save key is pressed again. 3. The episode is saved and the ring buffer resumes capturing. + Requires ``streaming_encoding=True`` (enforced in config validation) + so that ``dataset.add_frame`` is a non-blocking queue put — draining + 900 frames stays sub-ms per frame. """ config: HighlightStrategyConfig @@ -63,10 +66,10 @@ class HighlightStrategy(RolloutStrategy): self._ring = RolloutRingBuffer( max_seconds=self.config.ring_buffer_seconds, max_memory_mb=self.config.ring_buffer_max_memory_mb, - fps=ctx.cfg.fps, + fps=ctx.runtime.cfg.fps, ) - self._shutdown_event = ctx.shutdown_event + self._shutdown_event = ctx.runtime.shutdown_event self._setup_keyboard() logger.info( "Highlight strategy ready (buffer=%.0fs, key='%s')", @@ -76,74 +79,67 @@ class HighlightStrategy(RolloutStrategy): def run(self, ctx: RolloutContext) -> None: engine = self._engine - cfg = ctx.cfg - robot = ctx.robot_wrapper - dataset = ctx.dataset + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + dataset = ctx.data.dataset ring = self._ring interpolator = self._interpolator control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.ordered_action_keys + ordered_keys = ctx.data.ordered_action_keys features = dataset.features - if engine.is_rtc: - engine.resume() + engine.resume() start_time = time.perf_counter() task_str = cfg.dataset.single_task if cfg.dataset else cfg.task with VideoEncodingManager(dataset): try: - while not ctx.shutdown_event.is_set(): + while not ctx.runtime.shutdown_event.is_set(): loop_start = time.perf_counter() if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: break obs = robot.get_observation() - obs_processed = ctx.robot_observation_processor(obs) - - if engine.is_rtc: - engine.update_observation(obs_processed) + obs_processed = ctx.processors.robot_observation_processor(obs) + engine.notify_observation(obs_processed) if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - action_dict = infer_action( + action_dict = send_next_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) - # Build frame for ring buffer / live recording if action_dict is not None: obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": task_str} - # Handle save key toggle if self._save_requested.is_set(): self._save_requested.clear() if not self._recording_live.is_set(): logger.info( - "Flushing ring buffer (%d frames) + starting live recording", len(ring) + "Flushing ring buffer (%d frames) + starting live recording", + len(ring), ) for buffered_frame in ring.drain(): dataset.add_frame(buffered_frame) self._recording_live.set() else: - # Save current frame as the last frame of the episode dataset.add_frame(frame) dataset.save_episode() logger.info("Episode saved") self._recording_live.clear() engine.reset() interpolator.reset() - if engine.is_rtc: - engine.resume() + engine.resume() if self._recording_live.is_set(): dataset.add_frame(frame) else: - # Current frame goes into the ring buffer for next potential save. ring.append(frame) dt = time.perf_counter() - loop_start @@ -159,12 +155,12 @@ class HighlightStrategy(RolloutStrategy): if self._listener is not None: self._listener.stop() - if ctx.dataset is not None: - ctx.dataset.finalize() - if ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: - ctx.dataset.push_to_hub( - tags=ctx.cfg.dataset.tags, - private=ctx.cfg.dataset.private, + if ctx.data.dataset is not None: + ctx.data.dataset.finalize() + if ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + ctx.data.dataset.push_to_hub( + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, ) self._teardown_hardware(ctx) @@ -172,7 +168,6 @@ class HighlightStrategy(RolloutStrategy): def _setup_keyboard(self) -> None: """Set up keyboard listener for the save key.""" - if is_headless(): logger.warning("Headless environment — highlight save key unavailable") return diff --git a/src/lerobot/rollout/strategies/sentry.py b/src/lerobot/rollout/strategies/sentry.py index 024584a9e..bc6a8948c 100644 --- a/src/lerobot/rollout/strategies/sentry.py +++ b/src/lerobot/rollout/strategies/sentry.py @@ -19,7 +19,8 @@ from __future__ import annotations import contextlib import logging import time -from threading import Event, Lock, Thread +from concurrent.futures import Future, ThreadPoolExecutor +from threading import Event, Lock from lerobot.datasets import VideoEncodingManager from lerobot.utils.constants import ACTION, OBS_STR @@ -28,7 +29,7 @@ from lerobot.utils.robot_utils import precise_sleep from ..configs import SentryStrategyConfig from ..context import RolloutContext -from . import RolloutStrategy, infer_action +from .core import RolloutStrategy, send_next_action logger = logging.getLogger(__name__) @@ -36,32 +37,30 @@ logger = logging.getLogger(__name__) class SentryStrategy(RolloutStrategy): """Continuous autonomous rollout with always-on recording. - Episodes are auto-rotated every ``episode_duration_s`` seconds. - The dataset is pushed to Hub in the background every - ``upload_every_n_episodes`` episodes. + Episodes are auto-rotated every ``episode_duration_s`` seconds. The + dataset is pushed to the Hub via a bounded single-worker executor so + no push is ever silently dropped and exactly one push runs at a time. + + Policy state (hidden state, RTC queue) intentionally persists across + episode boundaries — Sentry slices one continuous rollout, the robot + does not reset between slices. Requires ``streaming_encoding=True`` (enforced in config validation) to prevent disk I/O from blocking the control loop. - - All actions flow through ``robot_observation_processor`` (observations) - and ``robot_action_processor`` (actions) before reaching the robot, - supporting EE-space recording with joint-space robots. - - **Thread safety:** A lock (``_episode_lock``) serialises - ``save_episode`` and ``push_to_hub`` calls so the background push - thread never reads an episode that is still being finalised. """ config: SentryStrategyConfig def __init__(self, config: SentryStrategyConfig): super().__init__(config) - self._push_thread: Thread | None = None + self._push_executor: ThreadPoolExecutor | None = None + self._pending_push: Future | None = None self._needs_push = Event() self._episode_lock = Lock() def setup(self, ctx: RolloutContext) -> None: self._init_engine(ctx) + self._push_executor = ThreadPoolExecutor(max_workers=1, thread_name_prefix="sentry-push") logger.info( "Sentry strategy ready (episode_duration=%.0fs, upload_every=%d eps)", self.config.episode_duration_s, @@ -70,17 +69,16 @@ class SentryStrategy(RolloutStrategy): def run(self, ctx: RolloutContext) -> None: engine = self._engine - cfg = ctx.cfg - robot = ctx.robot_wrapper - dataset = ctx.dataset + cfg = ctx.runtime.cfg + robot = ctx.hardware.robot_wrapper + dataset = ctx.data.dataset interpolator = self._interpolator control_interval = interpolator.get_control_interval(cfg.fps) - ordered_keys = ctx.ordered_action_keys + ordered_keys = ctx.data.ordered_action_keys features = dataset.features - if engine.is_rtc: - engine.resume() + engine.resume() start_time = time.perf_counter() episode_start = time.perf_counter() @@ -89,33 +87,29 @@ class SentryStrategy(RolloutStrategy): with VideoEncodingManager(dataset): try: - while not ctx.shutdown_event.is_set(): + while not ctx.runtime.shutdown_event.is_set(): loop_start = time.perf_counter() if cfg.duration > 0 and (time.perf_counter() - start_time) >= cfg.duration: break obs = robot.get_observation() - obs_processed = ctx.robot_observation_processor(obs) - - if engine.is_rtc: - engine.update_observation(obs_processed) + obs_processed = ctx.processors.robot_observation_processor(obs) + engine.notify_observation(obs_processed) if self._handle_warmup(cfg.use_torch_compile, loop_start, control_interval): continue - action_dict = infer_action( + action_dict = send_next_action( engine, obs_processed, obs, ctx, interpolator, ordered_keys, features ) - # Record frame if action_dict is not None: obs_frame = build_dataset_frame(features, obs_processed, prefix=OBS_STR) action_frame = build_dataset_frame(features, action_dict, prefix=ACTION) frame = {**obs_frame, **action_frame, "task": task_str} dataset.add_frame(frame) - # Auto-rotate episodes elapsed = time.perf_counter() - episode_start if elapsed >= self.config.episode_duration_s: with self._episode_lock: @@ -129,10 +123,6 @@ class SentryStrategy(RolloutStrategy): episodes_since_push = 0 episode_start = time.perf_counter() - engine.reset() - interpolator.reset() - if engine.is_rtc: - engine.resume() dt = time.perf_counter() - loop_start if (sleep_t := control_interval - dt) > 0: @@ -145,32 +135,34 @@ class SentryStrategy(RolloutStrategy): self._needs_push.set() def teardown(self, ctx: RolloutContext) -> None: - # Wait for any in-flight background push - if self._push_thread is not None and self._push_thread.is_alive(): - self._push_thread.join(timeout=60) + # Flush any queued/running push cleanly. + if self._push_executor is not None: + self._push_executor.shutdown(wait=True) + self._push_executor = None - if ctx.dataset is not None: - ctx.dataset.finalize() - # Only push if there are unsaved changes since last background push - if self._needs_push.is_set() and ctx.cfg.dataset and ctx.cfg.dataset.push_to_hub: - ctx.dataset.push_to_hub( - tags=ctx.cfg.dataset.tags, - private=ctx.cfg.dataset.private, + if ctx.data.dataset is not None: + ctx.data.dataset.finalize() + if self._needs_push.is_set() and ctx.runtime.cfg.dataset and ctx.runtime.cfg.dataset.push_to_hub: + ctx.data.dataset.push_to_hub( + tags=ctx.runtime.cfg.dataset.tags, + private=ctx.runtime.cfg.dataset.private, ) self._teardown_hardware(ctx) logger.info("Sentry strategy teardown complete") def _background_push(self, dataset, cfg) -> None: - """Push dataset to hub in a background thread (non-blocking). + """Queue a Hub push on the single-worker executor. - Acquires ``_episode_lock`` during the push to prevent - ``save_episode`` from finalising a new episode mid-upload. + The executor's max_workers=1 guarantees at most one push runs at + a time; submitted tasks are queued rather than dropped. """ - if self._push_thread is not None and self._push_thread.is_alive(): - logger.info("Previous push still in progress, skipping") + if self._push_executor is None: return + if self._pending_push is not None and not self._pending_push.done(): + logger.info("Previous push still in progress; queueing next") + def _push(): try: with self._episode_lock: @@ -183,5 +175,4 @@ class SentryStrategy(RolloutStrategy): except Exception as e: logger.error("Background push failed: %s", e) - self._push_thread = Thread(target=_push, daemon=True) - self._push_thread.start() + self._pending_push = self._push_executor.submit(_push) diff --git a/src/lerobot/scripts/lerobot_record.py b/src/lerobot/scripts/lerobot_record.py index 9bf5996fd..fc4b5779c 100644 --- a/src/lerobot/scripts/lerobot_record.py +++ b/src/lerobot/scripts/lerobot_record.py @@ -379,7 +379,12 @@ def record_loop( @parser.wrap() -def record(cfg: RecordConfig) -> LeRobotDataset: +def record( + cfg: RecordConfig, + teleop_action_processor: RobotProcessorPipeline | None = None, + robot_action_processor: RobotProcessorPipeline | None = None, + robot_observation_processor: RobotProcessorPipeline | None = None, +) -> LeRobotDataset: init_logging() logging.info(pformat(asdict(cfg))) if cfg.display_data: @@ -393,7 +398,16 @@ def record(cfg: RecordConfig) -> LeRobotDataset: robot = make_robot_from_config(cfg.robot) teleop = make_teleoperator_from_config(cfg.teleop) if cfg.teleop is not None else None - teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() + # Fall back to identity pipelines when the caller doesn't supply processors. + if ( + teleop_action_processor is None + or robot_action_processor is None + or robot_observation_processor is None + ): + _t, _r, _o = make_default_processors() + teleop_action_processor = teleop_action_processor or _t + robot_action_processor = robot_action_processor or _r + robot_observation_processor = robot_observation_processor or _o dataset_features = combine_feature_dicts( aggregate_pipeline_dataset_features( diff --git a/src/lerobot/scripts/lerobot_rollout.py b/src/lerobot/scripts/lerobot_rollout.py index a4ce7ef46..0fd6da45a 100644 --- a/src/lerobot/scripts/lerobot_rollout.py +++ b/src/lerobot/scripts/lerobot_rollout.py @@ -37,7 +37,7 @@ Usage examples:: lerobot-rollout \\ --strategy.type=base \\ --policy.path=lerobot/pi0_base \\ - --rtc.enabled=true --rtc.execution_horizon=10 \\ + --inference.type=rtc --inference.rtc.execution_horizon=10 \\ --robot.type=so100_follower \\ --task="pick up cube" --duration=60 @@ -47,7 +47,7 @@ Usage examples:: --strategy.episode_duration_s=120 \\ --strategy.upload_every_n_episodes=5 \\ --policy.path=lerobot/pi0_base \\ - --rtc.enabled=true \\ + --inference.type=rtc \\ --robot.type=so100_follower \\ --dataset.repo_id=user/sentry-data \\ --dataset.single_task="patrol" --duration=3600 @@ -68,7 +68,6 @@ from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401 from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401 from lerobot.configs import parser -from lerobot.rl.process import ProcessSignalHandler from lerobot.robots import ( # noqa: F401 bi_openarm_follower, bi_so_follower, @@ -89,6 +88,7 @@ from lerobot.teleoperators import ( # noqa: F401 unitree_g1 as unitree_g1_teleop, ) from lerobot.utils.import_utils import register_third_party_plugins +from lerobot.utils.process import ProcessSignalHandler from lerobot.utils.utils import init_logging logger = logging.getLogger(__name__) diff --git a/src/lerobot/utils/pedal.py b/src/lerobot/utils/pedal.py new file mode 100644 index 000000000..88f3db1f9 --- /dev/null +++ b/src/lerobot/utils/pedal.py @@ -0,0 +1,83 @@ +# Copyright 2025 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Generic foot pedal listener using evdev. + +Callers supply a callback receiving the pressed key code (e.g. ``"KEY_A"``) +and an optional device path. The listener runs in a daemon thread and +silently no-ops when :mod:`evdev` is not installed or the device is +unavailable. Strategy-specific key mapping logic lives in the caller. +""" + +from __future__ import annotations + +import logging +import threading +from collections.abc import Callable + +logger = logging.getLogger(__name__) + +DEFAULT_PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + + +def start_pedal_listener( + on_press: Callable[[str], None], + device_path: str = DEFAULT_PEDAL_DEVICE, +) -> threading.Thread | None: + """Spawn a daemon thread that forwards pedal key-press codes to ``on_press``. + + Parameters + ---------- + on_press: + Callback invoked with the pressed key code string (e.g. ``"KEY_A"``) + on each pedal press event. The callback runs in the listener thread + and must be thread-safe. + device_path: + Linux input device path (e.g. ``/dev/input/by-id/...``). + + Returns + ------- + The started daemon :class:`threading.Thread`, or ``None`` when + :mod:`evdev` is not installed (optional dependency; silent no-op). + """ + try: + from evdev import InputDevice, categorize, ecodes + except ImportError: + return None + + def pedal_reader() -> None: + try: + dev = InputDevice(device_path) + logger.info("Pedal connected: %s", dev.name) + for ev in dev.read_loop(): + if ev.type != ecodes.EV_KEY: + continue + key = categorize(ev) + code = key.keycode + if isinstance(code, (list, tuple)): + code = code[0] + if key.keystate != 1: # only key-down events + continue + try: + on_press(code) + except Exception as cb_err: # pragma: no cover - defensive + logger.warning("Pedal callback error: %s", cb_err) + except (FileNotFoundError, PermissionError): + pass + except Exception as e: + logger.warning("Pedal error: %s", e) + + thread = threading.Thread(target=pedal_reader, daemon=True, name="PedalListener") + thread.start() + return thread diff --git a/src/lerobot/rl/process.py b/src/lerobot/utils/process.py similarity index 100% rename from src/lerobot/rl/process.py rename to src/lerobot/utils/process.py diff --git a/tests/utils/test_process.py b/tests/utils/test_process.py index ce56db173..65b24aac4 100644 --- a/tests/utils/test_process.py +++ b/tests/utils/test_process.py @@ -24,7 +24,7 @@ import pytest pytest.importorskip("grpc") -from lerobot.rl.process import ProcessSignalHandler # noqa: E402 +from lerobot.utils.process import ProcessSignalHandler # noqa: E402 # Fixture to reset shutdown_event_counter and original signal handlers before and after each test