mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| a0dc324b81 | |||
| 1d275e2021 | |||
| 24bb2cb0ff | |||
| 1d414c07e2 | |||
| e04e3399b9 |
@@ -131,6 +131,15 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -149,6 +158,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -198,6 +208,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -209,6 +220,8 @@ class _NormalizationMixin:
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
self.stats = {}
|
||||
|
||||
@@ -62,6 +62,7 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
@@ -258,6 +259,11 @@ def act_with_policy(
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, postprocessor = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -289,7 +295,9 @@ def act_with_policy(
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
normalized_observation = preprocessor.process_observation(observation)
|
||||
action = policy.select_action(batch=normalized_observation)
|
||||
# action = postprocessor.process_action(action)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
@@ -66,6 +66,7 @@ from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.processor_sac import make_sac_pre_post_processors
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
@@ -313,6 +314,11 @@ def add_actor_information_and_train(
|
||||
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
preprocessor, _ = make_sac_pre_post_processors(
|
||||
config=cfg.policy,
|
||||
dataset_stats=cfg.policy.dataset_stats,
|
||||
)
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
@@ -408,6 +414,9 @@ def add_actor_information_and_train(
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
@@ -467,6 +476,9 @@ def add_actor_information_and_train(
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observations = preprocessor.process_observation(observations)
|
||||
next_observations = preprocessor.process_observation(next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
@@ -23,65 +23,46 @@ class InputController:
|
||||
"""Base class for input controllers that generate motion deltas."""
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0):
|
||||
"""
|
||||
Initialize the controller.
|
||||
|
||||
Args:
|
||||
x_step_size: Base movement step size in meters
|
||||
y_step_size: Base movement step size in meters
|
||||
z_step_size: Base movement step size in meters
|
||||
"""
|
||||
self.x_step_size = x_step_size
|
||||
self.y_step_size = y_step_size
|
||||
self.z_step_size = z_step_size
|
||||
self.running = True
|
||||
self.episode_end_status = None # None, "success", or "failure"
|
||||
self.episode_end_status = None
|
||||
self.intervention_flag = False
|
||||
self.open_gripper_command = False
|
||||
self.close_gripper_command = False
|
||||
|
||||
def start(self):
|
||||
"""Start the controller and initialize resources."""
|
||||
pass
|
||||
|
||||
def stop(self):
|
||||
"""Stop the controller and release resources."""
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas (dx, dy, dz) in meters."""
|
||||
return 0.0, 0.0, 0.0
|
||||
|
||||
def update(self):
|
||||
"""Update controller state - call this once per frame."""
|
||||
pass
|
||||
|
||||
def __enter__(self):
|
||||
"""Support for use in 'with' statements."""
|
||||
self.start()
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
"""Ensure resources are released when exiting 'with' block."""
|
||||
self.stop()
|
||||
|
||||
def get_episode_end_status(self):
|
||||
"""
|
||||
Get the current episode end status.
|
||||
|
||||
Returns:
|
||||
None if episode should continue, "success" or "failure" otherwise
|
||||
"""
|
||||
status = self.episode_end_status
|
||||
self.episode_end_status = None # Reset after reading
|
||||
self.episode_end_status = None
|
||||
return status
|
||||
|
||||
def should_intervene(self):
|
||||
"""Return True if intervention flag was set."""
|
||||
return self.intervention_flag
|
||||
|
||||
def gripper_command(self):
|
||||
"""Return the current gripper command."""
|
||||
if self.open_gripper_command == self.close_gripper_command:
|
||||
return "stay"
|
||||
elif self.open_gripper_command:
|
||||
@@ -102,14 +83,14 @@ class KeyboardController(InputController):
|
||||
"backward_y": False,
|
||||
"forward_z": False,
|
||||
"backward_z": False,
|
||||
"quit": False,
|
||||
"success": False,
|
||||
"failure": False,
|
||||
"intervention": False,
|
||||
"rerecord": False,
|
||||
}
|
||||
self.listener = None
|
||||
|
||||
def start(self):
|
||||
"""Start the keyboard listener."""
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
@@ -126,16 +107,21 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = True
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["quit"] = True
|
||||
self.running = False
|
||||
return False
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = True
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = True
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = True
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif key == keyboard.Key.backspace:
|
||||
elif key == keyboard.Key.esc:
|
||||
self.key_states["failure"] = True
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif key == keyboard.Key.space:
|
||||
self.key_states["intervention"] = not self.key_states["intervention"]
|
||||
elif hasattr(key, "char") and key.char == "r":
|
||||
self.key_states["rerecord"] = True
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -153,10 +139,10 @@ class KeyboardController(InputController):
|
||||
self.key_states["backward_z"] = False
|
||||
elif key == keyboard.Key.shift_r:
|
||||
self.key_states["forward_z"] = False
|
||||
elif key == keyboard.Key.enter:
|
||||
self.key_states["success"] = False
|
||||
elif key == keyboard.Key.backspace:
|
||||
self.key_states["failure"] = False
|
||||
elif key == keyboard.Key.ctrl_r:
|
||||
self.open_gripper_command = False
|
||||
elif key == keyboard.Key.ctrl_l:
|
||||
self.close_gripper_command = False
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
@@ -165,18 +151,18 @@ class KeyboardController(InputController):
|
||||
|
||||
print("Keyboard controls:")
|
||||
print(" Arrow keys: Move in X-Y plane")
|
||||
print(" Shift and Shift_R: Move in Z axis")
|
||||
print(" Shift / Shift_R: Move in Z axis")
|
||||
print(" Ctrl_R / Ctrl_L: Open / Close gripper")
|
||||
print(" Space: Toggle intervention")
|
||||
print(" Enter: End episode with SUCCESS")
|
||||
print(" Backspace: End episode with FAILURE")
|
||||
print(" ESC: Exit")
|
||||
print(" Esc: End episode with FAILURE")
|
||||
print(" R: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Stop the keyboard listener."""
|
||||
if self.listener and self.listener.is_alive():
|
||||
self.listener.stop()
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from keyboard state."""
|
||||
delta_x = delta_y = delta_z = 0.0
|
||||
|
||||
if self.key_states["forward_x"]:
|
||||
@@ -194,18 +180,58 @@ class KeyboardController(InputController):
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
def should_intervene(self):
|
||||
return self.key_states["intervention"]
|
||||
|
||||
def reset(self):
|
||||
for key in self.key_states:
|
||||
self.key_states[key] = False
|
||||
|
||||
|
||||
class GamepadController(InputController):
|
||||
"""Generate motion deltas from gamepad input."""
|
||||
"""Generate motion deltas from gamepad input using pygame.
|
||||
|
||||
Matches gym-hil button/axis conventions for Linux gamepads, including
|
||||
Xbox mappings.
|
||||
"""
|
||||
|
||||
# Face buttons (same across most controllers on Linux)
|
||||
BUTTON_A = 0
|
||||
BUTTON_B = 1
|
||||
BUTTON_X = 2
|
||||
BUTTON_Y = 3
|
||||
BUTTON_LB = 4
|
||||
BUTTON_RB = 5
|
||||
# Stick axes
|
||||
AXIS_LEFT_X = 0
|
||||
AXIS_LEFT_Y = 1
|
||||
AXIS_RIGHT_X = 2
|
||||
AXIS_RIGHT_Y = 3
|
||||
|
||||
# Default trigger buttons
|
||||
BUTTON_LT = 6
|
||||
BUTTON_RT = 7
|
||||
|
||||
# Xbox (gym-hil mapping on Linux)
|
||||
XBOX_BUTTON_LT = 9
|
||||
XBOX_BUTTON_RT = 10
|
||||
|
||||
def __init__(self, x_step_size=1.0, y_step_size=1.0, z_step_size=1.0, deadzone=0.1):
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.joystick = None
|
||||
self.intervention_flag = False
|
||||
self.is_xbox = False
|
||||
self._xbox360_profile = False
|
||||
self._invert_left_x = False
|
||||
self._invert_left_y = True
|
||||
self._invert_right_y = True
|
||||
|
||||
def _detect_xbox(self, name):
|
||||
name_lower = name.lower()
|
||||
return any(tag in name_lower for tag in ["xbox", "microsoft", "x-box"])
|
||||
|
||||
def start(self):
|
||||
"""Initialize pygame and the gamepad."""
|
||||
import pygame
|
||||
|
||||
pygame.init()
|
||||
@@ -218,18 +244,35 @@ class GamepadController(InputController):
|
||||
|
||||
self.joystick = pygame.joystick.Joystick(0)
|
||||
self.joystick.init()
|
||||
logging.info(f"Initialized gamepad: {self.joystick.get_name()}")
|
||||
joystick_name = self.joystick.get_name()
|
||||
self.is_xbox = self._detect_xbox(joystick_name)
|
||||
self._xbox360_profile = joystick_name == "Xbox 360 Controller"
|
||||
if self._xbox360_profile:
|
||||
# gym-hil "Xbox 360 Controller" profile
|
||||
self.AXIS_RIGHT_X = 3
|
||||
self.AXIS_RIGHT_Y = 4
|
||||
self.BUTTON_LT = self.XBOX_BUTTON_LT
|
||||
self.BUTTON_RT = self.XBOX_BUTTON_RT
|
||||
self._invert_left_x = True
|
||||
else:
|
||||
# gym-hil default profile
|
||||
self.AXIS_RIGHT_X = 2
|
||||
self.AXIS_RIGHT_Y = 3
|
||||
self.BUTTON_LT = 6
|
||||
self.BUTTON_RT = 7
|
||||
self._invert_left_x = False
|
||||
logging.info(f"Initialized gamepad: {joystick_name} (xbox={self.is_xbox})")
|
||||
|
||||
print("Gamepad controls:")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick (vertical): Move in Z axis")
|
||||
print(" B/Circle button: Exit")
|
||||
print(" Y/Triangle button: End episode with SUCCESS")
|
||||
print(" A/Cross button: End episode with FAILURE")
|
||||
print(" X/Square button: Rerecord episode")
|
||||
print(" RB: Intervention toggle")
|
||||
print(" LT / RT: Close / Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" A: End episode with FAILURE")
|
||||
print(" X: Rerecord episode")
|
||||
|
||||
def stop(self):
|
||||
"""Clean up pygame resources."""
|
||||
import pygame
|
||||
|
||||
if pygame.joystick.get_init():
|
||||
@@ -239,67 +282,56 @@ class GamepadController(InputController):
|
||||
pygame.quit()
|
||||
|
||||
def update(self):
|
||||
"""Process pygame events to get fresh gamepad readings."""
|
||||
import pygame
|
||||
|
||||
for event in pygame.event.get():
|
||||
if event.type == pygame.JOYBUTTONDOWN:
|
||||
if event.button == 3:
|
||||
if event.button == self.BUTTON_Y:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
# A button (1) for failure
|
||||
elif event.button == 1:
|
||||
elif event.button == self.BUTTON_A:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
# X button (0) for rerecord
|
||||
elif event.button == 0:
|
||||
elif event.button == self.BUTTON_X:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
|
||||
# RB button (6) for closing gripper
|
||||
elif event.button == 6:
|
||||
elif event.button == self.BUTTON_LT:
|
||||
self.close_gripper_command = True
|
||||
|
||||
# LT button (7) for opening gripper
|
||||
elif event.button == 7:
|
||||
elif event.button == self.BUTTON_RT:
|
||||
self.open_gripper_command = True
|
||||
|
||||
# Reset episode status on button release
|
||||
elif event.type == pygame.JOYBUTTONUP:
|
||||
if event.button in [0, 2, 3]:
|
||||
if event.button in [self.BUTTON_Y, self.BUTTON_A, self.BUTTON_X]:
|
||||
self.episode_end_status = None
|
||||
|
||||
elif event.button == 6:
|
||||
elif event.button == self.BUTTON_LT:
|
||||
self.close_gripper_command = False
|
||||
|
||||
elif event.button == 7:
|
||||
elif event.button == self.BUTTON_RT:
|
||||
self.open_gripper_command = False
|
||||
|
||||
# Check for RB button (typically button 5) for intervention flag
|
||||
if self.joystick.get_button(5):
|
||||
if self.joystick.get_button(self.BUTTON_RB):
|
||||
self.intervention_flag = True
|
||||
else:
|
||||
self.intervention_flag = False
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
import pygame
|
||||
|
||||
try:
|
||||
# Read joystick axes
|
||||
# Left stick X and Y (typically axes 0 and 1)
|
||||
y_input = self.joystick.get_axis(0) # Up/Down (often inverted)
|
||||
x_input = self.joystick.get_axis(1) # Left/Right
|
||||
x_input = self.joystick.get_axis(self.AXIS_LEFT_X)
|
||||
y_input = self.joystick.get_axis(self.AXIS_LEFT_Y)
|
||||
z_input = self.joystick.get_axis(self.AXIS_RIGHT_Y)
|
||||
|
||||
# Right stick Y (typically axis 3 or 4)
|
||||
z_input = self.joystick.get_axis(3) # Up/Down for Z
|
||||
|
||||
# Apply deadzone to avoid drift
|
||||
x_input = 0 if abs(x_input) < self.deadzone else x_input
|
||||
y_input = 0 if abs(y_input) < self.deadzone else y_input
|
||||
z_input = 0 if abs(z_input) < self.deadzone else z_input
|
||||
|
||||
# Calculate deltas (note: may need to invert axes depending on controller)
|
||||
delta_x = -x_input * self.x_step_size # Forward/backward
|
||||
delta_y = -y_input * self.y_step_size # Left/right
|
||||
delta_z = -z_input * self.z_step_size # Up/down
|
||||
if self._invert_left_x:
|
||||
x_input = -x_input
|
||||
if self._invert_left_y:
|
||||
y_input = -y_input
|
||||
if self._invert_right_y:
|
||||
z_input = -z_input
|
||||
|
||||
delta_x = y_input * self.y_step_size
|
||||
delta_y = x_input * self.x_step_size
|
||||
delta_z = z_input * self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
@@ -309,7 +341,15 @@ class GamepadController(InputController):
|
||||
|
||||
|
||||
class GamepadControllerHID(InputController):
|
||||
"""Generate motion deltas from gamepad input using HIDAPI."""
|
||||
"""Generate motion deltas from gamepad input using HIDAPI.
|
||||
|
||||
Supports auto-detection of controller type for correct HID report parsing.
|
||||
Currently supported: Logitech RumblePad 2, 8BitDo Ultimate 2C Wireless.
|
||||
"""
|
||||
|
||||
CONTROLLER_LOGITECH = "logitech"
|
||||
CONTROLLER_8BITDO = "8bitdo"
|
||||
CONTROLLER_UNKNOWN = "unknown"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -318,36 +358,26 @@ class GamepadControllerHID(InputController):
|
||||
z_step_size=1.0,
|
||||
deadzone=0.1,
|
||||
):
|
||||
"""
|
||||
Initialize the HID gamepad controller.
|
||||
|
||||
Args:
|
||||
step_size: Base movement step size in meters
|
||||
z_scale: Scaling factor for Z-axis movement
|
||||
deadzone: Joystick deadzone to prevent drift
|
||||
"""
|
||||
super().__init__(x_step_size, y_step_size, z_step_size)
|
||||
self.deadzone = deadzone
|
||||
self.device = None
|
||||
self.device_info = None
|
||||
self.controller_type = self.CONTROLLER_UNKNOWN
|
||||
|
||||
# Movement values (normalized from -1.0 to 1.0)
|
||||
self.left_x = 0.0
|
||||
self.left_y = 0.0
|
||||
self.right_x = 0.0
|
||||
self.right_y = 0.0
|
||||
|
||||
# Button states
|
||||
self.buttons = {}
|
||||
|
||||
def find_device(self):
|
||||
"""Look for the gamepad device by vendor and product ID."""
|
||||
import hid
|
||||
|
||||
devices = hid.enumerate()
|
||||
for device in devices:
|
||||
device_name = device["product_string"]
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5"]):
|
||||
if any(controller in device_name for controller in ["Logitech", "Xbox", "PS4", "PS5", "8BitDo"]):
|
||||
return device
|
||||
|
||||
logging.error(
|
||||
@@ -355,8 +385,15 @@ class GamepadControllerHID(InputController):
|
||||
)
|
||||
return None
|
||||
|
||||
def _detect_controller_type(self, product_string):
|
||||
product = product_string.lower() if product_string else ""
|
||||
if "8bitdo" in product:
|
||||
return self.CONTROLLER_8BITDO
|
||||
elif "logitech" in product:
|
||||
return self.CONTROLLER_LOGITECH
|
||||
return self.CONTROLLER_UNKNOWN
|
||||
|
||||
def start(self):
|
||||
"""Connect to the gamepad using HIDAPI."""
|
||||
import hid
|
||||
|
||||
self.device_info = self.find_device()
|
||||
@@ -374,12 +411,22 @@ class GamepadControllerHID(InputController):
|
||||
product = self.device.get_product_string()
|
||||
logging.info(f"Connected to {manufacturer} {product}")
|
||||
|
||||
logging.info("Gamepad controls (HID mode):")
|
||||
logging.info(" Left analog stick: Move in X-Y plane")
|
||||
logging.info(" Right analog stick: Move in Z axis (vertical)")
|
||||
logging.info(" Button 1/B/Circle: Exit")
|
||||
logging.info(" Button 2/A/Cross: End episode with SUCCESS")
|
||||
logging.info(" Button 3/X/Square: End episode with FAILURE")
|
||||
self.controller_type = self._detect_controller_type(product)
|
||||
logging.info(f"Detected controller type: {self.controller_type}")
|
||||
|
||||
print("Gamepad controls (HID mode):")
|
||||
print(" Left analog stick: Move in X-Y plane")
|
||||
print(" Right analog stick: Move in Z axis (vertical)")
|
||||
print(" RB: Intervention toggle")
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
print(" L3 (left stick click): Close gripper")
|
||||
print(" R3 (right stick click): Open gripper")
|
||||
else:
|
||||
print(" LT: Close gripper")
|
||||
print(" RT: Open gripper")
|
||||
print(" Y: End episode with SUCCESS")
|
||||
print(" X: End episode with FAILURE")
|
||||
print(" A: Rerecord episode")
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error opening gamepad: {e}")
|
||||
@@ -387,74 +434,124 @@ class GamepadControllerHID(InputController):
|
||||
self.running = False
|
||||
|
||||
def stop(self):
|
||||
"""Close the HID device connection."""
|
||||
if self.device:
|
||||
self.device.close()
|
||||
self.device = None
|
||||
|
||||
def update(self):
|
||||
"""
|
||||
Read and process the latest gamepad data.
|
||||
Due to an issue with the HIDAPI, we need to read the read the device several times in order to get a stable reading
|
||||
"""
|
||||
"""Read the device several times to drain the HID buffer and get a stable reading."""
|
||||
for _ in range(10):
|
||||
self._update()
|
||||
|
||||
def _update(self):
|
||||
"""Read and process the latest gamepad data."""
|
||||
if not self.device or not self.running:
|
||||
return
|
||||
|
||||
try:
|
||||
# Read data from the gamepad
|
||||
data = self.device.read(64)
|
||||
# Interpret gamepad data - this will vary by controller model
|
||||
# These offsets are for the Logitech RumblePad 2
|
||||
if data and len(data) >= 8:
|
||||
# Normalize joystick values from 0-255 to -1.0-1.0
|
||||
self.left_y = (data[1] - 128) / 128.0
|
||||
self.left_x = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
if not data:
|
||||
return
|
||||
|
||||
# Apply deadzone
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
# Parse button states (byte 5 in the Logitech RumblePad 2)
|
||||
buttons = data[5]
|
||||
|
||||
# Check if RB is pressed then the intervention flag should be set
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
|
||||
# Check if RT is pressed
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
|
||||
# Check if LT is pressed
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
# Check if Y/Triangle button (bit 7) is pressed for saving
|
||||
# Check if X/Square button (bit 5) is pressed for failure
|
||||
# Check if A/Cross button (bit 4) is pressed for rerecording
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
if self.controller_type == self.CONTROLLER_8BITDO:
|
||||
self._parse_8bitdo(data)
|
||||
else:
|
||||
self._parse_logitech(data)
|
||||
|
||||
except OSError as e:
|
||||
logging.error(f"Error reading from gamepad: {e}")
|
||||
|
||||
def _apply_deadzone(self):
|
||||
self.left_x = 0 if abs(self.left_x) < self.deadzone else self.left_x
|
||||
self.left_y = 0 if abs(self.left_y) < self.deadzone else self.left_y
|
||||
self.right_x = 0 if abs(self.right_x) < self.deadzone else self.right_x
|
||||
self.right_y = 0 if abs(self.right_y) < self.deadzone else self.right_y
|
||||
|
||||
def _parse_8bitdo(self, data):
|
||||
"""Parse HID report from 8BitDo Ultimate 2C Wireless (Bluetooth on macOS).
|
||||
|
||||
11-byte report layout:
|
||||
byte[0]: Report ID (0x01)
|
||||
byte[1]: D-pad hat switch (0=N, 2=E, 5=S, 6=W, 15=neutral)
|
||||
byte[2]: Left Stick X (0=left, 127=center, 255=right)
|
||||
byte[3]: Left Stick Y (0=up, 127=center, 255=down)
|
||||
byte[4]: Right Stick X (inverted: 255=left, 0=right)
|
||||
byte[5]: Right Stick Y (0=up, 127=center, 255=down)
|
||||
byte[6]: RT analog trigger (0-255)
|
||||
byte[7]: LT analog trigger (0-255)
|
||||
byte[8]: Buttons -- bit0=A, bit1=B, bit3=X, bit4=Y, bit6=LB, bit7=RB
|
||||
byte[9]: System -- bit0=LT(digital), bit1=RT(digital), bit3=Select,
|
||||
bit4=Start, bit5=L3, bit6=R3
|
||||
byte[10]: Unused
|
||||
"""
|
||||
if len(data) < 11:
|
||||
return
|
||||
|
||||
self.left_x = (data[2] - 127) / 128.0
|
||||
self.left_y = (data[3] - 127) / 128.0
|
||||
self.right_x = -(data[4] - 127) / 128.0
|
||||
self.right_y = (data[5] - 127) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[8]
|
||||
|
||||
# RB (bit 7) = intervention
|
||||
self.intervention_flag = bool(buttons & 0x80)
|
||||
|
||||
# Stick clicks for gripper: R3 (byte[9] bit6) = open, L3 (byte[9] bit5) = close
|
||||
system = data[9]
|
||||
self.open_gripper_command = bool(system & 0x40) # R3
|
||||
self.close_gripper_command = bool(system & 0x20) # L3
|
||||
|
||||
# Y (bit 4) = success, X (bit 3) = failure, A (bit 0) = rerecord
|
||||
if buttons & 0x10:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 0x08:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 0x01:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def _parse_logitech(self, data):
|
||||
"""Parse HID report from Logitech RumblePad 2 (and similar Logitech gamepads).
|
||||
|
||||
Report layout (8+ bytes):
|
||||
byte[1]: Left Stick X (0-255, center=128)
|
||||
byte[2]: Left Stick Y (0-255, center=128)
|
||||
byte[3]: Right Stick X (0-255, center=128)
|
||||
byte[4]: Right Stick Y (0-255, center=128)
|
||||
byte[5]: Face buttons bitmask
|
||||
byte[6]: Shoulder/trigger buttons bitmask
|
||||
"""
|
||||
if len(data) < 8:
|
||||
return
|
||||
|
||||
self.left_x = (data[1] - 128) / 128.0
|
||||
self.left_y = (data[2] - 128) / 128.0
|
||||
self.right_x = (data[3] - 128) / 128.0
|
||||
self.right_y = (data[4] - 128) / 128.0
|
||||
|
||||
self._apply_deadzone()
|
||||
|
||||
buttons = data[5]
|
||||
|
||||
self.intervention_flag = data[6] in [2, 6, 10, 14]
|
||||
self.open_gripper_command = data[6] in [8, 10, 12]
|
||||
self.close_gripper_command = data[6] in [4, 6, 12]
|
||||
|
||||
if buttons & 1 << 7:
|
||||
self.episode_end_status = TeleopEvents.SUCCESS
|
||||
elif buttons & 1 << 5:
|
||||
self.episode_end_status = TeleopEvents.FAILURE
|
||||
elif buttons & 1 << 4:
|
||||
self.episode_end_status = TeleopEvents.RERECORD_EPISODE
|
||||
else:
|
||||
self.episode_end_status = None
|
||||
|
||||
def get_deltas(self):
|
||||
"""Get the current movement deltas from gamepad state."""
|
||||
# Calculate deltas - invert as needed based on controller orientation
|
||||
delta_x = -self.left_x * self.x_step_size # Forward/backward
|
||||
delta_y = -self.left_y * self.y_step_size # Left/right
|
||||
delta_z = -self.right_y * self.z_step_size # Up/down
|
||||
delta_x = -self.left_y * self.x_step_size
|
||||
delta_y = -self.left_x * self.y_step_size
|
||||
delta_z = -self.right_y * self.z_step_size
|
||||
|
||||
return delta_x, delta_y, delta_z
|
||||
|
||||
Reference in New Issue
Block a user