mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
chore(style): fix pre-commit
This commit is contained in:
@@ -37,6 +37,17 @@ from dataclasses import dataclass
|
|||||||
from pprint import pformat
|
from pprint import pformat
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from hil_utils import (
|
||||||
|
HILDatasetConfig,
|
||||||
|
init_keyboard_listener,
|
||||||
|
make_identity_processors,
|
||||||
|
print_controls,
|
||||||
|
reset_loop,
|
||||||
|
teleop_disable_torque,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
@@ -46,8 +57,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
|||||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||||
import torch
|
|
||||||
|
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||||
from lerobot.policies.rtc import ActionInterpolator
|
from lerobot.policies.rtc import ActionInterpolator
|
||||||
@@ -62,16 +71,6 @@ from lerobot.utils.robot_utils import precise_sleep
|
|||||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
from hil_utils import (
|
|
||||||
HILDatasetConfig,
|
|
||||||
init_keyboard_listener,
|
|
||||||
make_identity_processors,
|
|
||||||
print_controls,
|
|
||||||
reset_loop,
|
|
||||||
teleop_disable_torque,
|
|
||||||
teleop_smooth_move_to,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -127,7 +126,7 @@ def rollout_loop(
|
|||||||
waiting_for_takeover = False
|
waiting_for_takeover = False
|
||||||
last_action: dict[str, Any] | None = None
|
last_action: dict[str, Any] | None = None
|
||||||
robot_action: dict[str, Any] = {}
|
robot_action: dict[str, Any] = {}
|
||||||
action_keys = sorted([k for k in robot.action_features.keys()])
|
action_keys = sorted(robot.action_features.keys())
|
||||||
|
|
||||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||||
control_interval = interpolator.get_control_interval(fps)
|
control_interval = interpolator.get_control_interval(fps)
|
||||||
@@ -147,7 +146,9 @@ def rollout_loop(
|
|||||||
# Transition to paused state
|
# Transition to paused state
|
||||||
if events["policy_paused"] and not was_paused:
|
if events["policy_paused"] and not was_paused:
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
robot_pos = {
|
||||||
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||||
|
}
|
||||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||||
events["start_next_episode"] = False
|
events["start_next_episode"] = False
|
||||||
waiting_for_takeover = True
|
waiting_for_takeover = True
|
||||||
@@ -341,6 +342,7 @@ def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
|
|
||||||
register_third_party_plugins()
|
register_third_party_plugins()
|
||||||
hil_collect()
|
hil_collect()
|
||||||
|
|
||||||
|
|||||||
@@ -44,6 +44,15 @@ from threading import Event, Lock, Thread
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
from hil_utils import (
|
||||||
|
HILDatasetConfig,
|
||||||
|
init_keyboard_listener,
|
||||||
|
make_identity_processors,
|
||||||
|
print_controls,
|
||||||
|
reset_loop,
|
||||||
|
teleop_disable_torque,
|
||||||
|
teleop_smooth_move_to,
|
||||||
|
)
|
||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
@@ -64,19 +73,9 @@ from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleope
|
|||||||
from lerobot.utils.constants import ACTION, OBS_STR
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.control_utils import is_headless
|
from lerobot.utils.control_utils import is_headless
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
from lerobot.utils.utils import init_logging, log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
|
|
||||||
from hil_utils import (
|
|
||||||
HILDatasetConfig,
|
|
||||||
init_keyboard_listener,
|
|
||||||
make_identity_processors,
|
|
||||||
print_controls,
|
|
||||||
reset_loop,
|
|
||||||
teleop_disable_torque,
|
|
||||||
teleop_smooth_move_to,
|
|
||||||
)
|
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@@ -239,7 +238,7 @@ def rollout_loop(
|
|||||||
was_paused = False
|
was_paused = False
|
||||||
waiting_for_takeover = False
|
waiting_for_takeover = False
|
||||||
last_action: dict[str, Any] | None = None
|
last_action: dict[str, Any] | None = None
|
||||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||||
|
|
||||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||||
control_interval = interpolator.get_control_interval(fps)
|
control_interval = interpolator.get_control_interval(fps)
|
||||||
@@ -261,7 +260,9 @@ def rollout_loop(
|
|||||||
if events["policy_paused"] and not was_paused:
|
if events["policy_paused"] and not was_paused:
|
||||||
policy_active.clear()
|
policy_active.clear()
|
||||||
obs = robot.get_observation()
|
obs = robot.get_observation()
|
||||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
robot_pos = {
|
||||||
|
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||||
|
}
|
||||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||||
events["start_next_episode"] = False
|
events["start_next_episode"] = False
|
||||||
waiting_for_takeover = True
|
waiting_for_takeover = True
|
||||||
@@ -305,7 +306,9 @@ def rollout_loop(
|
|||||||
|
|
||||||
action_tensor = interpolator.get()
|
action_tensor = interpolator.get()
|
||||||
if action_tensor is not None:
|
if action_tensor is not None:
|
||||||
robot_action = {k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)}
|
robot_action = {
|
||||||
|
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
|
||||||
|
}
|
||||||
robot.send_action(robot_action)
|
robot.send_action(robot_action)
|
||||||
last_action = robot_action
|
last_action = robot_action
|
||||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||||
@@ -418,8 +421,17 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
rtc_thread = Thread(
|
rtc_thread = Thread(
|
||||||
target=rtc_inference_thread,
|
target=rtc_inference_thread,
|
||||||
args=(policy, obs_holder, hw_features, preprocessor, postprocessor,
|
args=(
|
||||||
queue_holder, shutdown_event, policy_active, cfg),
|
policy,
|
||||||
|
obs_holder,
|
||||||
|
hw_features,
|
||||||
|
preprocessor,
|
||||||
|
postprocessor,
|
||||||
|
queue_holder,
|
||||||
|
shutdown_event,
|
||||||
|
policy_active,
|
||||||
|
cfg,
|
||||||
|
),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
)
|
)
|
||||||
rtc_thread.start()
|
rtc_thread.start()
|
||||||
@@ -492,10 +504,10 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
|||||||
|
|
||||||
def main():
|
def main():
|
||||||
from lerobot.utils.import_utils import register_third_party_plugins
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
|
|
||||||
register_third_party_plugins()
|
register_third_party_plugins()
|
||||||
hil_rtc_collect()
|
hil_rtc_collect()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
main()
|
main()
|
||||||
|
|
||||||
|
|||||||
@@ -17,8 +17,8 @@ from lerobot.processor.converters import (
|
|||||||
transition_to_observation,
|
transition_to_observation,
|
||||||
transition_to_robot_action,
|
transition_to_robot_action,
|
||||||
)
|
)
|
||||||
from lerobot.robots import Robot, RobotConfig
|
from lerobot.robots import Robot
|
||||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
from lerobot.teleoperators import Teleoperator
|
||||||
from lerobot.utils.control_utils import is_headless
|
from lerobot.utils.control_utils import is_headless
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
|
|
||||||
@@ -151,12 +151,12 @@ def start_pedal_listener(events: dict):
|
|||||||
logger.info("[Pedal] evdev not installed - pedal support disabled")
|
logger.info("[Pedal] evdev not installed - pedal support disabled")
|
||||||
return
|
return
|
||||||
|
|
||||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||||
KEY_LEFT, KEY_RIGHT = "KEY_A", "KEY_C"
|
key_left, key_right = "KEY_A", "KEY_C"
|
||||||
|
|
||||||
def pedal_reader():
|
def pedal_reader():
|
||||||
try:
|
try:
|
||||||
dev = InputDevice(PEDAL_DEVICE)
|
dev = InputDevice(pedal_device)
|
||||||
print(f"[Pedal] Connected: {dev.name}")
|
print(f"[Pedal] Connected: {dev.name}")
|
||||||
for ev in dev.read_loop():
|
for ev in dev.read_loop():
|
||||||
if ev.type != ecodes.EV_KEY:
|
if ev.type != ecodes.EV_KEY:
|
||||||
@@ -167,17 +167,16 @@ def start_pedal_listener(events: dict):
|
|||||||
continue
|
continue
|
||||||
|
|
||||||
if events["in_reset"]:
|
if events["in_reset"]:
|
||||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
if code in [key_left, key_right]:
|
||||||
events["start_next_episode"] = True
|
events["start_next_episode"] = True
|
||||||
else:
|
else:
|
||||||
if code == KEY_RIGHT:
|
if code == key_right:
|
||||||
if events["correction_active"]:
|
if events["correction_active"]:
|
||||||
events["exit_early"] = True
|
events["exit_early"] = True
|
||||||
elif not events["policy_paused"]:
|
elif not events["policy_paused"]:
|
||||||
events["policy_paused"] = True
|
events["policy_paused"] = True
|
||||||
elif code == KEY_LEFT:
|
elif code == key_left and events["policy_paused"] and not events["correction_active"]:
|
||||||
if events["policy_paused"] and not events["correction_active"]:
|
events["start_next_episode"] = True
|
||||||
events["start_next_episode"] = True
|
|
||||||
except (FileNotFoundError, PermissionError) as e:
|
except (FileNotFoundError, PermissionError) as e:
|
||||||
logger.info(f"[Pedal] {e}")
|
logger.info(f"[Pedal] {e}")
|
||||||
|
|
||||||
@@ -248,4 +247,3 @@ def print_controls(rtc: bool = False):
|
|||||||
print(" → - End episode")
|
print(" → - End episode")
|
||||||
print(" ESC - Stop and push to hub")
|
print(" ESC - Stop and push to hub")
|
||||||
print("=" * 60 + "\n")
|
print("=" * 60 + "\n")
|
||||||
|
|
||||||
|
|||||||
@@ -27,4 +27,3 @@ __all__ = [
|
|||||||
"RTCConfig",
|
"RTCConfig",
|
||||||
"RTCProcessor",
|
"RTCProcessor",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|||||||
@@ -115,4 +115,3 @@ class ActionInterpolator:
|
|||||||
Control interval in seconds (divided by multiplier).
|
Control interval in seconds (divided by multiplier).
|
||||||
"""
|
"""
|
||||||
return 1.0 / (fps * self.multiplier)
|
return 1.0 / (fps * self.multiplier)
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user