mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 13:09:43 +00:00
add cmd arg
This commit is contained in:
@@ -20,6 +20,13 @@ Controls:
|
|||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python examples/rac/rac_data_collection_openarms_rtc.py \
|
python examples/rac/rac_data_collection_openarms_rtc.py \
|
||||||
|
--robot.type=openarms_follower \
|
||||||
|
--robot.port_right=can0 \
|
||||||
|
--robot.port_left=can1 \
|
||||||
|
--robot.cameras="{ left_wrist: {type: opencv, index_or_path: 0}, right_wrist: {type: opencv, index_or_path: 2}}" \
|
||||||
|
--teleop.type=openarms_mini \
|
||||||
|
--teleop.port_right=/dev/ttyUSB0 \
|
||||||
|
--teleop.port_left=/dev/ttyUSB1 \
|
||||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||||
--dataset.repo_id=my_user/rac_rtc_dataset \
|
--dataset.repo_id=my_user/rac_rtc_dataset \
|
||||||
--dataset.single_task="Pick up the cube"
|
--dataset.single_task="Pick up the cube"
|
||||||
@@ -37,7 +44,8 @@ from threading import Event, Lock, Thread
|
|||||||
import torch
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.types import RTCAttentionSchedule
|
from lerobot.configs.types import RTCAttentionSchedule
|
||||||
@@ -49,10 +57,10 @@ from lerobot.policies.rtc.action_queue import ActionQueue
|
|||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||||
from lerobot.processor import make_default_processors
|
from lerobot.processor import make_default_processors
|
||||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
from lerobot.robots import Robot, RobotConfig, make_robot_from_config
|
||||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig # noqa: F401
|
||||||
from lerobot.teleoperators import make_teleoperator_from_config
|
from lerobot.teleoperators import TeleoperatorConfig, make_teleoperator_from_config
|
||||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig # noqa: F401
|
||||||
from lerobot.utils.control_utils import is_headless
|
from lerobot.utils.control_utils import is_headless
|
||||||
from lerobot.utils.hub import HubMixin
|
from lerobot.utils.hub import HubMixin
|
||||||
from lerobot.utils.robot_utils import precise_sleep
|
from lerobot.utils.robot_utils import precise_sleep
|
||||||
@@ -64,7 +72,7 @@ logger = logging.getLogger(__name__)
|
|||||||
class RobotWrapper:
|
class RobotWrapper:
|
||||||
"""Thread-safe wrapper for robot operations."""
|
"""Thread-safe wrapper for robot operations."""
|
||||||
|
|
||||||
def __init__(self, robot: OpenArmsFollower):
|
def __init__(self, robot: Robot):
|
||||||
self.robot = robot
|
self.robot = robot
|
||||||
self.lock = Lock()
|
self.lock = Lock()
|
||||||
|
|
||||||
@@ -88,48 +96,44 @@ class RobotWrapper:
|
|||||||
def name(self) -> str:
|
def name(self) -> str:
|
||||||
return self.robot.name
|
return self.robot.name
|
||||||
|
|
||||||
DEFAULT_FPS = 30
|
@dataclass
|
||||||
DEFAULT_EPISODE_TIME_SEC = 120
|
class RaCRTCDatasetConfig:
|
||||||
DEFAULT_NUM_EPISODES = 50
|
"""Dataset configuration for RaC + RTC."""
|
||||||
|
repo_id: str = "lerobot/rac_rtc_openarms"
|
||||||
|
single_task: str = "task"
|
||||||
|
root: str | Path | None = None
|
||||||
|
fps: int = 30
|
||||||
|
episode_time_s: float = 120
|
||||||
|
num_episodes: int = 50
|
||||||
|
video: bool = True
|
||||||
|
push_to_hub: bool = True
|
||||||
|
private: bool = True
|
||||||
|
num_image_writer_processes: int = 0
|
||||||
|
num_image_writer_threads_per_camera: int = 4
|
||||||
|
|
||||||
DEFAULT_CAMERA_CONFIG = {
|
|
||||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS),
|
|
||||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS),
|
|
||||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS),
|
|
||||||
}
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RaCRTCConfig(HubMixin):
|
class RaCRTCConfig(HubMixin):
|
||||||
"""Configuration for RaC data collection with RTC."""
|
"""Configuration for RaC data collection with RTC."""
|
||||||
|
|
||||||
|
robot: RobotConfig
|
||||||
|
teleop: TeleoperatorConfig
|
||||||
|
dataset: RaCRTCDatasetConfig
|
||||||
policy: PreTrainedConfig | None = None
|
policy: PreTrainedConfig | None = None
|
||||||
|
|
||||||
rtc: RTCConfig = field(
|
rtc: RTCConfig = field(
|
||||||
default_factory=lambda: RTCConfig(
|
default_factory=lambda: RTCConfig(
|
||||||
enabled=True,
|
enabled=True,
|
||||||
execution_horizon=10,
|
execution_horizon=20,
|
||||||
max_guidance_weight=10.0,
|
max_guidance_weight=5.0,
|
||||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset_repo_id: str = "lerobot/rac_rtc_openarms"
|
|
||||||
task: str = "task"
|
|
||||||
|
|
||||||
fps: float = DEFAULT_FPS
|
|
||||||
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
|
||||||
num_episodes: int = DEFAULT_NUM_EPISODES
|
|
||||||
|
|
||||||
follower_left_port: str = "can0"
|
|
||||||
follower_right_port: str = "can1"
|
|
||||||
teleop_left_port: str = "/dev/ttyUSB1"
|
|
||||||
teleop_right_port: str = "/dev/ttyUSB0"
|
|
||||||
|
|
||||||
device: str = "cuda"
|
device: str = "cuda"
|
||||||
action_queue_size_to_get_new_actions: int = 30
|
action_queue_size_to_get_new_actions: int = 30
|
||||||
|
interpolation: bool = True
|
||||||
interpolation: bool = False
|
play_sounds: bool = True
|
||||||
push_to_hub: bool = True
|
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
policy_path = parser.get_path_arg("policy")
|
policy_path = parser.get_path_arg("policy")
|
||||||
@@ -257,13 +261,14 @@ def get_actions_thread(
|
|||||||
shutdown_event: Event,
|
shutdown_event: Event,
|
||||||
cfg: RaCRTCConfig,
|
cfg: RaCRTCConfig,
|
||||||
policy_active: Event,
|
policy_active: Event,
|
||||||
|
fps: int,
|
||||||
):
|
):
|
||||||
"""Thread for async action generation via RTC."""
|
"""Thread for async action generation via RTC."""
|
||||||
try:
|
try:
|
||||||
logger.info("[GET_ACTIONS] Starting RTC action generation thread")
|
logger.info("[GET_ACTIONS] Starting RTC action generation thread")
|
||||||
|
|
||||||
latency_tracker = LatencyTracker()
|
latency_tracker = LatencyTracker()
|
||||||
time_per_chunk = 1.0 / cfg.fps
|
time_per_chunk = 1.0 / fps
|
||||||
|
|
||||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
policy_device = policy.config.device
|
policy_device = policy.config.device
|
||||||
@@ -311,7 +316,7 @@ def get_actions_thread(
|
|||||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||||
|
|
||||||
obs_with_policy_features["task"] = [cfg.task]
|
obs_with_policy_features["task"] = [cfg.dataset.single_task]
|
||||||
obs_with_policy_features["robot_type"] = robot.name
|
obs_with_policy_features["robot_type"] = robot.name
|
||||||
|
|
||||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||||
@@ -363,14 +368,16 @@ def main(cfg: RaCRTCConfig):
|
|||||||
"""Main RaC + RTC data collection."""
|
"""Main RaC + RTC data collection."""
|
||||||
init_logging()
|
init_logging()
|
||||||
|
|
||||||
|
fps = cfg.dataset.fps
|
||||||
|
|
||||||
print("=" * 65)
|
print("=" * 65)
|
||||||
print(" RaC Data Collection with RTC - OpenArms")
|
print(" RaC Data Collection with RTC - OpenArms")
|
||||||
print("=" * 65)
|
print("=" * 65)
|
||||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||||
print(f" Dataset: {cfg.dataset_repo_id}")
|
print(f" Dataset: {cfg.dataset.repo_id}")
|
||||||
print(f" Task: {cfg.task}")
|
print(f" Task: {cfg.dataset.single_task}")
|
||||||
print(f" Policy Hz: {cfg.fps}")
|
print(f" Policy Hz: {fps}")
|
||||||
print(f" Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
|
print(f" Robot Hz: {fps * 2 if cfg.interpolation else fps}")
|
||||||
print(f" Interpolation: {cfg.interpolation}")
|
print(f" Interpolation: {cfg.interpolation}")
|
||||||
print(f" RTC Enabled: {cfg.rtc.enabled}")
|
print(f" RTC Enabled: {cfg.rtc.enabled}")
|
||||||
print("=" * 65)
|
print("=" * 65)
|
||||||
@@ -389,26 +396,12 @@ def main(cfg: RaCRTCConfig):
|
|||||||
shutdown_event = Event()
|
shutdown_event = Event()
|
||||||
policy_active = Event()
|
policy_active = Event()
|
||||||
|
|
||||||
follower_config = OpenArmsFollowerConfig(
|
follower = make_robot_from_config(cfg.robot)
|
||||||
port_left=cfg.follower_left_port,
|
follower.connect()
|
||||||
port_right=cfg.follower_right_port,
|
|
||||||
can_interface="socketcan",
|
|
||||||
id="openarms_follower",
|
|
||||||
disable_torque_on_disconnect=True,
|
|
||||||
max_relative_target=10.0,
|
|
||||||
cameras=DEFAULT_CAMERA_CONFIG,
|
|
||||||
)
|
|
||||||
|
|
||||||
follower = OpenArmsFollower(follower_config)
|
|
||||||
follower.connect(calibrate=False)
|
|
||||||
robot = RobotWrapper(follower)
|
robot = RobotWrapper(follower)
|
||||||
logger.info("Robot connected")
|
logger.info("Robot connected")
|
||||||
|
|
||||||
teleop_config = OpenArmsMiniConfig(
|
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||||
port_right=cfg.teleop_right_port,
|
|
||||||
port_left=cfg.teleop_left_port,
|
|
||||||
)
|
|
||||||
teleop = make_teleoperator_from_config(teleop_config)
|
|
||||||
teleop.connect()
|
teleop.connect()
|
||||||
teleop.disable_torque()
|
teleop.disable_torque()
|
||||||
logger.info("Teleop connected")
|
logger.info("Teleop connected")
|
||||||
@@ -420,23 +413,25 @@ def main(cfg: RaCRTCConfig):
|
|||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=teleop_proc,
|
pipeline=teleop_proc,
|
||||||
initial_features=create_initial_features(action=action_features_hw),
|
initial_features=create_initial_features(action=action_features_hw),
|
||||||
use_videos=True,
|
use_videos=cfg.dataset.video,
|
||||||
),
|
),
|
||||||
aggregate_pipeline_dataset_features(
|
aggregate_pipeline_dataset_features(
|
||||||
pipeline=obs_proc,
|
pipeline=obs_proc,
|
||||||
initial_features=create_initial_features(observation=follower.observation_features),
|
initial_features=create_initial_features(observation=follower.observation_features),
|
||||||
use_videos=True,
|
use_videos=cfg.dataset.video,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id=cfg.dataset_repo_id,
|
repo_id=cfg.dataset.repo_id,
|
||||||
fps=int(cfg.fps),
|
fps=fps,
|
||||||
|
root=cfg.dataset.root,
|
||||||
features=dataset_features,
|
features=dataset_features,
|
||||||
robot_type=follower.name,
|
robot_type=follower.name,
|
||||||
use_videos=True,
|
use_videos=cfg.dataset.video,
|
||||||
image_writer_processes=0,
|
image_writer_processes=cfg.dataset.num_image_writer_processes,
|
||||||
image_writer_threads=12,
|
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera
|
||||||
|
* len(follower.cameras if hasattr(follower, "cameras") else []),
|
||||||
)
|
)
|
||||||
dataset_lock = Lock()
|
dataset_lock = Lock()
|
||||||
|
|
||||||
@@ -453,7 +448,7 @@ def main(cfg: RaCRTCConfig):
|
|||||||
|
|
||||||
get_actions_t = Thread(
|
get_actions_t = Thread(
|
||||||
target=get_actions_thread,
|
target=get_actions_thread,
|
||||||
args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active),
|
args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active, fps),
|
||||||
daemon=True,
|
daemon=True,
|
||||||
name="GetActions",
|
name="GetActions",
|
||||||
)
|
)
|
||||||
@@ -472,15 +467,15 @@ def main(cfg: RaCRTCConfig):
|
|||||||
|
|
||||||
if cfg.interpolation:
|
if cfg.interpolation:
|
||||||
interp_factor = 2
|
interp_factor = 2
|
||||||
robot_interval = 1.0 / (cfg.fps * interp_factor)
|
robot_interval = 1.0 / (fps * interp_factor)
|
||||||
else:
|
else:
|
||||||
interp_factor = 1
|
interp_factor = 1
|
||||||
robot_interval = 1.0 / cfg.fps
|
robot_interval = 1.0 / fps
|
||||||
|
|
||||||
try:
|
try:
|
||||||
recorded = 0
|
recorded = 0
|
||||||
while recorded < cfg.num_episodes and not events["stop_recording"]:
|
while recorded < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||||
log_say(f"RaC episode {recorded + 1}", play_sounds=True)
|
log_say(f"RaC episode {recorded + 1}", play_sounds=cfg.play_sounds)
|
||||||
|
|
||||||
move_robot_to_zero(follower, duration_s=2.0)
|
move_robot_to_zero(follower, duration_s=2.0)
|
||||||
|
|
||||||
@@ -502,7 +497,7 @@ def main(cfg: RaCRTCConfig):
|
|||||||
policy_active.set()
|
policy_active.set()
|
||||||
episode_start = time.perf_counter()
|
episode_start = time.perf_counter()
|
||||||
|
|
||||||
while (time.perf_counter() - episode_start) < cfg.episode_time_sec:
|
while (time.perf_counter() - episode_start) < cfg.dataset.episode_time_s:
|
||||||
loop_start = time.perf_counter()
|
loop_start = time.perf_counter()
|
||||||
|
|
||||||
if events["exit_early"]:
|
if events["exit_early"]:
|
||||||
@@ -531,7 +526,7 @@ def main(cfg: RaCRTCConfig):
|
|||||||
robot.send_action(robot_action)
|
robot.send_action(robot_action)
|
||||||
|
|
||||||
action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action")
|
action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action")
|
||||||
frame = {**obs_frame, **action_frame, "task": cfg.task}
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
||||||
frame_buffer.append(frame)
|
frame_buffer.append(frame)
|
||||||
|
|
||||||
elif events["policy_paused"]:
|
elif events["policy_paused"]:
|
||||||
@@ -567,7 +562,7 @@ def main(cfg: RaCRTCConfig):
|
|||||||
robot_send_count += 1
|
robot_send_count += 1
|
||||||
|
|
||||||
action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action")
|
action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action")
|
||||||
frame = {**obs_frame, **action_frame, "task": cfg.task}
|
frame = {**obs_frame, **action_frame, "task": cfg.dataset.single_task}
|
||||||
frame_buffer.append(frame)
|
frame_buffer.append(frame)
|
||||||
|
|
||||||
now = time.perf_counter()
|
now = time.perf_counter()
|
||||||
@@ -588,7 +583,7 @@ def main(cfg: RaCRTCConfig):
|
|||||||
policy_active.clear()
|
policy_active.clear()
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
log_say("Re-recording", play_sounds=True)
|
log_say("Re-recording", play_sounds=cfg.play_sounds)
|
||||||
events["rerecord_episode"] = False
|
events["rerecord_episode"] = False
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -627,12 +622,12 @@ def main(cfg: RaCRTCConfig):
|
|||||||
if "gripper" in key:
|
if "gripper" in key:
|
||||||
action[key] = -0.65 * action[key]
|
action[key] = -0.65 * action[key]
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
precise_sleep(1 / cfg.fps)
|
precise_sleep(1 / fps)
|
||||||
|
|
||||||
events["in_reset"] = False
|
events["in_reset"] = False
|
||||||
events["start_next_episode"] = False
|
events["start_next_episode"] = False
|
||||||
|
|
||||||
log_say("Recording complete", play_sounds=True, blocking=True)
|
log_say("Recording complete", play_sounds=cfg.play_sounds, blocking=True)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
logger.info("Interrupted by user")
|
logger.info("Interrupted by user")
|
||||||
@@ -651,9 +646,9 @@ def main(cfg: RaCRTCConfig):
|
|||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
dataset.finalize()
|
dataset.finalize()
|
||||||
if cfg.push_to_hub:
|
if cfg.dataset.push_to_hub:
|
||||||
logger.info("Pushing to hub...")
|
logger.info("Pushing to hub...")
|
||||||
dataset.push_to_hub(private=True)
|
dataset.push_to_hub(private=cfg.dataset.private)
|
||||||
|
|
||||||
logger.info("Done")
|
logger.info("Done")
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user