chore(rl): move rl related code to its directory at top level (#2002)

* chore(rl): move rl related code to its directory at top level

* chore(style): apply pre-commit to renamed headers

* test(rl): fix rl imports

* docs(rl): update rl headers doc
This commit is contained in:
Steven Palma
2025-09-23 16:32:34 +02:00
committed by GitHub
parent 9d0cf64da6
commit d6a32e9742
12 changed files with 44 additions and 41 deletions
+740
View File
@@ -0,0 +1,740 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Actor server runner for distributed HILSerl robot policy training.
This script implements the actor component of the distributed HILSerl architecture.
It executes the policy in the robot environment, collects experience,
and sends transitions to the learner server for policy updates.
Examples of usage:
- Start an actor server for real robot training with human-in-the-loop intervention:
```bash
python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
```
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
server is started before launching the actor.
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
reduce interventions as the policy improves.
**WORKFLOW**:
1. Determine robot workspace bounds using `find_joint_limits.py`
2. Record demonstrations with `gym_manipulator.py` in record mode
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
4. Start the learner server with the training configuration
5. Start this actor server with the same configuration
6. Use human interventions to guide policy learning
For more details on the complete HILSerl training workflow, see:
https://github.com/michel-aractingi/lerobot-hilserl-guide
"""
import logging
import os
import time
from functools import lru_cache
from queue import Empty
import grpc
import torch
from torch import nn
from torch.multiprocessing import Event, Queue
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.processor import TransitionKey
from lerobot.robots import so100_follower # noqa: F401
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import (
bytes_to_state_dict,
grpc_channel_options,
python_object_to_bytes,
receive_bytes_in_chunks,
send_bytes_in_chunks,
transitions_to_bytes,
)
from lerobot.utils.process import ProcessSignalHandler
from lerobot.utils.queue import get_last_item_from_queue
from lerobot.utils.random_utils import set_seed
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.transition import (
Transition,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.utils.utils import (
TimerManager,
get_safe_torch_device,
init_logging,
)
from .gym_manipulator import (
create_transition,
make_processors,
make_robot_env,
step_env_and_process_transition,
)
ACTOR_SHUTDOWN_TIMEOUT = 30
# Main entry point
@parser.wrap()
def actor_cli(cfg: TrainRLServerPipelineConfig):
cfg.validate()
display_pid = False
if not use_threads(cfg):
import torch.multiprocessing as mp
mp.set_start_method("spawn")
display_pid = True
# Create logs directory to ensure it exists
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=display_pid)
logging.info(f"Actor logging initialized, writing to {log_file}")
is_threaded = use_threads(cfg)
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
logging.info("[ACTOR] Establishing connection with Learner")
if not establish_learner_connection(learner_client, shutdown_event):
logging.error("[ACTOR] Failed to establish connection with Learner")
return
if not use_threads(cfg):
# If we use multithreading, we can reuse the channel
grpc_channel.close()
grpc_channel = None
logging.info("[ACTOR] Connection with Learner established")
parameters_queue = Queue()
transitions_queue = Queue()
interactions_queue = Queue()
concurrency_entity = None
if use_threads(cfg):
from threading import Thread
concurrency_entity = Thread
else:
from multiprocessing import Process
concurrency_entity = Process
receive_policy_process = concurrency_entity(
target=receive_policy,
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process = concurrency_entity(
target=send_transitions,
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
daemon=True,
)
interactions_process = concurrency_entity(
target=send_interactions,
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
daemon=True,
)
transitions_process.start()
interactions_process.start()
receive_policy_process.start()
act_with_policy(
cfg=cfg,
shutdown_event=shutdown_event,
parameters_queue=parameters_queue,
transitions_queue=transitions_queue,
interactions_queue=interactions_queue,
)
logging.info("[ACTOR] Policy process joined")
logging.info("[ACTOR] Closing queues")
transitions_queue.close()
interactions_queue.close()
parameters_queue.close()
transitions_process.join()
logging.info("[ACTOR] Transitions process joined")
interactions_process.join()
logging.info("[ACTOR] Interactions process joined")
receive_policy_process.join()
logging.info("[ACTOR] Receive policy process joined")
logging.info("[ACTOR] join queues")
transitions_queue.cancel_join_thread()
interactions_queue.cancel_join_thread()
parameters_queue.cancel_join_thread()
logging.info("[ACTOR] queues closed")
# Core algorithm functions
def act_with_policy(
cfg: TrainRLServerPipelineConfig,
shutdown_event: any, # Event,
parameters_queue: Queue,
transitions_queue: Queue,
interactions_queue: Queue,
):
"""
Executes policy interaction within the environment.
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
Args:
cfg: Configuration settings for the interaction process.
shutdown_event: Event to check if the process should shutdown.
parameters_queue: Queue to receive updated network parameters from the learner.
transitions_queue: Queue to send transitions to the learner.
interactions_queue: Queue to send interactions to the learner.
"""
# Initialize logging for multiprocessing
if not use_threads(cfg):
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor policy process logging initialized")
logging.info("make_env online")
online_env, teleop_device = make_robot_env(cfg=cfg.env)
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
set_seed(cfg.seed)
device = get_safe_torch_device(cfg.policy.device, log=True)
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# NOTE: For the moment we will solely handle the case of a single environment
sum_reward_episode = 0
list_transition_to_send_to_learner = []
episode_intervention = False
# Add counters for intervention rate calculation
episode_intervention_steps = 0
episode_total_steps = 0
policy_timer = TimerManager("Policy inference", log=False)
for interaction_step in range(cfg.policy.online_steps):
start_time = time.perf_counter()
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down act_with_policy")
return
observation = {
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
}
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
# Use the new step function
new_transition = step_env_and_process_transition(
env=online_env,
transition=transition,
action=action,
env_processor=env_processor,
action_processor=action_processor,
)
# Extract values from processed transition
next_observation = {
k: v
for k, v in new_transition[TransitionKey.OBSERVATION].items()
if k in cfg.policy.input_features
}
# Teleop action is the action that was executed in the environment
# It is either the action from the teleop device or the action from the policy
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
reward = new_transition[TransitionKey.REWARD]
done = new_transition.get(TransitionKey.DONE, False)
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
sum_reward_episode += float(reward)
episode_total_steps += 1
# Check for intervention from transition info
intervention_info = new_transition[TransitionKey.INFO]
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
episode_intervention = True
episode_intervention_steps += 1
complementary_info = {
"discrete_penalty": torch.tensor(
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
),
}
# Create transition for learner (convert to old format)
list_transition_to_send_to_learner.append(
Transition(
state=observation,
action=executed_action,
reward=reward,
next_state=next_observation,
done=done,
truncated=truncated,
complementary_info=complementary_info,
)
)
# Update transition for next iteration
transition = new_transition
if done or truncated:
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
if len(list_transition_to_send_to_learner) > 0:
push_transitions_to_transport_queue(
transitions=list_transition_to_send_to_learner,
transitions_queue=transitions_queue,
)
list_transition_to_send_to_learner = []
stats = get_frequency_stats(policy_timer)
policy_timer.reset()
# Calculate intervention rate
intervention_rate = 0.0
if episode_total_steps > 0:
intervention_rate = episode_intervention_steps / episode_total_steps
# Send episodic reward to the learner
interactions_queue.put(
python_object_to_bytes(
{
"Episodic reward": sum_reward_episode,
"Interaction step": interaction_step,
"Episode intervention": int(episode_intervention),
"Intervention rate": intervention_rate,
**stats,
}
)
)
# Reset intervention counters and environment
sum_reward_episode = 0.0
episode_intervention = False
episode_intervention_steps = 0
episode_total_steps = 0
# Reset environment and processors
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
if cfg.env.fps is not None:
dt_time = time.perf_counter() - start_time
busy_wait(1 / cfg.env.fps - dt_time)
# Communication Functions - Group all gRPC/messaging functions
def establish_learner_connection(
stub: services_pb2_grpc.LearnerServiceStub,
shutdown_event: Event, # type: ignore
attempts: int = 30,
):
"""Establish a connection with the learner.
Args:
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
shutdown_event (Event): The event to check if the connection should be established.
attempts (int): The number of attempts to establish the connection.
Returns:
bool: True if the connection is established, False otherwise.
"""
for _ in range(attempts):
if shutdown_event.is_set():
logging.info("[ACTOR] Shutting down establish_learner_connection")
return False
# Force a connection attempt and check state
try:
logging.info("[ACTOR] Send ready message to Learner")
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
return True
except grpc.RpcError as e:
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
time.sleep(2)
return False
@lru_cache(maxsize=1)
def learner_service_client(
host: str = "127.0.0.1",
port: int = 50051,
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
"""
Returns a client for the learner service.
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
So we need to create only one client and reuse it.
"""
channel = grpc.insecure_channel(
f"{host}:{port}",
grpc_channel_options(),
)
stub = services_pb2_grpc.LearnerServiceStub(channel)
logging.info("[ACTOR] Learner service client created")
return stub, channel
def receive_policy(
cfg: TrainRLServerPipelineConfig,
parameters_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
):
"""Receive parameters from the learner.
Args:
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
parameters_queue (Queue): The queue to receive the parameters.
shutdown_event (Event): The event to check if the process should shutdown.
"""
logging.info("[ACTOR] Start receiving parameters from the Learner")
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor receive policy process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
iterator = learner_client.StreamParameters(services_pb2.Empty())
receive_bytes_in_chunks(
iterator,
parameters_queue,
shutdown_event,
log_prefix="[ACTOR] parameters",
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Received policy loop stopped")
def send_transitions(
cfg: TrainRLServerPipelineConfig,
transitions_queue: Queue,
shutdown_event: any, # Event,
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends transitions to the learner.
This function continuously retrieves messages from the queue and processes:
- Transition Data:
- A batch of transitions (observation, action, reward, next observation) is collected.
- Transitions are moved to the CPU and serialized using PyTorch.
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor transitions process logging initialized")
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendTransitions(
transitions_stream(
shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout
)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming transitions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Transitions process stopped")
def send_interactions(
cfg: TrainRLServerPipelineConfig,
interactions_queue: Queue,
shutdown_event: Event, # type: ignore
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
grpc_channel: grpc.Channel | None = None,
) -> services_pb2.Empty:
"""
Sends interactions to the learner.
This function continuously retrieves messages from the queue and processes:
- Interaction Messages:
- Contains useful statistics about episodic rewards and policy timings.
- The message is serialized using `pickle` and sent to the learner.
"""
if not use_threads(cfg):
# Create a process-specific log file
log_dir = os.path.join(cfg.output_dir, "logs")
os.makedirs(log_dir, exist_ok=True)
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
# Initialize logging with explicit log file
init_logging(log_file=log_file, display_pid=True)
logging.info("Actor interactions process logging initialized")
# Setup process handlers to handle shutdown signal
# But use shutdown event from the main process
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
if grpc_channel is None or learner_client is None:
learner_client, grpc_channel = learner_service_client(
host=cfg.policy.actor_learner_config.learner_host,
port=cfg.policy.actor_learner_config.learner_port,
)
try:
learner_client.SendInteractions(
interactions_stream(
shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout
)
)
except grpc.RpcError as e:
logging.error(f"[ACTOR] gRPC error: {e}")
logging.info("[ACTOR] Finished streaming interactions")
if not use_threads(cfg):
grpc_channel.close()
logging.info("[ACTOR] Interactions process stopped")
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
while not shutdown_event.is_set():
try:
message = transitions_queue.get(block=True, timeout=timeout)
except Empty:
logging.debug("[ACTOR] Transition queue is empty")
continue
yield from send_bytes_in_chunks(
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
)
return services_pb2.Empty()
def interactions_stream(
shutdown_event: Event,
interactions_queue: Queue,
timeout: float, # type: ignore
) -> services_pb2.Empty:
while not shutdown_event.is_set():
try:
message = interactions_queue.get(block=True, timeout=timeout)
except Empty:
logging.debug("[ACTOR] Interaction queue is empty")
continue
yield from send_bytes_in_chunks(
message,
services_pb2.InteractionMessage,
log_prefix="[ACTOR] Send interactions",
)
return services_pb2.Empty()
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dicts = bytes_to_state_dict(bytes_state_dict)
# TODO: check encoder parameter synchronization possible issues:
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
# instead of the updated encoder params from critic (which is optimized separately)
# 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
# 3. Need to handle encoder params correctly for both actor and discrete_critic
# Potential fixes:
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
# Utilities functions
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
"""Send transitions to learner in smaller chunks to avoid network issues.
Args:
transitions: List of transitions to send
message_queue: Queue to send messages to learner
chunk_size: Size of each chunk to send
"""
transition_to_send_to_learner = []
for transition in transitions:
tr = move_transition_to_device(transition=transition, device="cpu")
for key, value in tr["state"].items():
if torch.isnan(value).any():
logging.warning(f"Found NaN values in transition {key}")
transition_to_send_to_learner.append(tr)
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
"""Get the frequency statistics of the policy.
Args:
timer (TimerManager): The timer with collected metrics.
Returns:
dict[str, float]: The frequency statistics of the policy.
"""
stats = {}
if timer.count > 1:
avg_fps = timer.fps_avg
p90_fps = timer.fps_percentile(90)
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
stats = {
"Policy frequency [Hz]": avg_fps,
"Policy frequency 90th-p [Hz]": p90_fps,
}
return stats
def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int):
if policy_fps < cfg.env.fps:
logging.warning(
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
)
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.actor == "threads"
if __name__ == "__main__":
actor_cli()
+313
View File
@@ -0,0 +1,313 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import json
from copy import deepcopy
from pathlib import Path
import cv2
import torch
import torchvision.transforms.functional as F # type: ignore # noqa: N812
from tqdm import tqdm # type: ignore
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def select_rect_roi(img):
"""
Allows the user to draw a rectangular ROI on the image.
The user must click and drag to draw the rectangle.
- While dragging, the rectangle is dynamically drawn.
- On mouse button release, the rectangle is fixed.
- Press 'c' to confirm the selection.
- Press 'r' to reset the selection.
- Press ESC to cancel.
Returns:
A tuple (top, left, height, width) representing the rectangular ROI,
or None if no valid ROI is selected.
"""
# Create a working copy of the image
clone = img.copy()
working_img = clone.copy()
roi = None # Will store the final ROI as (top, left, height, width)
drawing = False
index_x, index_y = -1, -1 # Initial click coordinates
def mouse_callback(event, x, y, flags, param):
nonlocal index_x, index_y, drawing, roi, working_img
if event == cv2.EVENT_LBUTTONDOWN:
# Start drawing: record starting coordinates
drawing = True
index_x, index_y = x, y
elif event == cv2.EVENT_MOUSEMOVE:
if drawing:
# Compute the top-left and bottom-right corners regardless of drag direction
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
# Show a temporary image with the current rectangle drawn
temp = working_img.copy()
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", temp)
elif event == cv2.EVENT_LBUTTONUP:
# Finish drawing
drawing = False
top = min(index_y, y)
left = min(index_x, x)
bottom = max(index_y, y)
right = max(index_x, x)
height = bottom - top
width = right - left
roi = (top, left, height, width) # (top, left, height, width)
# Draw the final rectangle on the working image and display it
working_img = clone.copy()
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
cv2.imshow("Select ROI", working_img)
# Create the window and set the callback
cv2.namedWindow("Select ROI")
cv2.setMouseCallback("Select ROI", mouse_callback)
cv2.imshow("Select ROI", working_img)
print("Instructions for ROI selection:")
print(" - Click and drag to draw a rectangular ROI.")
print(" - Press 'c' to confirm the selection.")
print(" - Press 'r' to reset and draw again.")
print(" - Press ESC to cancel the selection.")
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
while True:
key = cv2.waitKey(1) & 0xFF
# Confirm ROI if one has been drawn
if key == ord("c") and roi is not None:
break
# Reset: clear the ROI and restore the original image
elif key == ord("r"):
working_img = clone.copy()
roi = None
cv2.imshow("Select ROI", working_img)
# Cancel selection for this image
elif key == 27: # ESC key
roi = None
break
cv2.destroyWindow("Select ROI")
return roi
def select_square_roi_for_images(images: dict) -> dict:
"""
For each image in the provided dictionary, open a window to allow the user
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
(top, left, height, width) representing the ROI.
Parameters:
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
Returns:
dict: Mapping of image keys to the selected rectangular ROI.
"""
selected_rois = {}
for key, img in images.items():
if img is None:
print(f"Image for key '{key}' is None, skipping.")
continue
print(f"\nSelect rectangular ROI for image with key: '{key}'")
roi = select_rect_roi(img)
if roi is None:
print(f"No valid ROI selected for '{key}'.")
else:
selected_rois[key] = roi
print(f"ROI for '{key}': {roi}")
return selected_rois
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
"""
Find the first row in the dataset and extract the image in order to be used for the crop.
"""
row = dataset[0]
image_dict = {}
for k in row:
if "image" in k:
image_dict[k] = deepcopy(row[k])
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: dict[str, tuple[int, int, int, int]],
new_repo_id: str,
new_dataset_root: str,
resize_size: tuple[int, int] = (128, 128),
push_to_hub: bool = False,
task: str = "",
) -> LeRobotDataset:
"""
Converts an existing LeRobotDataset by iterating over its episodes and frames,
applying cropping and resizing to image observations, and saving a new dataset
with the transformed data.
Args:
original_dataset (LeRobotDataset): The source dataset.
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
A dictionary mapping observation keys to crop parameters (top, left, height, width).
new_repo_id (str): Repository id for the new dataset.
new_dataset_root (str): The root directory where the new dataset will be written.
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
Defaults to (128, 128).
Returns:
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
and resized.
"""
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
use_videos=len(original_dataset.meta.video_keys) > 0,
)
# Update the metadata for every image key that will be cropped:
# (Here we simply set the shape to be the final resize_size.)
for key in crop_params_dict:
if key in new_dataset.meta.info["features"]:
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
prev_episode_index = 0
for frame_idx in tqdm(range(len(original_dataset))):
frame = original_dataset[frame_idx]
# Create a copy of the frame to add to the new dataset
new_frame = {}
for key, value in frame.items():
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
continue
if key in ("next.done", "next.reward"):
# if not isinstance(value, str) and len(value.shape) == 0:
value = value.unsqueeze(0)
if key in crop_params_dict:
top, left, height, width = crop_params_dict[key]
# Apply crop then resize.
cropped = F.crop(value, top, left, height, width)
value = F.resize(cropped, resize_size)
value = value.clamp(0, 1)
if key.startswith("complementary_info") and isinstance(value, torch.Tensor) and value.dim() == 0:
value = value.unsqueeze(0)
new_frame[key] = value
new_frame["task"] = task
new_dataset.add_frame(new_frame)
if frame["episode_index"].item() != prev_episode_index:
# Save the episode
new_dataset.save_episode()
prev_episode_index = frame["episode_index"].item()
# Save the last episode
new_dataset.save_episode()
if push_to_hub:
new_dataset.push_to_hub()
return new_dataset
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
parser.add_argument(
"--repo-id",
type=str,
default="lerobot",
help="The repository id of the LeRobot dataset to process.",
)
parser.add_argument(
"--root",
type=str,
default=None,
help="The root directory of the LeRobot dataset.",
)
parser.add_argument(
"--crop-params-path",
type=str,
default=None,
help="The path to the JSON file containing the ROIs.",
)
parser.add_argument(
"--push-to-hub",
action="store_true",
help="Whether to push the new dataset to the hub.",
)
parser.add_argument(
"--task",
type=str,
default="",
help="The natural language task to describe the dataset.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
images = get_image_from_lerobot_dataset(dataset)
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
if args.crop_params_path is None:
rois = select_square_roi_for_images(images)
else:
with open(args.crop_params_path) as f:
rois = json.load(f)
# Print the selected rectangular ROIs
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
new_dataset_root=new_dataset_root,
resize_size=(128, 128),
push_to_hub=args.push_to_hub,
task=args.task,
)
meta_dir = new_dataset_root / "meta"
meta_dir.mkdir(exist_ok=True)
with open(meta_dir / "crop_params.json", "w") as f:
json.dump(rois, f, indent=4)
+75
View File
@@ -0,0 +1,75 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
so100_follower,
)
from lerobot.teleoperators import (
gamepad, # noqa: F401
so101_leader, # noqa: F401
)
from .gym_manipulator import make_robot_env
logging.basicConfig(level=logging.INFO)
def eval_policy(env, policy, n_episodes):
sum_reward_episode = []
for _ in range(n_episodes):
obs, _ = env.reset()
episode_reward = 0.0
while True:
action = policy.select_action(obs)
obs, reward, terminated, truncated, _ = env.step(action)
episode_reward += reward
if terminated or truncated:
break
sum_reward_episode.append(episode_reward)
logging.info(f"Success after 20 steps {sum_reward_episode}")
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
@parser.wrap()
def main(cfg: TrainRLServerPipelineConfig):
env_cfg = cfg.env
env = make_robot_env(env_cfg)
dataset_cfg = cfg.dataset
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
dataset_meta = dataset.meta
policy = make_policy(
cfg=cfg.policy,
# env_cfg=cfg.env,
ds_meta=dataset_meta,
)
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
policy.eval()
eval_policy(env, policy=policy, n_episodes=10)
if __name__ == "__main__":
main()
+769
View File
@@ -0,0 +1,769 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from dataclasses import dataclass
from typing import Any
import gymnasium as gym
import numpy as np
import torch
from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import HILSerlRobotEnvConfig
from lerobot.model.kinematics import RobotKinematics
from lerobot.processor import (
AddBatchDimensionProcessorStep,
AddTeleopActionAsComplimentaryDataStep,
AddTeleopEventsAsInfoStep,
DataProcessorPipeline,
DeviceProcessorStep,
EnvTransition,
GripperPenaltyProcessorStep,
ImageCropResizeProcessorStep,
InterventionActionProcessorStep,
JointVelocityProcessorStep,
MapDeltaActionToRobotActionStep,
MapTensorToDeltaActionDictStep,
MotorCurrentProcessorStep,
Numpy2TorchActionProcessorStep,
RewardClassifierProcessorStep,
RobotActionToPolicyActionProcessorStep,
TimeLimitProcessorStep,
Torch2NumpyActionProcessorStep,
TransitionKey,
VanillaObservationProcessorStep,
create_transition,
)
from lerobot.processor.converters import identity_transition
from lerobot.robots import ( # noqa: F401
RobotConfig,
make_robot_from_config,
so100_follower,
)
from lerobot.robots.robot import Robot
from lerobot.robots.so100_follower.robot_kinematic_processor import (
EEBoundsAndSafety,
EEReferenceAndDelta,
ForwardKinematicsJointsToEEObservation,
GripperVelocityToJoint,
InverseKinematicsRLStep,
)
from lerobot.teleoperators import (
gamepad, # noqa: F401
keyboard, # noqa: F401
make_teleoperator_from_config,
so101_leader, # noqa: F401
)
from lerobot.teleoperators.teleoperator import Teleoperator
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.robot_utils import busy_wait
from lerobot.utils.utils import log_say
logging.basicConfig(level=logging.INFO)
@dataclass
class DatasetConfig:
"""Configuration for dataset creation and management."""
repo_id: str
task: str
root: str | None = None
num_episodes_to_record: int = 5
replay_episode: int | None = None
push_to_hub: bool = False
@dataclass
class GymManipulatorConfig:
"""Main configuration for gym manipulator environment."""
env: HILSerlRobotEnvConfig
dataset: DatasetConfig
mode: str | None = None # Either "record", "replay", None
device: str = "cpu"
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
"""Reset robot arm to target position using smooth trajectory."""
current_position_dict = robot_arm.bus.sync_read("Present_Position")
current_position = np.array(
[current_position_dict[name] for name in current_position_dict], dtype=np.float32
)
trajectory = torch.from_numpy(
np.linspace(current_position, target_position, 50)
) # NOTE: 30 is just an arbitrary number
for pose in trajectory:
action_dict = dict(zip(current_position_dict, pose, strict=False))
robot_arm.bus.sync_write("Goal_Position", action_dict)
busy_wait(0.015)
class RobotEnv(gym.Env):
"""Gym environment for robotic control with human intervention support."""
def __init__(
self,
robot,
use_gripper: bool = False,
display_cameras: bool = False,
reset_pose: list[float] | None = None,
reset_time_s: float = 5.0,
) -> None:
"""Initialize robot environment with configuration options.
Args:
robot: Robot interface for hardware communication.
use_gripper: Whether to include gripper in action space.
display_cameras: Whether to show camera feeds during execution.
reset_pose: Joint positions for environment reset.
reset_time_s: Time to wait during reset.
"""
super().__init__()
self.robot = robot
self.display_cameras = display_cameras
# Connect to the robot if not already connected.
if not self.robot.is_connected:
self.robot.connect()
# Episode tracking.
self.current_step = 0
self.episode_data = None
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
self._image_keys = self.robot.cameras.keys()
self.reset_pose = reset_pose
self.reset_time_s = reset_time_s
self.use_gripper = use_gripper
self._joint_names = list(self.robot.bus.motors.keys())
self._raw_joint_positions = None
self._setup_spaces()
def _get_observation(self) -> dict[str, Any]:
"""Get current robot observation including joint positions and camera images."""
obs_dict = self.robot.get_observation()
raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names}
joint_positions = np.array([raw_joint_joint_position[f"{name}.pos"] for name in self._joint_names])
images = {key: obs_dict[key] for key in self._image_keys}
return {"agent_pos": joint_positions, "pixels": images, **raw_joint_joint_position}
def _setup_spaces(self) -> None:
"""Configure observation and action spaces based on robot capabilities."""
current_observation = self._get_observation()
observation_spaces = {}
# Define observation spaces for images and other states.
if current_observation is not None and "pixels" in current_observation:
prefix = "observation.images"
observation_spaces = {
f"{prefix}.{key}": gym.spaces.Box(
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
)
for key in current_observation["pixels"]
}
if current_observation is not None:
agent_pos = current_observation["agent_pos"]
observation_spaces["observation.state"] = gym.spaces.Box(
low=0,
high=10,
shape=agent_pos.shape,
dtype=np.float32,
)
self.observation_space = gym.spaces.Dict(observation_spaces)
# Define the action space for joint positions along with setting an intervention flag.
action_dim = 3
bounds = {}
bounds["min"] = -np.ones(action_dim)
bounds["max"] = np.ones(action_dim)
if self.use_gripper:
action_dim += 1
bounds["min"] = np.concatenate([bounds["min"], [0]])
bounds["max"] = np.concatenate([bounds["max"], [2]])
self.action_space = gym.spaces.Box(
low=bounds["min"],
high=bounds["max"],
shape=(action_dim,),
dtype=np.float32,
)
def reset(
self, *, seed: int | None = None, options: dict[str, Any] | None = None
) -> tuple[dict[str, Any], dict[str, Any]]:
"""Reset environment to initial state.
Args:
seed: Random seed for reproducibility.
options: Additional reset options.
Returns:
Tuple of (observation, info) dictionaries.
"""
# Reset the robot
# self.robot.reset()
start_time = time.perf_counter()
if self.reset_pose is not None:
log_say("Reset the environment.", play_sounds=True)
reset_follower_position(self.robot, np.array(self.reset_pose))
log_say("Reset the environment done.", play_sounds=True)
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
super().reset(seed=seed, options=options)
# Reset episode tracking variables.
self.current_step = 0
self.episode_data = None
obs = self._get_observation()
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
return obs, {TeleopEvents.IS_INTERVENTION: False}
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
"""Execute one environment step with given action."""
joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
self.robot.send_action(joint_targets_dict)
obs = self._get_observation()
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
if self.display_cameras:
self.render()
self.current_step += 1
reward = 0.0
terminated = False
truncated = False
return (
obs,
reward,
terminated,
truncated,
{TeleopEvents.IS_INTERVENTION: False},
)
def render(self) -> None:
"""Display robot camera feeds."""
import cv2
current_observation = self._get_observation()
if current_observation is not None:
image_keys = [key for key in current_observation if "image" in key]
for key in image_keys:
cv2.imshow(key, cv2.cvtColor(current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
cv2.waitKey(1)
def close(self) -> None:
"""Close environment and disconnect robot."""
if self.robot.is_connected:
self.robot.disconnect()
def get_raw_joint_positions(self) -> dict[str, float]:
"""Get raw joint positions."""
return self._raw_joint_positions
def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
"""Create robot environment from configuration.
Args:
cfg: Environment configuration.
Returns:
Tuple of (gym environment, teleoperator device).
"""
# Check if this is a GymHIL simulation environment
if cfg.name == "gym_hil":
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
import gym_hil # noqa: F401
# Extract gripper settings with defaults
use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
gripper_penalty = cfg.processor.gripper.gripper_penalty if cfg.processor.gripper is not None else 0.0
env = gym.make(
f"gym_hil/{cfg.task}",
image_obs=True,
render_mode="human",
use_gripper=use_gripper,
gripper_penalty=gripper_penalty,
)
return env, None
# Real robot environment
assert cfg.robot is not None, "Robot config must be provided for real robot environment"
assert cfg.teleop is not None, "Teleop config must be provided for real robot environment"
robot = make_robot_from_config(cfg.robot)
teleop_device = make_teleoperator_from_config(cfg.teleop)
teleop_device.connect()
# Create base environment with safe defaults
use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
display_cameras = (
cfg.processor.observation.display_cameras if cfg.processor.observation is not None else False
)
reset_pose = cfg.processor.reset.fixed_reset_joint_positions if cfg.processor.reset is not None else None
env = RobotEnv(
robot=robot,
use_gripper=use_gripper,
display_cameras=display_cameras,
reset_pose=reset_pose,
)
return env, teleop_device
def make_processors(
env: gym.Env, teleop_device: Teleoperator | None, cfg: HILSerlRobotEnvConfig, device: str = "cpu"
) -> tuple[
DataProcessorPipeline[EnvTransition, EnvTransition], DataProcessorPipeline[EnvTransition, EnvTransition]
]:
"""Create environment and action processors.
Args:
env: Robot environment instance.
teleop_device: Teleoperator device for intervention.
cfg: Processor configuration.
device: Target device for computations.
Returns:
Tuple of (environment processor, action processor).
"""
terminate_on_success = (
cfg.processor.reset.terminate_on_success if cfg.processor.reset is not None else True
)
if cfg.name == "gym_hil":
action_pipeline_steps = [
InterventionActionProcessorStep(terminate_on_success=terminate_on_success),
Torch2NumpyActionProcessorStep(),
]
env_pipeline_steps = [
Numpy2TorchActionProcessorStep(),
VanillaObservationProcessorStep(),
AddBatchDimensionProcessorStep(),
DeviceProcessorStep(device=device),
]
return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
), DataProcessorPipeline(
steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
)
# Full processor pipeline for real robot environment
# Get robot and motor information for kinematics
motor_names = list(env.robot.bus.motors.keys())
# Set up kinematics solver if inverse kinematics is configured
kinematics_solver = None
if cfg.processor.inverse_kinematics is not None:
kinematics_solver = RobotKinematics(
urdf_path=cfg.processor.inverse_kinematics.urdf_path,
target_frame_name=cfg.processor.inverse_kinematics.target_frame_name,
joint_names=motor_names,
)
env_pipeline_steps = [VanillaObservationProcessorStep()]
if cfg.processor.observation is not None:
if cfg.processor.observation.add_joint_velocity_to_observation:
env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps))
if cfg.processor.observation.add_current_to_observation:
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
if kinematics_solver is not None:
env_pipeline_steps.append(
ForwardKinematicsJointsToEEObservation(
kinematics=kinematics_solver,
motor_names=motor_names,
)
)
if cfg.processor.image_preprocessing is not None:
env_pipeline_steps.append(
ImageCropResizeProcessorStep(
crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict,
resize_size=cfg.processor.image_preprocessing.resize_size,
)
)
# Add time limit processor if reset config exists
if cfg.processor.reset is not None:
env_pipeline_steps.append(
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
)
# Add gripper penalty processor if gripper config exists and enabled
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
env_pipeline_steps.append(
GripperPenaltyProcessorStep(
penalty=cfg.processor.gripper.gripper_penalty,
max_gripper_pos=cfg.processor.max_gripper_pos,
)
)
if (
cfg.processor.reward_classifier is not None
and cfg.processor.reward_classifier.pretrained_path is not None
):
env_pipeline_steps.append(
RewardClassifierProcessorStep(
pretrained_path=cfg.processor.reward_classifier.pretrained_path,
device=device,
success_threshold=cfg.processor.reward_classifier.success_threshold,
success_reward=cfg.processor.reward_classifier.success_reward,
terminate_on_success=terminate_on_success,
)
)
env_pipeline_steps.append(AddBatchDimensionProcessorStep())
env_pipeline_steps.append(DeviceProcessorStep(device=device))
action_pipeline_steps = [
AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device),
AddTeleopEventsAsInfoStep(teleop_device=teleop_device),
InterventionActionProcessorStep(
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False,
terminate_on_success=terminate_on_success,
),
]
# Replace InverseKinematicsProcessor with new kinematic processors
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
# Add EE bounds and safety processor
inverse_kinematics_steps = [
MapTensorToDeltaActionDictStep(
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
),
MapDeltaActionToRobotActionStep(),
EEReferenceAndDelta(
kinematics=kinematics_solver,
end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes,
motor_names=motor_names,
use_latched_reference=False,
use_ik_solution=True,
),
EEBoundsAndSafety(
end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds,
),
GripperVelocityToJoint(
clip_max=cfg.processor.max_gripper_pos,
speed_factor=1.0,
discrete_gripper=True,
),
InverseKinematicsRLStep(
kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False
),
]
action_pipeline_steps.extend(inverse_kinematics_steps)
action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names))
return DataProcessorPipeline(
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
), DataProcessorPipeline(
steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
)
def step_env_and_process_transition(
env: gym.Env,
transition: EnvTransition,
action: torch.Tensor,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
) -> EnvTransition:
"""
Execute one step with processor pipeline.
Args:
env: The robot environment
transition: Current transition state
action: Action to execute
env_processor: Environment processor
action_processor: Action processor
Returns:
Processed transition with updated state.
"""
# Create action transition
transition[TransitionKey.ACTION] = action
transition[TransitionKey.OBSERVATION] = (
env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}
)
processed_action_transition = action_processor(transition)
processed_action = processed_action_transition[TransitionKey.ACTION]
obs, reward, terminated, truncated, info = env.step(processed_action)
reward = reward + processed_action_transition[TransitionKey.REWARD]
terminated = terminated or processed_action_transition[TransitionKey.DONE]
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
new_info = processed_action_transition[TransitionKey.INFO].copy()
new_info.update(info)
new_transition = create_transition(
observation=obs,
action=processed_action,
reward=reward,
done=terminated,
truncated=truncated,
info=new_info,
complementary_data=complementary_data,
)
new_transition = env_processor(new_transition)
return new_transition
def control_loop(
env: gym.Env,
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
teleop_device: Teleoperator,
cfg: GymManipulatorConfig,
) -> None:
"""Main control loop for robot environment interaction.
if cfg.mode == "record": then a dataset will be created and recorded
Args:
env: The robot environment
env_processor: Environment processor
action_processor: Action processor
teleop_device: Teleoperator device
cfg: gym_manipulator configuration
"""
dt = 1.0 / cfg.env.fps
print(f"Starting control loop at {cfg.env.fps} FPS")
print("Controls:")
print("- Use gamepad/teleop device for intervention")
print("- When not intervening, robot will stay still")
print("- Press Ctrl+C to exit")
# Reset environment and processors
obs, info = env.reset()
complementary_data = (
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
)
env_processor.reset()
action_processor.reset()
# Process initial observation
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
transition = env_processor(data=transition)
# Determine if gripper is used
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
dataset = None
if cfg.mode == "record":
action_features = teleop_device.action_features
features = {
"action": action_features,
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
}
if use_gripper:
features["complementary_info.discrete_penalty"] = {
"dtype": "float32",
"shape": (1,),
"names": ["discrete_penalty"],
}
for key, value in transition[TransitionKey.OBSERVATION].items():
if key == "observation.state":
features[key] = {
"dtype": "float32",
"shape": value.squeeze(0).shape,
"names": None,
}
if "image" in key:
features[key] = {
"dtype": "video",
"shape": value.squeeze(0).shape,
"names": ["channels", "height", "width"],
}
# Create dataset
dataset = LeRobotDataset.create(
cfg.dataset.repo_id,
cfg.env.fps,
root=cfg.dataset.root,
use_videos=True,
image_writer_threads=4,
image_writer_processes=0,
features=features,
)
episode_idx = 0
episode_step = 0
episode_start_time = time.perf_counter()
while episode_idx < cfg.dataset.num_episodes_to_record:
step_start_time = time.perf_counter()
# Create a neutral action (no movement)
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
if use_gripper:
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
# Use the new step function
transition = step_env_and_process_transition(
env=env,
transition=transition,
action=neutral_action,
env_processor=env_processor,
action_processor=action_processor,
)
terminated = transition.get(TransitionKey.DONE, False)
truncated = transition.get(TransitionKey.TRUNCATED, False)
if cfg.mode == "record":
observations = {
k: v.squeeze(0).cpu()
for k, v in transition[TransitionKey.OBSERVATION].items()
if isinstance(v, torch.Tensor)
}
# Use teleop_action if available, otherwise use the action from the transition
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
"teleop_action", transition[TransitionKey.ACTION]
)
frame = {
**observations,
"action": action_to_record.cpu(),
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
"next.done": np.array([terminated or truncated], dtype=bool),
}
if use_gripper:
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
if dataset is not None:
frame["task"] = cfg.dataset.task
dataset.add_frame(frame)
episode_step += 1
# Handle episode termination
if terminated or truncated:
episode_time = time.perf_counter() - episode_start_time
logging.info(
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
)
episode_step = 0
episode_idx += 1
if dataset is not None:
if transition[TransitionKey.INFO].get("rerecord_episode", False):
logging.info(f"Re-recording episode {episode_idx}")
dataset.clear_episode_buffer()
episode_idx -= 1
else:
logging.info(f"Saving episode {episode_idx}")
dataset.save_episode()
# Reset for new episode
obs, info = env.reset()
env_processor.reset()
action_processor.reset()
transition = create_transition(observation=obs, info=info)
transition = env_processor(transition)
# Maintain fps timing
busy_wait(dt - (time.perf_counter() - step_start_time))
if dataset is not None and cfg.dataset.push_to_hub:
logging.info("Pushing dataset to hub")
dataset.push_to_hub()
def replay_trajectory(
env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig
) -> None:
"""Replay recorded trajectory on robot environment."""
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
dataset = LeRobotDataset(
cfg.dataset.repo_id,
root=cfg.dataset.root,
episodes=[cfg.dataset.replay_episode],
download_videos=False,
)
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
actions = episode_frames.select_columns("action")
_, info = env.reset()
for action_data in actions:
start_time = time.perf_counter()
transition = create_transition(
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
action=action_data["action"],
)
transition = action_processor(transition)
env.step(transition[TransitionKey.ACTION])
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
@parser.wrap()
def main(cfg: GymManipulatorConfig) -> None:
"""Main entry point for gym manipulator script."""
env, teleop_device = make_robot_env(cfg.env)
env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device)
print("Environment observation space:", env.observation_space)
print("Environment action space:", env.action_space)
print("Environment processor:", env_processor)
print("Action processor:", action_processor)
if cfg.mode == "replay":
replay_trajectory(env, action_processor, cfg)
exit()
control_loop(env, env_processor, action_processor, teleop_device, cfg)
if __name__ == "__main__":
main()
File diff suppressed because it is too large Load Diff
+117
View File
@@ -0,0 +1,117 @@
# !/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from multiprocessing import Event, Queue
from lerobot.transport import services_pb2, services_pb2_grpc
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
from lerobot.utils.queue import get_last_item_from_queue
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
SHUTDOWN_TIMEOUT = 10
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
"""
Implementation of the LearnerService gRPC service
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
check transport.proto for the gRPC service definition
"""
def __init__(
self,
shutdown_event: Event, # type: ignore
parameters_queue: Queue,
seconds_between_pushes: float,
transition_queue: Queue,
interaction_message_queue: Queue,
queue_get_timeout: float = 0.001,
):
self.shutdown_event = shutdown_event
self.parameters_queue = parameters_queue
self.seconds_between_pushes = seconds_between_pushes
self.transition_queue = transition_queue
self.interaction_message_queue = interaction_message_queue
self.queue_get_timeout = queue_get_timeout
def StreamParameters(self, request, context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to stream parameters from the Actor")
last_push_time = 0
while not self.shutdown_event.is_set():
time_since_last_push = time.time() - last_push_time
if time_since_last_push < self.seconds_between_pushes:
self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push)
# Continue, because we could receive a shutdown event,
# and it's checked in the while loop
continue
logging.info("[LEARNER] Push parameters to the Actor")
buffer = get_last_item_from_queue(
self.parameters_queue, block=True, timeout=self.queue_get_timeout
)
if buffer is None:
continue
yield from send_bytes_in_chunks(
buffer,
services_pb2.Parameters,
log_prefix="[LEARNER] Sending parameters",
silent=True,
)
last_push_time = time.time()
logging.info("[LEARNER] Parameters sent")
logging.info("[LEARNER] Stream parameters finished")
return services_pb2.Empty()
def SendTransitions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive transitions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.transition_queue,
self.shutdown_event,
log_prefix="[LEARNER] transitions",
)
logging.debug("[LEARNER] Finished receiving transitions")
return services_pb2.Empty()
def SendInteractions(self, request_iterator, _context): # noqa: N802
# TODO: authorize the request
logging.info("[LEARNER] Received request to receive interactions from the Actor")
receive_bytes_in_chunks(
request_iterator,
self.interaction_message_queue,
self.shutdown_event,
log_prefix="[LEARNER] interactions",
)
logging.debug("[LEARNER] Finished receiving interactions")
return services_pb2.Empty()
def Ready(self, request, context): # noqa: N802
return services_pb2.Empty()