mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
chore(style): fix pre-commit
This commit is contained in:
@@ -37,6 +37,17 @@ from dataclasses import dataclass
|
||||
from pprint import pformat
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
@@ -46,8 +57,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.datasets.video_utils import VideoEncodingManager
|
||||
import torch
|
||||
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc import ActionInterpolator
|
||||
@@ -62,16 +71,6 @@ from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -127,7 +126,7 @@ def rollout_loop(
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
robot_action: dict[str, Any] = {}
|
||||
action_keys = sorted([k for k in robot.action_features.keys()])
|
||||
action_keys = sorted(robot.action_features.keys())
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
@@ -147,7 +146,9 @@ def rollout_loop(
|
||||
# Transition to paused state
|
||||
if events["policy_paused"] and not was_paused:
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
@@ -341,6 +342,7 @@ def hil_collect(cfg: HILConfig) -> LeRobotDataset:
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_collect()
|
||||
|
||||
|
||||
@@ -44,6 +44,15 @@ from threading import Event, Lock, Thread
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
@@ -64,19 +73,9 @@ from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleope
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import get_safe_torch_device, init_logging, log_say
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
from hil_utils import (
|
||||
HILDatasetConfig,
|
||||
init_keyboard_listener,
|
||||
make_identity_processors,
|
||||
print_controls,
|
||||
reset_loop,
|
||||
teleop_disable_torque,
|
||||
teleop_smooth_move_to,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -239,7 +238,7 @@ def rollout_loop(
|
||||
was_paused = False
|
||||
waiting_for_takeover = False
|
||||
last_action: dict[str, Any] | None = None
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
action_keys = [k for k in robot.action_features if k.endswith(".pos")]
|
||||
|
||||
interpolator = ActionInterpolator(multiplier=cfg.interpolation_multiplier)
|
||||
control_interval = interpolator.get_control_interval(fps)
|
||||
@@ -261,7 +260,9 @@ def rollout_loop(
|
||||
if events["policy_paused"] and not was_paused:
|
||||
policy_active.clear()
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
robot_pos = {
|
||||
k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features
|
||||
}
|
||||
teleop_smooth_move_to(teleop, robot_pos, duration_s=2.0, fps=50)
|
||||
events["start_next_episode"] = False
|
||||
waiting_for_takeover = True
|
||||
@@ -305,7 +306,9 @@ def rollout_loop(
|
||||
|
||||
action_tensor = interpolator.get()
|
||||
if action_tensor is not None:
|
||||
robot_action = {k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)}
|
||||
robot_action = {
|
||||
k: action_tensor[i].item() for i, k in enumerate(action_keys) if i < len(action_tensor)
|
||||
}
|
||||
robot.send_action(robot_action)
|
||||
last_action = robot_action
|
||||
action_frame = build_dataset_frame(dataset.features, robot_action, prefix=ACTION)
|
||||
@@ -418,8 +421,17 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
||||
|
||||
rtc_thread = Thread(
|
||||
target=rtc_inference_thread,
|
||||
args=(policy, obs_holder, hw_features, preprocessor, postprocessor,
|
||||
queue_holder, shutdown_event, policy_active, cfg),
|
||||
args=(
|
||||
policy,
|
||||
obs_holder,
|
||||
hw_features,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
queue_holder,
|
||||
shutdown_event,
|
||||
policy_active,
|
||||
cfg,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
rtc_thread.start()
|
||||
@@ -492,10 +504,10 @@ def hil_rtc_collect(cfg: HILRTCConfig) -> LeRobotDataset:
|
||||
|
||||
def main():
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
|
||||
register_third_party_plugins()
|
||||
hil_rtc_collect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -17,8 +17,8 @@ from lerobot.processor.converters import (
|
||||
transition_to_observation,
|
||||
transition_to_robot_action,
|
||||
)
|
||||
from lerobot.robots import Robot, RobotConfig
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.teleoperators import Teleoperator
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
@@ -151,12 +151,12 @@ def start_pedal_listener(events: dict):
|
||||
logger.info("[Pedal] evdev not installed - pedal support disabled")
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT, KEY_RIGHT = "KEY_A", "KEY_C"
|
||||
pedal_device = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
key_left, key_right = "KEY_A", "KEY_C"
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
dev = InputDevice(pedal_device)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
@@ -167,17 +167,16 @@ def start_pedal_listener(events: dict):
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
if code in [key_left, key_right]:
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if code == KEY_RIGHT:
|
||||
if code == key_right:
|
||||
if events["correction_active"]:
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
events["policy_paused"] = True
|
||||
elif code == KEY_LEFT:
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
events["start_next_episode"] = True
|
||||
elif code == key_left and events["policy_paused"] and not events["correction_active"]:
|
||||
events["start_next_episode"] = True
|
||||
except (FileNotFoundError, PermissionError) as e:
|
||||
logger.info(f"[Pedal] {e}")
|
||||
|
||||
@@ -248,4 +247,3 @@ def print_controls(rtc: bool = False):
|
||||
print(" → - End episode")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
|
||||
Reference in New Issue
Block a user