mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Add RaC with RTC
This commit is contained in:
@@ -0,0 +1,663 @@
|
||||
#!/usr/bin/env python
|
||||
"""
|
||||
RaC (Recovery and Correction) Data Collection with RTC for OpenArms Robot.
|
||||
|
||||
Combines RaC paradigm with Real-Time Chunking (RTC) for smooth policy execution.
|
||||
RTC enables large flow-matching policies (Pi0, Pi0.5, SmolVLA) to produce reactive
|
||||
motion despite high inference latency by asynchronously generating action chunks.
|
||||
|
||||
The workflow:
|
||||
1. Policy runs via RTC (async action generation) - teleop is idle
|
||||
2. Press SPACE/right pedal to pause - teleop moves to match robot position
|
||||
3. Press 'c'/left pedal to take control - human provides RECOVERY + CORRECTION
|
||||
4. Press →/right pedal to end episode (save and continue to next)
|
||||
|
||||
Controls:
|
||||
SPACE/Right pedal - Pause policy
|
||||
c/Left pedal - Take control (start correction)
|
||||
→/Right pedal - End episode (when in correction mode)
|
||||
ESC - Stop recording and push to hub
|
||||
|
||||
Usage:
|
||||
python examples/rac/rac_data_collection_openarms_rtc.py \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=my_user/rac_rtc_dataset \
|
||||
--dataset.single_task="Pick up the cube"
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor import make_default_processors
|
||||
from lerobot.robots.openarms.config_openarms_follower import OpenArmsFollowerConfig
|
||||
from lerobot.robots.openarms.openarms_follower import OpenArmsFollower
|
||||
from lerobot.teleoperators import make_teleoperator_from_config
|
||||
from lerobot.teleoperators.openarms_mini.config_openarms_mini import OpenArmsMiniConfig
|
||||
from lerobot.utils.control_utils import is_headless
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
class RobotWrapper:
|
||||
"""Thread-safe wrapper for robot operations."""
|
||||
|
||||
def __init__(self, robot: OpenArmsFollower):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: dict) -> None:
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
@property
|
||||
def observation_features(self) -> dict:
|
||||
return self.robot.observation_features
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict:
|
||||
return self.robot.action_features
|
||||
|
||||
@property
|
||||
def name(self) -> str:
|
||||
return self.robot.name
|
||||
|
||||
DEFAULT_FPS = 30
|
||||
DEFAULT_EPISODE_TIME_SEC = 120
|
||||
DEFAULT_NUM_EPISODES = 50
|
||||
|
||||
DEFAULT_CAMERA_CONFIG = {
|
||||
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=DEFAULT_FPS),
|
||||
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=DEFAULT_FPS),
|
||||
"base": OpenCVCameraConfig(index_or_path="/dev/video3", width=640, height=480, fps=DEFAULT_FPS),
|
||||
}
|
||||
|
||||
@dataclass
|
||||
class RaCRTCConfig(HubMixin):
|
||||
"""Configuration for RaC data collection with RTC."""
|
||||
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
dataset_repo_id: str = "lerobot/rac_rtc_openarms"
|
||||
task: str = "task"
|
||||
|
||||
fps: float = DEFAULT_FPS
|
||||
episode_time_sec: float = DEFAULT_EPISODE_TIME_SEC
|
||||
num_episodes: int = DEFAULT_NUM_EPISODES
|
||||
|
||||
follower_left_port: str = "can0"
|
||||
follower_right_port: str = "can1"
|
||||
teleop_left_port: str = "/dev/ttyUSB1"
|
||||
teleop_right_port: str = "/dev/ttyUSB0"
|
||||
|
||||
device: str = "cuda"
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
interpolation: bool = False
|
||||
push_to_hub: bool = True
|
||||
|
||||
def __post_init__(self):
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
if self.policy is None:
|
||||
raise ValueError("policy.path is required")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def init_keyboard_listener(events: dict):
|
||||
"""Initialize keyboard listener with RaC controls."""
|
||||
if is_headless():
|
||||
logger.warning("Headless environment - keyboard controls unavailable")
|
||||
return None
|
||||
|
||||
from pynput import keyboard
|
||||
|
||||
def on_press(key):
|
||||
try:
|
||||
if events["in_reset"]:
|
||||
if key == keyboard.Key.space or key == keyboard.Key.right:
|
||||
events["start_next_episode"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
events["start_next_episode"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
events["stop_recording"] = True
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if key == keyboard.Key.space:
|
||||
if not events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ⏸ PAUSED - Press 'c' to take control")
|
||||
events["policy_paused"] = True
|
||||
elif hasattr(key, "char") and key.char == "c":
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[RaC] ▶ Taking control...")
|
||||
events["start_correction"] = True
|
||||
elif key == keyboard.Key.right:
|
||||
print("\n[RaC] → End episode")
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.left:
|
||||
print("\n[RaC] ← Re-record episode")
|
||||
events["rerecord_episode"] = True
|
||||
events["exit_early"] = True
|
||||
elif key == keyboard.Key.esc:
|
||||
events["stop_recording"] = True
|
||||
events["exit_early"] = True
|
||||
except Exception as e:
|
||||
print(f"Key error: {e}")
|
||||
|
||||
listener = keyboard.Listener(on_press=on_press)
|
||||
listener.start()
|
||||
start_pedal_listener(events)
|
||||
return listener
|
||||
|
||||
|
||||
def start_pedal_listener(events: dict):
|
||||
"""Start foot pedal listener if available."""
|
||||
import threading
|
||||
|
||||
try:
|
||||
from evdev import InputDevice, ecodes
|
||||
except ImportError:
|
||||
return
|
||||
|
||||
PEDAL_DEVICE = "/dev/input/by-id/usb-PCsensor_FootSwitch-event-kbd"
|
||||
KEY_LEFT = "KEY_A"
|
||||
KEY_RIGHT = "KEY_C"
|
||||
|
||||
def pedal_reader():
|
||||
try:
|
||||
dev = InputDevice(PEDAL_DEVICE)
|
||||
print(f"[Pedal] Connected: {dev.name}")
|
||||
|
||||
for ev in dev.read_loop():
|
||||
if ev.type != ecodes.EV_KEY:
|
||||
continue
|
||||
|
||||
from evdev import categorize
|
||||
|
||||
key = categorize(ev)
|
||||
code = key.keycode
|
||||
if isinstance(code, (list, tuple)):
|
||||
code = code[0]
|
||||
|
||||
if key.keystate != 1:
|
||||
continue
|
||||
|
||||
if events["in_reset"]:
|
||||
if code in [KEY_LEFT, KEY_RIGHT]:
|
||||
events["start_next_episode"] = True
|
||||
else:
|
||||
if code == KEY_RIGHT:
|
||||
if events["correction_active"]:
|
||||
events["exit_early"] = True
|
||||
elif not events["policy_paused"]:
|
||||
print("\n[Pedal] ⏸ PAUSED")
|
||||
events["policy_paused"] = True
|
||||
elif code == KEY_LEFT:
|
||||
if events["policy_paused"] and not events["correction_active"]:
|
||||
print("\n[Pedal] ▶ Taking control...")
|
||||
events["start_correction"] = True
|
||||
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
except PermissionError:
|
||||
logger.warning(f"[Pedal] Permission denied for {PEDAL_DEVICE}")
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
thread = threading.Thread(target=pedal_reader, daemon=True)
|
||||
thread.start()
|
||||
|
||||
|
||||
def get_actions_thread(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RaCRTCConfig,
|
||||
policy_active: Event,
|
||||
):
|
||||
"""Thread for async action generation via RTC."""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting RTC action generation thread")
|
||||
|
||||
latency_tracker = LatencyTracker()
|
||||
time_per_chunk = 1.0 / cfg.fps
|
||||
|
||||
hw_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None,
|
||||
preprocessor_overrides={"device_processor": {"device": cfg.device}},
|
||||
)
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if not policy_active.is_set():
|
||||
time.sleep(0.01)
|
||||
continue
|
||||
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk) if inference_latency else 0
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
hw_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task]
|
||||
obs_with_policy_features["robot_type"] = robot.name
|
||||
|
||||
preprocessed_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
actions = policy.predict_action_chunk(
|
||||
preprocessed_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
postprocessed_actions = postprocessor(actions).squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
time.sleep(0.01)
|
||||
|
||||
logger.info("[GET_ACTIONS] Shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
shutdown_event.set()
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def move_robot_to_zero(robot, duration_s: float = 2.0, fps: int = 50):
|
||||
"""Smoothly move robot to zero position."""
|
||||
obs = robot.get_observation()
|
||||
current_pos = {k: v for k, v in obs.items() if k.endswith(".pos")}
|
||||
target_pos = {k: 0.0 for k in current_pos}
|
||||
|
||||
print(f"[RaC] Moving robot to zero ({duration_s}s)...")
|
||||
steps = int(duration_s * fps)
|
||||
for step in range(steps + 1):
|
||||
t = step / steps
|
||||
interp_pos = {k: current_pos[k] * (1 - t) + target_pos[k] * t for k in current_pos}
|
||||
robot.send_action(interp_pos)
|
||||
time.sleep(1 / fps)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RaCRTCConfig):
|
||||
"""Main RaC + RTC data collection."""
|
||||
init_logging()
|
||||
|
||||
print("=" * 65)
|
||||
print(" RaC Data Collection with RTC - OpenArms")
|
||||
print("=" * 65)
|
||||
print(f" Policy: {cfg.policy.pretrained_path}")
|
||||
print(f" Dataset: {cfg.dataset_repo_id}")
|
||||
print(f" Task: {cfg.task}")
|
||||
print(f" Policy Hz: {cfg.fps}")
|
||||
print(f" Robot Hz: {cfg.fps * 2 if cfg.interpolation else cfg.fps}")
|
||||
print(f" Interpolation: {cfg.interpolation}")
|
||||
print(f" RTC Enabled: {cfg.rtc.enabled}")
|
||||
print("=" * 65)
|
||||
|
||||
events = {
|
||||
"exit_early": False,
|
||||
"rerecord_episode": False,
|
||||
"stop_recording": False,
|
||||
"policy_paused": False,
|
||||
"correction_active": False,
|
||||
"start_correction": False,
|
||||
"in_reset": False,
|
||||
"start_next_episode": False,
|
||||
}
|
||||
|
||||
shutdown_event = Event()
|
||||
policy_active = Event()
|
||||
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=cfg.follower_left_port,
|
||||
port_right=cfg.follower_right_port,
|
||||
can_interface="socketcan",
|
||||
id="openarms_follower",
|
||||
disable_torque_on_disconnect=True,
|
||||
max_relative_target=10.0,
|
||||
cameras=DEFAULT_CAMERA_CONFIG,
|
||||
)
|
||||
|
||||
follower = OpenArmsFollower(follower_config)
|
||||
follower.connect(calibrate=False)
|
||||
robot = RobotWrapper(follower)
|
||||
logger.info("Robot connected")
|
||||
|
||||
teleop_config = OpenArmsMiniConfig(
|
||||
port_right=cfg.teleop_right_port,
|
||||
port_left=cfg.teleop_left_port,
|
||||
)
|
||||
teleop = make_teleoperator_from_config(teleop_config)
|
||||
teleop.connect()
|
||||
teleop.disable_torque()
|
||||
logger.info("Teleop connected")
|
||||
|
||||
teleop_proc, robot_proc, obs_proc = make_default_processors()
|
||||
action_features_hw = {k: v for k, v in follower.action_features.items() if k.endswith(".pos")}
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_proc,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=obs_proc,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
fps=int(cfg.fps),
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
dataset_lock = Lock()
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
policy.init_rtc_processor()
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
logger.info(f"Policy loaded: {policy.name}")
|
||||
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
get_actions_t = Thread(
|
||||
target=get_actions_thread,
|
||||
args=(policy, robot, obs_proc, action_queue, shutdown_event, cfg, policy_active),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_t.start()
|
||||
|
||||
listener = init_keyboard_listener(events)
|
||||
|
||||
print("\n Controls:")
|
||||
print(" SPACE/Right pedal - Pause policy")
|
||||
print(" c/Left pedal - Take control")
|
||||
print(" →/Right pedal - End episode (in correction mode)")
|
||||
print(" ESC - Stop and push to hub")
|
||||
print("=" * 65 + "\n")
|
||||
|
||||
action_keys = [k for k in robot.action_features.keys() if k.endswith(".pos")]
|
||||
|
||||
if cfg.interpolation:
|
||||
interp_factor = 2
|
||||
robot_interval = 1.0 / (cfg.fps * interp_factor)
|
||||
else:
|
||||
interp_factor = 1
|
||||
robot_interval = 1.0 / cfg.fps
|
||||
|
||||
try:
|
||||
recorded = 0
|
||||
while recorded < cfg.num_episodes and not events["stop_recording"]:
|
||||
log_say(f"RaC episode {recorded + 1}", play_sounds=True)
|
||||
|
||||
move_robot_to_zero(follower, duration_s=2.0)
|
||||
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
events["policy_paused"] = False
|
||||
events["correction_active"] = False
|
||||
events["start_correction"] = False
|
||||
events["exit_early"] = False
|
||||
|
||||
frame_buffer = []
|
||||
prev_action: Tensor | None = None
|
||||
interpolated_actions: list[Tensor] = []
|
||||
interp_idx = 0
|
||||
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = time.perf_counter()
|
||||
|
||||
policy_active.set()
|
||||
episode_start = time.perf_counter()
|
||||
|
||||
while (time.perf_counter() - episode_start) < cfg.episode_time_sec:
|
||||
loop_start = time.perf_counter()
|
||||
|
||||
if events["exit_early"]:
|
||||
break
|
||||
|
||||
if events["start_correction"] and not events["correction_active"]:
|
||||
policy_active.clear()
|
||||
print("[RaC] Moving teleop to robot position...")
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
teleop.disable_torque()
|
||||
events["correction_active"] = True
|
||||
events["start_correction"] = False
|
||||
print("[RaC] Correction mode - you have control")
|
||||
|
||||
obs = robot.get_observation()
|
||||
obs_filtered = {k: v for k, v in obs.items() if k in robot.observation_features}
|
||||
obs_frame = build_dataset_frame(dataset_features, obs_filtered, prefix="observation")
|
||||
|
||||
if events["correction_active"]:
|
||||
robot_action = teleop.get_action()
|
||||
for key in robot_action:
|
||||
if "gripper" in key:
|
||||
robot_action[key] = -0.65 * robot_action[key]
|
||||
robot.send_action(robot_action)
|
||||
|
||||
action_frame = build_dataset_frame(dataset_features, robot_action, prefix="action")
|
||||
frame = {**obs_frame, **action_frame, "task": cfg.task}
|
||||
frame_buffer.append(frame)
|
||||
|
||||
elif events["policy_paused"]:
|
||||
pass
|
||||
|
||||
else:
|
||||
if interp_idx >= len(interpolated_actions):
|
||||
new_action = action_queue.get()
|
||||
if new_action is not None:
|
||||
current_action = new_action.cpu()
|
||||
policy_consume_count += 1
|
||||
|
||||
if cfg.interpolation and prev_action is not None:
|
||||
mid = prev_action + 0.5 * (current_action - prev_action)
|
||||
interpolated_actions = [mid, current_action]
|
||||
else:
|
||||
interpolated_actions = [current_action]
|
||||
|
||||
prev_action = current_action
|
||||
interp_idx = 0
|
||||
|
||||
if interp_idx < len(interpolated_actions):
|
||||
action_to_send = interpolated_actions[interp_idx]
|
||||
interp_idx += 1
|
||||
|
||||
action_dict = {}
|
||||
for i, key in enumerate(action_keys):
|
||||
if i < len(action_to_send):
|
||||
action_dict[key] = action_to_send[i].item()
|
||||
|
||||
action_processed = robot_proc((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
robot_send_count += 1
|
||||
|
||||
action_frame = build_dataset_frame(dataset_features, action_dict, prefix="action")
|
||||
frame = {**obs_frame, **action_frame, "task": cfg.task}
|
||||
frame_buffer.append(frame)
|
||||
|
||||
now = time.perf_counter()
|
||||
if now - last_hz_print >= 5.0:
|
||||
elapsed = now - last_hz_print
|
||||
actual_robot_hz = robot_send_count / elapsed if elapsed > 0 else 0
|
||||
actual_policy_hz = policy_consume_count / elapsed if elapsed > 0 else 0
|
||||
logger.info(f"[ACTOR] Actual Hz - Robot: {actual_robot_hz:.1f}, Policy: {actual_policy_hz:.1f}")
|
||||
robot_send_count = 0
|
||||
policy_consume_count = 0
|
||||
last_hz_print = now
|
||||
|
||||
dt = time.perf_counter() - loop_start
|
||||
sleep_time = max(0, robot_interval - dt - 0.001)
|
||||
if sleep_time > 0:
|
||||
precise_sleep(sleep_time)
|
||||
|
||||
policy_active.clear()
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording", play_sounds=True)
|
||||
events["rerecord_episode"] = False
|
||||
continue
|
||||
|
||||
for frame in frame_buffer:
|
||||
dataset.add_frame(frame)
|
||||
|
||||
with dataset_lock:
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
logger.info(f"Saving episode ({dataset.episode_buffer['size']} frames)")
|
||||
dataset.save_episode()
|
||||
|
||||
recorded += 1
|
||||
|
||||
if recorded < cfg.num_episodes and not events["stop_recording"]:
|
||||
events["in_reset"] = True
|
||||
events["start_next_episode"] = False
|
||||
print("\n[RaC] RESET - Press any key/pedal to enable teleop")
|
||||
|
||||
obs = robot.get_observation()
|
||||
robot_pos = {k: v for k, v in obs.items() if k.endswith(".pos") and k in robot.observation_features}
|
||||
teleop.smooth_move_to(robot_pos, duration_s=2.0, fps=50)
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
precise_sleep(0.05)
|
||||
|
||||
if events["stop_recording"]:
|
||||
break
|
||||
|
||||
events["start_next_episode"] = False
|
||||
teleop.disable_torque()
|
||||
print("[RaC] Teleop enabled - move to start, then press key/pedal")
|
||||
|
||||
while not events["start_next_episode"] and not events["stop_recording"]:
|
||||
action = teleop.get_action()
|
||||
for key in action:
|
||||
if "gripper" in key:
|
||||
action[key] = -0.65 * action[key]
|
||||
robot.send_action(action)
|
||||
precise_sleep(1 / cfg.fps)
|
||||
|
||||
events["in_reset"] = False
|
||||
events["start_next_episode"] = False
|
||||
|
||||
log_say("Recording complete", play_sounds=True, blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Interrupted by user")
|
||||
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
policy_active.clear()
|
||||
|
||||
if get_actions_t.is_alive():
|
||||
get_actions_t.join(timeout=5.0)
|
||||
|
||||
follower.disconnect()
|
||||
teleop.disconnect()
|
||||
|
||||
if listener:
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
if cfg.push_to_hub:
|
||||
logger.info("Pushing to hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
logger.info("Done")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user