This commit is contained in:
Pepijn
2026-01-09 09:41:04 +01:00
parent 6d12740c24
commit 8e430f323f
+137 -246
View File
@@ -15,22 +15,20 @@
# limitations under the License. # limitations under the License.
""" """
OpenArms Policy Evaluation with Async Inference + Interpolation OpenArms Policy Evaluation with Interpolation
Key features: Evaluates a trained policy with smooth action interpolation:
- ASYNC INFERENCE: Policy runs in background thread, never blocks robot loop - Decoupled camera capture (CAMERA_FPS) from robot control (ROBOT_FPS)
- Robot control loop runs at true ROBOT_FPS (50Hz+) - Speed multiplier to execute actions faster than training
- Interpolation between policy outputs for smooth motion - Velocity feedforward for smoother tracking
- Speed multiplier to execute faster than training - Adjustable PID gains
Example usage: Example usage:
python examples/openarms/evaluate_interpolation.py python examples/openarms/evaluate_interpolation.py
""" """
import threading
import time import time
from collections import deque from collections import deque
from dataclasses import dataclass
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -57,21 +55,23 @@ HF_EVAL_DATASET_ID = "lerobot-data-collection/three-folds-pi0_eval_interp" # TO
TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task TASK_DESCRIPTION = "three-folds-dataset" # TODO: Replace with your task
# ======================== TIMING CONFIG ======================== # ======================== TIMING CONFIG ========================
CAMERA_FPS = 30 # Camera hardware limit (fixed)
POLICY_FPS = 30 # What the policy was trained with POLICY_FPS = 30 # What the policy was trained with
SPEED_MULTIPLIER = 1.0 # Execute actions faster (1.0 = normal, 1.2 = 20% faster) SPEED_MULTIPLIER = 1.2 # Execute actions faster (1.0 = normal, 1.2 = 20% faster)
ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation) ROBOT_FPS = 50 # Robot command rate (higher = smoother interpolation)
# Derived values # Derived values
EFFECTIVE_POLICY_FPS = int(POLICY_FPS * SPEED_MULTIPLIER) EFFECTIVE_POLICY_FPS = int(POLICY_FPS * SPEED_MULTIPLIER) # How fast we consume actions (36Hz at 1.2x)
NUM_EPISODES = 1 NUM_EPISODES = 1
EPISODE_TIME_SEC = 300 EPISODE_TIME_SEC = 300
RESET_TIME_SEC = 60 RESET_TIME_SEC = 60
# ======================== PID TUNING ======================== # ======================== PID TUNING ========================
CUSTOM_KP_SCALE = 0.7 # Set to None to use robot config defaults
CUSTOM_KD_SCALE = 1.3 CUSTOM_KP_SCALE = 0.7 # Scale factor for position gain (0.5-1.0, lower = smoother)
USE_VELOCITY_FEEDFORWARD = True CUSTOM_KD_SCALE = 1.3 # Scale factor for damping gain (1.0-2.0, higher = less overshoot)
USE_VELOCITY_FEEDFORWARD = True # Enable velocity feedforward for smoother tracking
# ======================== ROBOT CONFIG ======================== # ======================== ROBOT CONFIG ========================
FOLLOWER_LEFT_PORT = "can0" FOLLOWER_LEFT_PORT = "can0"
@@ -81,7 +81,7 @@ USE_LEADER_FOR_RESETS = True
LEADER_LEFT_PORT = "can2" LEADER_LEFT_PORT = "can2"
LEADER_RIGHT_PORT = "can3" LEADER_RIGHT_PORT = "can3"
CAMERA_FPS = 30 # Camera config uses CAMERA_FPS (hardware limit)
CAMERA_CONFIG = { CAMERA_CONFIG = {
"left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=CAMERA_FPS), "left_wrist": OpenCVCameraConfig(index_or_path="/dev/video5", width=640, height=480, fps=CAMERA_FPS),
"right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=CAMERA_FPS), "right_wrist": OpenCVCameraConfig(index_or_path="/dev/video1", width=640, height=480, fps=CAMERA_FPS),
@@ -89,170 +89,47 @@ CAMERA_CONFIG = {
} }
@dataclass
class InferenceResult:
"""Result from async inference thread."""
robot_action: dict
observation_frame: dict
obs_processed: dict
act_processed: dict
timestamp: float
inference_time_ms: float
class AsyncInferenceThread(threading.Thread):
"""Background thread for camera capture + policy inference."""
def __init__(
self,
robot,
policy,
preprocessor,
postprocessor,
robot_observation_processor,
robot_action_processor,
dataset,
task: str,
):
super().__init__(daemon=True)
self.robot = robot
self.policy = policy
self.preprocessor = preprocessor
self.postprocessor = postprocessor
self.robot_observation_processor = robot_observation_processor
self.robot_action_processor = robot_action_processor
self.dataset = dataset
self.task = task
self._lock = threading.Lock()
self._latest_result: InferenceResult | None = None
self._result_consumed = True
self._running = False
self._inference_hz_tracker = HzTracker(name="Inference", print_interval=5.0)
def get_latest_result(self) -> InferenceResult | None:
"""Get latest inference result (thread-safe). Returns None if no new result."""
with self._lock:
if self._result_consumed:
return None
result = self._latest_result
self._result_consumed = True
return result
def peek_latest_result(self) -> InferenceResult | None:
"""Peek at latest result without marking as consumed."""
with self._lock:
return self._latest_result
def stop(self):
self._running = False
def run(self):
from lerobot.scripts.lerobot_record import build_dataset_frame, make_robot_action
self._running = True
self.policy.reset()
while self._running:
try:
start = time.perf_counter()
# Capture observation
obs = self.robot.get_observation()
obs_processed = self.robot_observation_processor(obs)
observation_frame = build_dataset_frame(
self.dataset.features, obs_processed, prefix="observation"
)
# Run inference
action_values = predict_action(
observation=observation_frame,
policy=self.policy,
device=get_safe_torch_device(self.policy.config.device),
preprocessor=self.preprocessor,
postprocessor=self.postprocessor,
use_amp=self.policy.config.use_amp,
task=self.task,
robot_type=self.robot.robot_type,
)
act_processed = make_robot_action(action_values, self.dataset.features)
robot_action = self.robot_action_processor((act_processed, obs))
inference_time_ms = (time.perf_counter() - start) * 1000
# Store result
result = InferenceResult(
robot_action=robot_action,
observation_frame=observation_frame,
obs_processed=obs_processed,
act_processed=act_processed,
timestamp=time.perf_counter(),
inference_time_ms=inference_time_ms,
)
with self._lock:
self._latest_result = result
self._result_consumed = False
self._inference_hz_tracker.tick()
except Exception as e:
print(f"Inference thread error: {e}")
time.sleep(0.01)
# Print final inference stats
hz = self._inference_hz_tracker.get_avg_hz()
if hz:
print(f"Final inference Hz: {hz:.1f}")
class ActionInterpolator: class ActionInterpolator:
"""Interpolate between policy actions for smoother robot control.""" """Interpolate between policy actions for smoother robot control with velocity estimation."""
def __init__(self, robot_fps: int): def __init__(self, effective_policy_fps: int, robot_fps: int):
self.effective_policy_fps = effective_policy_fps
self.robot_fps = robot_fps self.robot_fps = robot_fps
self.substeps_per_policy_step = robot_fps / effective_policy_fps
self.prev_action: dict | None = None self.prev_action: dict | None = None
self.curr_action: dict | None = None self.curr_action: dict | None = None
self.prev_time: float = 0 self.substep = 0
self.curr_time: float = 0
self.last_interpolated: dict | None = None self.last_interpolated: dict | None = None
def update(self, new_action: dict, timestamp: float) -> None: def update(self, new_action: dict) -> None:
self.prev_action = self.curr_action self.prev_action = self.curr_action
self.prev_time = self.curr_time
self.curr_action = new_action self.curr_action = new_action
self.curr_time = timestamp self.substep = 0
def get_interpolated_action(self, current_time: float) -> tuple[dict | None, dict | None]: def get_interpolated_action(self) -> tuple[dict | None, dict | None]:
"""Returns (interpolated_position, estimated_velocity_deg_per_sec)""" """Returns (interpolated_position, estimated_velocity_deg_per_sec)"""
if self.curr_action is None: if self.curr_action is None:
return None, None return None, None
if self.prev_action is None: if self.prev_action is None:
self.last_interpolated = self.curr_action.copy() self.last_interpolated = self.curr_action.copy()
return self.curr_action, {k: 0.0 for k in self.curr_action} return self.curr_action, {k: 0.0 for k in self.curr_action}
# Time-based interpolation
dt_actions = self.curr_time - self.prev_time
if dt_actions <= 0:
dt_actions = 1.0 / 30 # Fallback
t = (current_time - self.prev_time) / dt_actions t = min(self.substep / self.substeps_per_policy_step, 1.0)
t = max(0.0, min(t, 1.5)) # Allow slight extrapolation self.substep += 1
interpolated = {} interpolated = {}
velocity = {} velocity = {}
dt_robot = 1.0 / self.robot_fps dt = 1.0 / self.robot_fps
for key in self.curr_action: for key in self.curr_action:
prev = self.prev_action.get(key, self.curr_action[key]) prev = self.prev_action.get(key, self.curr_action[key])
curr = self.curr_action[key] curr = self.curr_action[key]
interpolated[key] = prev + t * (curr - prev) interpolated[key] = prev * (1 - t) + curr * t
if self.last_interpolated is not None and key in self.last_interpolated: if self.last_interpolated is not None and key in self.last_interpolated:
velocity[key] = (interpolated[key] - self.last_interpolated[key]) / dt_robot velocity[key] = (interpolated[key] - self.last_interpolated[key]) / dt
else: else:
velocity[key] = (curr - prev) / dt_actions velocity[key] = (curr - prev) * self.effective_policy_fps
self.last_interpolated = interpolated.copy() self.last_interpolated = interpolated.copy()
return interpolated, velocity return interpolated, velocity
@@ -260,15 +137,14 @@ class ActionInterpolator:
def reset(self): def reset(self):
self.prev_action = None self.prev_action = None
self.curr_action = None self.curr_action = None
self.prev_time = 0 self.substep = 0
self.curr_time = 0
self.last_interpolated = None self.last_interpolated = None
class HzTracker: class HzTracker:
"""Track and display actual loop frequency.""" """Track and display actual loop frequency."""
def __init__(self, name: str = "Loop", window_size: int = 100, print_interval: float = 2.0): def __init__(self, name: str = "Robot", window_size: int = 100, print_interval: float = 2.0):
self.name = name self.name = name
self.timestamps = deque(maxlen=window_size) self.timestamps = deque(maxlen=window_size)
self.last_print_time = 0 self.last_print_time = 0
@@ -299,44 +175,72 @@ class HzTracker:
self.last_print_time = 0 self.last_print_time = 0
def async_eval_loop( def interpolated_eval_loop(
robot, robot,
inference_thread: AsyncInferenceThread, policy,
interpolator: ActionInterpolator, preprocessor,
robot_hz_tracker: HzTracker, postprocessor,
robot_observation_processor,
robot_action_processor,
dataset, dataset,
events, events,
robot_fps: int, interpolator: ActionInterpolator,
robot_hz_tracker: HzTracker,
camera_fps: int,
effective_policy_fps: int, effective_policy_fps: int,
robot_fps: int,
control_time_s: float, control_time_s: float,
task: str, task: str,
custom_kp: dict | None = None, kp_scale: float | None = None,
custom_kd: dict | None = None, kd_scale: float | None = None,
use_velocity_ff: bool = False, use_velocity_ff: bool = False,
): ):
""" """
Main robot control loop with async inference. Run evaluation with decoupled camera and robot control:
- Camera captures at camera_fps (hardware limit)
- Inference runs in background thread (as fast as it can) - Policy inference runs when new camera frame is available
- This loop runs at ROBOT_FPS, never blocked by inference - Actions are consumed at effective_policy_fps (sped up by SPEED_MULTIPLIER)
- Interpolates between inference results for smooth motion - Robot receives interpolated commands at robot_fps (smoothest)
""" """
from lerobot.scripts.lerobot_record import build_dataset_frame from lerobot.scripts.lerobot_record import build_dataset_frame, make_robot_action
from lerobot.utils.visualization_utils import log_rerun_data from lerobot.utils.visualization_utils import log_rerun_data
robot_dt = 1.0 / robot_fps camera_dt = 1.0 / camera_fps
policy_dt = 1.0 / effective_policy_fps policy_dt = 1.0 / effective_policy_fps
robot_dt = 1.0 / robot_fps
interpolator.reset() interpolator.reset()
robot_hz_tracker.reset() robot_hz_tracker.reset()
policy.reset()
# Build custom gains if scaling is enabled
custom_kp = None
custom_kd = None
if kp_scale is not None or kd_scale is not None:
custom_kp = {}
custom_kd = {}
for arm in ["right", "left"]:
bus = robot.bus_right if arm == "right" else robot.bus_left
for i, motor_name in enumerate(bus.motors):
full_name = f"{arm}_{motor_name}"
default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp
default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
print(f"Custom gains: kp_scale={kp_scale}, kd_scale={kd_scale}")
if use_velocity_ff:
print("Velocity feedforward: enabled")
last_camera_time = -camera_dt
last_policy_action_time = -policy_dt
cached_observation = None
cached_robot_action = None
last_action_consume_time = 0
start_time = time.perf_counter() start_time = time.perf_counter()
print(f"\nAsync eval loop started:") print(f"\nStarting interpolated eval loop:")
print(f" Robot control: {robot_fps}Hz (main thread, never blocked)") print(f" Camera: {camera_fps}Hz | Policy actions consumed: {effective_policy_fps}Hz | Robot: {robot_fps}Hz")
print(f" Inference: background thread (as fast as possible)")
print(f" Action consume rate: {effective_policy_fps}Hz")
while time.perf_counter() - start_time < control_time_s: while time.perf_counter() - start_time < control_time_s:
if events["exit_early"] or events["stop_recording"]: if events["exit_early"] or events["stop_recording"]:
@@ -345,26 +249,45 @@ def async_eval_loop(
loop_start = time.perf_counter() loop_start = time.perf_counter()
elapsed = loop_start - start_time elapsed = loop_start - start_time
# Check for new inference result (non-blocking) # === CAMERA CAPTURE (at camera_fps, decoupled from robot) ===
result = inference_thread.get_latest_result() if elapsed - last_camera_time >= camera_dt:
if result is not None: obs = robot.get_observation()
# Consume action at effective_policy_fps rate obs_processed = robot_observation_processor(obs)
if elapsed - last_action_consume_time >= policy_dt: observation_frame = build_dataset_frame(dataset.features, obs_processed, prefix="observation")
interpolator.update(result.robot_action, result.timestamp)
last_action_consume_time = elapsed # Run policy inference with fresh observation
action_values = predict_action(
# Save to dataset observation=observation_frame,
if dataset is not None: policy=policy,
action_frame = build_dataset_frame( device=get_safe_torch_device(policy.config.device),
dataset.features, result.act_processed, prefix="action" preprocessor=preprocessor,
) postprocessor=postprocessor,
frame = {**result.observation_frame, **action_frame, "task": task} use_amp=policy.config.use_amp,
dataset.add_frame(frame) task=task,
log_rerun_data(observation=result.obs_processed, action=result.act_processed) robot_type=robot.robot_type,
)
act_processed = make_robot_action(action_values, dataset.features)
cached_robot_action = robot_action_processor((act_processed, obs))
cached_observation = (obs_processed, observation_frame, act_processed)
last_camera_time = elapsed
# Get interpolated action and send to robot (always runs at robot_fps) # === ACTION UPDATE (at effective_policy_fps, faster than camera if speed > 1) ===
current_time = time.perf_counter() if elapsed - last_policy_action_time >= policy_dt and cached_robot_action is not None:
smooth_action, velocity = interpolator.get_interpolated_action(current_time) interpolator.update(cached_robot_action)
last_policy_action_time = elapsed
# Save to dataset at effective policy rate
if dataset is not None and cached_observation is not None:
obs_processed, observation_frame, act_processed = cached_observation
action_frame = build_dataset_frame(dataset.features, act_processed, prefix="action")
frame = {**observation_frame, **action_frame, "task": task}
dataset.add_frame(frame)
log_rerun_data(observation=obs_processed, action=act_processed)
# === ROBOT COMMAND (at robot_fps, highest rate for smoothness) ===
smooth_action, velocity = interpolator.get_interpolated_action()
if smooth_action is not None: if smooth_action is not None:
vel_ff = velocity if use_velocity_ff else None vel_ff = velocity if use_velocity_ff else None
robot.send_action(smooth_action, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff) robot.send_action(smooth_action, custom_kp=custom_kp, custom_kd=custom_kd, velocity_feedforward=vel_ff)
@@ -379,44 +302,32 @@ def async_eval_loop(
# Print final stats # Print final stats
robot_hz = robot_hz_tracker.get_avg_hz() robot_hz = robot_hz_tracker.get_avg_hz()
if robot_hz: if robot_hz:
print(f"\nFinal robot Hz: {robot_hz:.1f}") print(f"\nFinal average robot Hz: {robot_hz:.1f}")
def build_custom_gains(robot, kp_scale: float | None, kd_scale: float | None):
"""Build custom PID gains dict."""
if kp_scale is None and kd_scale is None:
return None, None
custom_kp = {}
custom_kd = {}
for arm in ["right", "left"]:
bus = robot.bus_right if arm == "right" else robot.bus_left
for i, motor_name in enumerate(bus.motors):
full_name = f"{arm}_{motor_name}"
default_kp = robot.config.position_kp[i] if isinstance(robot.config.position_kp, list) else robot.config.position_kp
default_kd = robot.config.position_kd[i] if isinstance(robot.config.position_kd, list) else robot.config.position_kd
custom_kp[full_name] = default_kp * (kp_scale or 1.0)
custom_kd[full_name] = default_kd * (kd_scale or 1.0)
return custom_kp, custom_kd
def main(): def main():
"""Main evaluation function.""" """Main evaluation function."""
print("=" * 60) print("=" * 60)
print("OpenArms Async Inference + Interpolation Evaluation") print("OpenArms Policy Evaluation with Interpolation")
print("=" * 60) print("=" * 60)
print(f"\nModel: {HF_MODEL_ID}") print(f"\nModel: {HF_MODEL_ID}")
print(f"Dataset: {HF_EVAL_DATASET_ID}") print(f"Dataset: {HF_EVAL_DATASET_ID}")
print(f"Task: {TASK_DESCRIPTION}") print(f"Task: {TASK_DESCRIPTION}")
print(f"\n--- Timing ---") print(f"\n--- Timing ---")
print(f"Camera FPS: {CAMERA_FPS} (hardware limit)")
print(f"Policy trained at: {POLICY_FPS}Hz") print(f"Policy trained at: {POLICY_FPS}Hz")
print(f"Speed multiplier: {SPEED_MULTIPLIER}x") print(f"Speed multiplier: {SPEED_MULTIPLIER}x")
print(f"Effective policy FPS: {EFFECTIVE_POLICY_FPS}Hz") print(f"Effective policy FPS: {EFFECTIVE_POLICY_FPS}Hz (actions consumed)")
print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated, non-blocking)") print(f"Robot FPS: {ROBOT_FPS}Hz (interpolated commands)")
print(f"\n--- PID Tuning ---") print(f"\n--- PID Tuning ---")
print(f"KP scale: {CUSTOM_KP_SCALE}") print(f"KP scale: {CUSTOM_KP_SCALE}")
print(f"KD scale: {CUSTOM_KD_SCALE}") print(f"KD scale: {CUSTOM_KD_SCALE}")
print(f"Velocity feedforward: {USE_VELOCITY_FEEDFORWARD}") print(f"Velocity feedforward: {USE_VELOCITY_FEEDFORWARD}")
print(f"\n--- Episodes ---")
print(f"Episodes: {NUM_EPISODES}")
print(f"Duration: {EPISODE_TIME_SEC}s per episode")
print(f"Reset time: {RESET_TIME_SEC}s")
print(f"Leader for resets: {USE_LEADER_FOR_RESETS}")
print("=" * 60) print("=" * 60)
follower_config = OpenArmsFollowerConfig( follower_config = OpenArmsFollowerConfig(
@@ -456,6 +367,8 @@ def main():
leader.bus_left.enable_torque() leader.bus_left.enable_torque()
time.sleep(0.1) time.sleep(0.1)
print(f"Leader connected with gravity compensation") print(f"Leader connected with gravity compensation")
else:
print(f"Leader connected (no gravity compensation)")
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
@@ -488,6 +401,7 @@ def main():
leader.disconnect() leader.disconnect()
return return
# Dataset uses effective policy FPS (sped up rate)
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id=HF_EVAL_DATASET_ID, repo_id=HF_EVAL_DATASET_ID,
fps=EFFECTIVE_POLICY_FPS, fps=EFFECTIVE_POLICY_FPS,
@@ -513,13 +427,9 @@ def main():
print(f"\nRunning evaluation...") print(f"\nRunning evaluation...")
listener, events = init_keyboard_listener() listener, events = init_keyboard_listener()
init_rerun(session_name="openarms_async_eval") init_rerun(session_name="openarms_evaluation_interp")
custom_kp, custom_kd = build_custom_gains(follower, CUSTOM_KP_SCALE, CUSTOM_KD_SCALE) interpolator = ActionInterpolator(effective_policy_fps=EFFECTIVE_POLICY_FPS, robot_fps=ROBOT_FPS)
if custom_kp:
print(f"Custom gains: kp_scale={CUSTOM_KP_SCALE}, kd_scale={CUSTOM_KD_SCALE}")
interpolator = ActionInterpolator(robot_fps=ROBOT_FPS)
robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0) robot_hz_tracker = HzTracker(name="Robot", window_size=100, print_interval=2.0)
episode_idx = 0 episode_idx = 0
@@ -529,8 +439,7 @@ def main():
log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Evaluating episode {episode_idx + 1} of {NUM_EPISODES}")
print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---") print(f"\n--- Episode {episode_idx + 1}/{NUM_EPISODES} ---")
# Start async inference thread interpolated_eval_loop(
inference_thread = AsyncInferenceThread(
robot=follower, robot=follower,
policy=policy, policy=policy,
preprocessor=preprocessor, preprocessor=preprocessor,
@@ -538,37 +447,19 @@ def main():
robot_observation_processor=robot_observation_processor, robot_observation_processor=robot_observation_processor,
robot_action_processor=robot_action_processor, robot_action_processor=robot_action_processor,
dataset=dataset, dataset=dataset,
task=TASK_DESCRIPTION, events=events,
)
inference_thread.start()
# Wait for first inference result
print("Waiting for first inference...")
while inference_thread.peek_latest_result() is None:
time.sleep(0.01)
print("First inference complete, starting control loop")
# Run the async evaluation loop
async_eval_loop(
robot=follower,
inference_thread=inference_thread,
interpolator=interpolator, interpolator=interpolator,
robot_hz_tracker=robot_hz_tracker, robot_hz_tracker=robot_hz_tracker,
dataset=dataset, camera_fps=CAMERA_FPS,
events=events,
robot_fps=ROBOT_FPS,
effective_policy_fps=EFFECTIVE_POLICY_FPS, effective_policy_fps=EFFECTIVE_POLICY_FPS,
robot_fps=ROBOT_FPS,
control_time_s=EPISODE_TIME_SEC, control_time_s=EPISODE_TIME_SEC,
task=TASK_DESCRIPTION, task=TASK_DESCRIPTION,
custom_kp=custom_kp, kp_scale=CUSTOM_KP_SCALE,
custom_kd=custom_kd, kd_scale=CUSTOM_KD_SCALE,
use_velocity_ff=USE_VELOCITY_FEEDFORWARD, use_velocity_ff=USE_VELOCITY_FEEDFORWARD,
) )
# Stop inference thread
inference_thread.stop()
inference_thread.join(timeout=2.0)
if events["rerecord_episode"]: if events["rerecord_episode"]:
log_say("Re-recording episode") log_say("Re-recording episode")
events["rerecord_episode"] = False events["rerecord_episode"] = False
@@ -581,7 +472,6 @@ def main():
dataset.save_episode() dataset.save_episode()
episode_idx += 1 episode_idx += 1
# Reset phase
if not events["stop_recording"] and episode_idx < NUM_EPISODES: if not events["stop_recording"] and episode_idx < NUM_EPISODES:
if USE_LEADER_FOR_RESETS and leader: if USE_LEADER_FOR_RESETS and leader:
log_say("Reset the environment using leader arms") log_say("Reset the environment using leader arms")
@@ -694,3 +584,4 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()