diff --git a/examples/rac/hil_data_collection.py b/examples/rac/hil_data_collection.py index 4555eafc4..c668c21f6 100644 --- a/examples/rac/hil_data_collection.py +++ b/examples/rac/hil_data_collection.py @@ -37,6 +37,17 @@ from dataclasses import dataclass from pprint import pformat from typing import Any +import torch +from hil_utils import ( + HILDatasetConfig, + init_keyboard_listener, + make_identity_processors, + print_controls, + reset_loop, + teleop_disable_torque, + teleop_smooth_move_to, +) + from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 from lerobot.configs import parser @@ -46,8 +57,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts from lerobot.datasets.video_utils import VideoEncodingManager -import torch - from lerobot.policies.factory import make_policy, make_pre_post_processors from lerobot.policies.pretrained import PreTrainedPolicy from lerobot.policies.rtc import ActionInterpolator @@ -62,16 +71,6 @@ from lerobot.utils.robot_utils import precise_sleep from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data -from hil_utils import ( - HILDatasetConfig, - init_keyboard_listener, - make_identity_processors, - print_controls, - reset_loop, - teleop_disable_torque, - teleop_smooth_move_to, -) - logger = logging.getLogger(__name__) @@ -127,7 +126,7 @@ def rollout_loop( waiting_for_takeover = False last_action: dict[str, Any] | None = None robot_action: dict[str, Any] = {} - action_keys = sorted([k for k in robot.action_features.keys()]) + action_keys = sorted(robot.action_features.keys()) interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(fps) @@ -147,7 +146,9 @@ def rollout_loop( # Transition to paused state if events["policy_paused"] and not was_paused: obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + robot_pos = { + k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features + } teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) events["start_next_episode"] = False waiting_for_takeover = True @@ -341,6 +342,7 @@ def hil_collect(cfg: HILConfig) -> LeRobotDataset: def main(): from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() hil_collect() diff --git a/examples/rac/hil_data_collection_rtc.py b/examples/rac/hil_data_collection_rtc.py index 85390a35c..6368f325e 100644 --- a/examples/rac/hil_data_collection_rtc.py +++ b/examples/rac/hil_data_collection_rtc.py @@ -44,6 +44,15 @@ from threading import Event, Lock, Thread from typing import Any import torch +from hil_utils import ( + HILDatasetConfig, + init_keyboard_listener, + make_identity_processors, + print_controls, + reset_loop, + teleop_disable_torque, + teleop_smooth_move_to, +) from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401 from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401 @@ -64,19 +73,9 @@ from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleope from lerobot.utils.constants import ACTION, OBS_STR from lerobot.utils.control_utils import is_headless from lerobot.utils.robot_utils import precise_sleep -from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say +from lerobot.utils.utils import init_logging, log_say from lerobot.utils.visualization_utils import init_rerun, log_rerun_data -from hil_utils import ( - HILDatasetConfig, - init_keyboard_listener, - make_identity_processors, - print_controls, - reset_loop, - teleop_disable_torque, - teleop_smooth_move_to, -) - logger = logging.getLogger(__name__) @@ -239,7 +238,7 @@ def rollout_loop( was_paused = False waiting_for_takeover = False last_action: dict[str, Any] | None = None - action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")] + action_keys = [k for k in robot.action_features if k.endswith(".pos")] interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier) control_interval = interpolator.get_control_interval(fps) @@ -261,7 +260,9 @@ def rollout_loop( if events["policy_paused"] and not was_paused: policy_active.clear() obs = robot.get_observation() - robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features} + robot_pos = { + k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features + } teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50) events["start_next_episode"] = False waiting_for_takeover = True @@ -305,7 +306,9 @@ def rollout_loop( action_tensor = interpolator.get() if action_tensor is not None: - robot_action = {k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)} + robot_action = { + k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor) + } robot.send_action(robot_action) last_action = robot_action action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION) @@ -418,8 +421,17 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset: rtc_thread = Thread( target=rtc_inference_thread, - args=(policy, obs_holder, hw_features, preprocessor, postprocessor, - queue_holder, shutdown_event, policy_active, cfg), + args=( + policy, + obs_holder, + hw_features, + preprocessor, + postprocessor, + queue_holder, + shutdown_event, + policy_active, + cfg, + ), daemon=True, ) rtc_thread.start() @@ -492,10 +504,10 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset: def main(): from lerobot.utils.import_utils import register_third_party_plugins + register_third_party_plugins() hil_rtc_collect() if __name__ == "__main__": main() - diff --git a/examples/rac/hil_utils.py b/examples/rac/hil_utils.py index b1b65f7ef..1e71ec71e 100644 --- a/examples/rac/hil_utils.py +++ b/examples/rac/hil_utils.py @@ -17,8 +17,8 @@ from lerobot.processor.converters import ( transition_to_observation, transition_to_robot_action, ) -from lerobot.robots import Robot, RobotConfig -from lerobot.teleoperators import Teleoperator, TeleoperatorConfig +from lerobot.robots import Robot +from lerobot.teleoperators import Teleoperator from lerobot.utils.control_utils import is_headless from lerobot.utils.robot_utils import precise_sleep @@ -151,12 +151,12 @@ def start_pedal_listener(events: dict): logger.info("[Pedal] evdev not installed - pedal support disabled") return - PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" - KEY_LEFT, KEY_RIGHT = "KEY_A", "KEY_C" + pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd" + key_left, key_right = "KEY_A", "KEY_C" def pedal_reader(): try: - dev = InputDevice(PEDAL_DEVICE) + dev = InputDevice(pedal_device) print(f"[Pedal] Connected: {dev.name}") for ev in dev.read_loop(): if ev.type != ecodes.EV_KEY: @@ -167,17 +167,16 @@ def start_pedal_listener(events: dict): continue if events["in_reset"]: - if code in [KEY_LEFT, KEY_RIGHT]: + if code in [key_left, key_right]: events["start_next_episode"] = True else: - if code == KEY_RIGHT: + if code == key_right: if events["correction_active"]: events["exit_early"] = True elif not events["policy_paused"]: events["policy_paused"] = True - elif code == KEY_LEFT: - if events["policy_paused"] and not events["correction_active"]: - events["start_next_episode"] = True + elif code == key_left and events["policy_paused"] and not events["correction_active"]: + events["start_next_episode"] = True except (FileNotFoundError, PermissionError) as e: logger.info(f"[Pedal] {e}") @@ -248,4 +247,3 @@ def print_controls(rtc: bool = False): print(" → - End episode") print(" ESC - Stop and push to hub") print("=" * 60 + "\n") - diff --git a/src/lerobot/policies/rtc/__init__.py b/src/lerobot/policies/rtc/__init__.py index 9d1620b6e..ac7b72ef7 100644 --- a/src/lerobot/policies/rtc/__init__.py +++ b/src/lerobot/policies/rtc/__init__.py @@ -27,4 +27,3 @@ __all__ = [ "RTCConfig", "RTCProcessor", ] - diff --git a/src/lerobot/policies/rtc/action_interpolator.py b/src/lerobot/policies/rtc/action_interpolator.py index 50be16e08..969054236 100644 --- a/src/lerobot/policies/rtc/action_interpolator.py +++ b/src/lerobot/policies/rtc/action_interpolator.py @@ -115,4 +115,3 @@ class ActionInterpolator: Control interval in seconds (divided by multiplier). """ return 1.0 / (fps * self.multiplier) -