mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
chore(rl): move rl related code to its directory at top level (#2002)
* chore(rl): move rl related code to its directory at top level * chore(style): apply pre-commit to renamed headers * test(rl): fix rl imports * docs(rl): update rl headers doc
This commit is contained in:
@@ -0,0 +1,740 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Actor server runner for distributed HILSerl robot policy training.
|
||||
|
||||
This script implements the actor component of the distributed HILSerl architecture.
|
||||
It executes the policy in the robot environment, collects experience,
|
||||
and sends transitions to the learner server for policy updates.
|
||||
|
||||
Examples of usage:
|
||||
|
||||
- Start an actor server for real robot training with human-in-the-loop intervention:
|
||||
```bash
|
||||
python -m lerobot.rl.actor --config_path src/lerobot/configs/train_config_hilserl_so100.json
|
||||
```
|
||||
|
||||
**NOTE**: The actor server requires a running learner server to connect to. Ensure the learner
|
||||
server is started before launching the actor.
|
||||
|
||||
**NOTE**: Human intervention is key to HILSerl training. Press the upper right trigger button on the
|
||||
gamepad to take control of the robot during training. Initially intervene frequently, then gradually
|
||||
reduce interventions as the policy improves.
|
||||
|
||||
**WORKFLOW**:
|
||||
1. Determine robot workspace bounds using `find_joint_limits.py`
|
||||
2. Record demonstrations with `gym_manipulator.py` in record mode
|
||||
3. Process the dataset and determine camera crops with `crop_dataset_roi.py`
|
||||
4. Start the learner server with the training configuration
|
||||
5. Start this actor server with the same configuration
|
||||
6. Use human interventions to guide policy learning
|
||||
|
||||
For more details on the complete HILSerl training workflow, see:
|
||||
https://github.com/michel-aractingi/lerobot-hilserl-guide
|
||||
"""
|
||||
|
||||
import logging
|
||||
import os
|
||||
import time
|
||||
from functools import lru_cache
|
||||
from queue import Empty
|
||||
|
||||
import grpc
|
||||
import torch
|
||||
from torch import nn
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.robots import so100_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so101_leader # noqa: F401
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import (
|
||||
bytes_to_state_dict,
|
||||
grpc_channel_options,
|
||||
python_object_to_bytes,
|
||||
receive_bytes_in_chunks,
|
||||
send_bytes_in_chunks,
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.process import ProcessSignalHandler
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
from lerobot.utils.utils import (
|
||||
TimerManager,
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
)
|
||||
|
||||
from .gym_manipulator import (
|
||||
create_transition,
|
||||
make_processors,
|
||||
make_robot_env,
|
||||
step_env_and_process_transition,
|
||||
)
|
||||
|
||||
ACTOR_SHUTDOWN_TIMEOUT = 30
|
||||
|
||||
# Main entry point
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def actor_cli(cfg: TrainRLServerPipelineConfig):
|
||||
cfg.validate()
|
||||
display_pid = False
|
||||
if not use_threads(cfg):
|
||||
import torch.multiprocessing as mp
|
||||
|
||||
mp.set_start_method("spawn")
|
||||
display_pid = True
|
||||
|
||||
# Create logs directory to ensure it exists
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_{cfg.job_name}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=display_pid)
|
||||
logging.info(f"Actor logging initialized, writing to {log_file}")
|
||||
|
||||
is_threaded = use_threads(cfg)
|
||||
shutdown_event = ProcessSignalHandler(is_threaded, display_pid=display_pid).shutdown_event
|
||||
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
logging.info("[ACTOR] Establishing connection with Learner")
|
||||
if not establish_learner_connection(learner_client, shutdown_event):
|
||||
logging.error("[ACTOR] Failed to establish connection with Learner")
|
||||
return
|
||||
|
||||
if not use_threads(cfg):
|
||||
# If we use multithreading, we can reuse the channel
|
||||
grpc_channel.close()
|
||||
grpc_channel = None
|
||||
|
||||
logging.info("[ACTOR] Connection with Learner established")
|
||||
|
||||
parameters_queue = Queue()
|
||||
transitions_queue = Queue()
|
||||
interactions_queue = Queue()
|
||||
|
||||
concurrency_entity = None
|
||||
if use_threads(cfg):
|
||||
from threading import Thread
|
||||
|
||||
concurrency_entity = Thread
|
||||
else:
|
||||
from multiprocessing import Process
|
||||
|
||||
concurrency_entity = Process
|
||||
|
||||
receive_policy_process = concurrency_entity(
|
||||
target=receive_policy,
|
||||
args=(cfg, parameters_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process = concurrency_entity(
|
||||
target=send_transitions,
|
||||
args=(cfg, transitions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
interactions_process = concurrency_entity(
|
||||
target=send_interactions,
|
||||
args=(cfg, interactions_queue, shutdown_event, grpc_channel),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
transitions_process.start()
|
||||
interactions_process.start()
|
||||
receive_policy_process.start()
|
||||
|
||||
act_with_policy(
|
||||
cfg=cfg,
|
||||
shutdown_event=shutdown_event,
|
||||
parameters_queue=parameters_queue,
|
||||
transitions_queue=transitions_queue,
|
||||
interactions_queue=interactions_queue,
|
||||
)
|
||||
logging.info("[ACTOR] Policy process joined")
|
||||
|
||||
logging.info("[ACTOR] Closing queues")
|
||||
transitions_queue.close()
|
||||
interactions_queue.close()
|
||||
parameters_queue.close()
|
||||
|
||||
transitions_process.join()
|
||||
logging.info("[ACTOR] Transitions process joined")
|
||||
interactions_process.join()
|
||||
logging.info("[ACTOR] Interactions process joined")
|
||||
receive_policy_process.join()
|
||||
logging.info("[ACTOR] Receive policy process joined")
|
||||
|
||||
logging.info("[ACTOR] join queues")
|
||||
transitions_queue.cancel_join_thread()
|
||||
interactions_queue.cancel_join_thread()
|
||||
parameters_queue.cancel_join_thread()
|
||||
|
||||
logging.info("[ACTOR] queues closed")
|
||||
|
||||
|
||||
# Core algorithm functions
|
||||
|
||||
|
||||
def act_with_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
shutdown_event: any, # Event,
|
||||
parameters_queue: Queue,
|
||||
transitions_queue: Queue,
|
||||
interactions_queue: Queue,
|
||||
):
|
||||
"""
|
||||
Executes policy interaction within the environment.
|
||||
|
||||
This function rolls out the policy in the environment, collecting interaction data and pushing it to a queue for streaming to the learner.
|
||||
Once an episode is completed, updated network parameters received from the learner are retrieved from a queue and loaded into the network.
|
||||
|
||||
Args:
|
||||
cfg: Configuration settings for the interaction process.
|
||||
shutdown_event: Event to check if the process should shutdown.
|
||||
parameters_queue: Queue to receive updated network parameters from the learner.
|
||||
transitions_queue: Queue to send transitions to the learner.
|
||||
interactions_queue: Queue to send interactions to the learner.
|
||||
"""
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_policy_{os.getpid()}.log")
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor policy process logging initialized")
|
||||
|
||||
logging.info("make_env online")
|
||||
|
||||
online_env, teleop_device = make_robot_env(cfg=cfg.env)
|
||||
env_processor, action_processor = make_processors(online_env, teleop_device, cfg.env, cfg.policy.device)
|
||||
|
||||
set_seed(cfg.seed)
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# NOTE: For the moment we will solely handle the case of a single environment
|
||||
sum_reward_episode = 0
|
||||
list_transition_to_send_to_learner = []
|
||||
episode_intervention = False
|
||||
# Add counters for intervention rate calculation
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
policy_timer = TimerManager("Policy inference", log=False)
|
||||
|
||||
for interaction_step in range(cfg.policy.online_steps):
|
||||
start_time = time.perf_counter()
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down act_with_policy")
|
||||
return
|
||||
|
||||
observation = {
|
||||
k: v for k, v in transition[TransitionKey.OBSERVATION].items() if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
log_policy_frequency_issue(policy_fps=policy_fps, cfg=cfg, interaction_step=interaction_step)
|
||||
|
||||
# Use the new step function
|
||||
new_transition = step_env_and_process_transition(
|
||||
env=online_env,
|
||||
transition=transition,
|
||||
action=action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
|
||||
# Extract values from processed transition
|
||||
next_observation = {
|
||||
k: v
|
||||
for k, v in new_transition[TransitionKey.OBSERVATION].items()
|
||||
if k in cfg.policy.input_features
|
||||
}
|
||||
|
||||
# Teleop action is the action that was executed in the environment
|
||||
# It is either the action from the teleop device or the action from the policy
|
||||
executed_action = new_transition[TransitionKey.COMPLEMENTARY_DATA]["teleop_action"]
|
||||
|
||||
reward = new_transition[TransitionKey.REWARD]
|
||||
done = new_transition.get(TransitionKey.DONE, False)
|
||||
truncated = new_transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
sum_reward_episode += float(reward)
|
||||
episode_total_steps += 1
|
||||
|
||||
# Check for intervention from transition info
|
||||
intervention_info = new_transition[TransitionKey.INFO]
|
||||
if intervention_info.get(TeleopEvents.IS_INTERVENTION, False):
|
||||
episode_intervention = True
|
||||
episode_intervention_steps += 1
|
||||
|
||||
complementary_info = {
|
||||
"discrete_penalty": torch.tensor(
|
||||
[new_transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)]
|
||||
),
|
||||
}
|
||||
# Create transition for learner (convert to old format)
|
||||
list_transition_to_send_to_learner.append(
|
||||
Transition(
|
||||
state=observation,
|
||||
action=executed_action,
|
||||
reward=reward,
|
||||
next_state=next_observation,
|
||||
done=done,
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
)
|
||||
|
||||
# Update transition for next iteration
|
||||
transition = new_transition
|
||||
|
||||
if done or truncated:
|
||||
logging.info(f"[ACTOR] Global step {interaction_step}: Episode reward: {sum_reward_episode}")
|
||||
|
||||
update_policy_parameters(policy=policy, parameters_queue=parameters_queue, device=device)
|
||||
|
||||
if len(list_transition_to_send_to_learner) > 0:
|
||||
push_transitions_to_transport_queue(
|
||||
transitions=list_transition_to_send_to_learner,
|
||||
transitions_queue=transitions_queue,
|
||||
)
|
||||
list_transition_to_send_to_learner = []
|
||||
|
||||
stats = get_frequency_stats(policy_timer)
|
||||
policy_timer.reset()
|
||||
|
||||
# Calculate intervention rate
|
||||
intervention_rate = 0.0
|
||||
if episode_total_steps > 0:
|
||||
intervention_rate = episode_intervention_steps / episode_total_steps
|
||||
|
||||
# Send episodic reward to the learner
|
||||
interactions_queue.put(
|
||||
python_object_to_bytes(
|
||||
{
|
||||
"Episodic reward": sum_reward_episode,
|
||||
"Interaction step": interaction_step,
|
||||
"Episode intervention": int(episode_intervention),
|
||||
"Intervention rate": intervention_rate,
|
||||
**stats,
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
# Reset intervention counters and environment
|
||||
sum_reward_episode = 0.0
|
||||
episode_intervention = False
|
||||
episode_intervention_steps = 0
|
||||
episode_total_steps = 0
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
|
||||
def establish_learner_connection(
|
||||
stub: services_pb2_grpc.LearnerServiceStub,
|
||||
shutdown_event: Event, # type: ignore
|
||||
attempts: int = 30,
|
||||
):
|
||||
"""Establish a connection with the learner.
|
||||
|
||||
Args:
|
||||
stub (services_pb2_grpc.LearnerServiceStub): The stub to use for the connection.
|
||||
shutdown_event (Event): The event to check if the connection should be established.
|
||||
attempts (int): The number of attempts to establish the connection.
|
||||
Returns:
|
||||
bool: True if the connection is established, False otherwise.
|
||||
"""
|
||||
for _ in range(attempts):
|
||||
if shutdown_event.is_set():
|
||||
logging.info("[ACTOR] Shutting down establish_learner_connection")
|
||||
return False
|
||||
|
||||
# Force a connection attempt and check state
|
||||
try:
|
||||
logging.info("[ACTOR] Send ready message to Learner")
|
||||
if stub.Ready(services_pb2.Empty()) == services_pb2.Empty():
|
||||
return True
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] Waiting for Learner to be ready... {e}")
|
||||
time.sleep(2)
|
||||
return False
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def learner_service_client(
|
||||
host: str = "127.0.0.1",
|
||||
port: int = 50051,
|
||||
) -> tuple[services_pb2_grpc.LearnerServiceStub, grpc.Channel]:
|
||||
"""
|
||||
Returns a client for the learner service.
|
||||
|
||||
GRPC uses HTTP/2, which is a binary protocol and multiplexes requests over a single connection.
|
||||
So we need to create only one client and reuse it.
|
||||
"""
|
||||
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
grpc_channel_options(),
|
||||
)
|
||||
stub = services_pb2_grpc.LearnerServiceStub(channel)
|
||||
logging.info("[ACTOR] Learner service client created")
|
||||
return stub, channel
|
||||
|
||||
|
||||
def receive_policy(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
parameters_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
):
|
||||
"""Receive parameters from the learner.
|
||||
|
||||
Args:
|
||||
cfg (TrainRLServerPipelineConfig): The configuration for the actor.
|
||||
parameters_queue (Queue): The queue to receive the parameters.
|
||||
shutdown_event (Event): The event to check if the process should shutdown.
|
||||
"""
|
||||
logging.info("[ACTOR] Start receiving parameters from the Learner")
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_receive_policy_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor receive policy process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
iterator = learner_client.StreamParameters(services_pb2.Empty())
|
||||
receive_bytes_in_chunks(
|
||||
iterator,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
log_prefix="[ACTOR] parameters",
|
||||
)
|
||||
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Received policy loop stopped")
|
||||
|
||||
|
||||
def send_transitions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
transitions_queue: Queue,
|
||||
shutdown_event: any, # Event,
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends transitions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Transition Data:
|
||||
- A batch of transitions (observation, action, reward, next observation) is collected.
|
||||
- Transitions are moved to the CPU and serialized using PyTorch.
|
||||
- The serialized data is wrapped in a `services_pb2.Transition` message and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_transitions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor transitions process logging initialized")
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendTransitions(
|
||||
transitions_stream(
|
||||
shutdown_event, transitions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming transitions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Transitions process stopped")
|
||||
|
||||
|
||||
def send_interactions(
|
||||
cfg: TrainRLServerPipelineConfig,
|
||||
interactions_queue: Queue,
|
||||
shutdown_event: Event, # type: ignore
|
||||
learner_client: services_pb2_grpc.LearnerServiceStub | None = None,
|
||||
grpc_channel: grpc.Channel | None = None,
|
||||
) -> services_pb2.Empty:
|
||||
"""
|
||||
Sends interactions to the learner.
|
||||
|
||||
This function continuously retrieves messages from the queue and processes:
|
||||
|
||||
- Interaction Messages:
|
||||
- Contains useful statistics about episodic rewards and policy timings.
|
||||
- The message is serialized using `pickle` and sent to the learner.
|
||||
"""
|
||||
|
||||
if not use_threads(cfg):
|
||||
# Create a process-specific log file
|
||||
log_dir = os.path.join(cfg.output_dir, "logs")
|
||||
os.makedirs(log_dir, exist_ok=True)
|
||||
log_file = os.path.join(log_dir, f"actor_interactions_{os.getpid()}.log")
|
||||
|
||||
# Initialize logging with explicit log file
|
||||
init_logging(log_file=log_file, display_pid=True)
|
||||
logging.info("Actor interactions process logging initialized")
|
||||
|
||||
# Setup process handlers to handle shutdown signal
|
||||
# But use shutdown event from the main process
|
||||
_ = ProcessSignalHandler(use_threads=False, display_pid=True)
|
||||
|
||||
if grpc_channel is None or learner_client is None:
|
||||
learner_client, grpc_channel = learner_service_client(
|
||||
host=cfg.policy.actor_learner_config.learner_host,
|
||||
port=cfg.policy.actor_learner_config.learner_port,
|
||||
)
|
||||
|
||||
try:
|
||||
learner_client.SendInteractions(
|
||||
interactions_stream(
|
||||
shutdown_event, interactions_queue, cfg.policy.actor_learner_config.queue_get_timeout
|
||||
)
|
||||
)
|
||||
except grpc.RpcError as e:
|
||||
logging.error(f"[ACTOR] gRPC error: {e}")
|
||||
|
||||
logging.info("[ACTOR] Finished streaming interactions")
|
||||
|
||||
if not use_threads(cfg):
|
||||
grpc_channel.close()
|
||||
logging.info("[ACTOR] Interactions process stopped")
|
||||
|
||||
|
||||
def transitions_stream(shutdown_event: Event, transitions_queue: Queue, timeout: float) -> services_pb2.Empty: # type: ignore
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = transitions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Transition queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message, services_pb2.Transition, log_prefix="[ACTOR] Send transitions"
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
def interactions_stream(
|
||||
shutdown_event: Event,
|
||||
interactions_queue: Queue,
|
||||
timeout: float, # type: ignore
|
||||
) -> services_pb2.Empty:
|
||||
while not shutdown_event.is_set():
|
||||
try:
|
||||
message = interactions_queue.get(block=True, timeout=timeout)
|
||||
except Empty:
|
||||
logging.debug("[ACTOR] Interaction queue is empty")
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
message,
|
||||
services_pb2.InteractionMessage,
|
||||
log_prefix="[ACTOR] Send interactions",
|
||||
)
|
||||
|
||||
return services_pb2.Empty()
|
||||
|
||||
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dicts = bytes_to_state_dict(bytes_state_dict)
|
||||
|
||||
# TODO: check encoder parameter synchronization possible issues:
|
||||
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
|
||||
# instead of the updated encoder params from critic (which is optimized separately)
|
||||
# 2. When freeze_vision_encoder=True, we waste bandwidth sending/loading frozen params
|
||||
# 3. Need to handle encoder params correctly for both actor and discrete_critic
|
||||
# Potential fixes:
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
|
||||
def push_transitions_to_transport_queue(transitions: list, transitions_queue):
|
||||
"""Send transitions to learner in smaller chunks to avoid network issues.
|
||||
|
||||
Args:
|
||||
transitions: List of transitions to send
|
||||
message_queue: Queue to send messages to learner
|
||||
chunk_size: Size of each chunk to send
|
||||
"""
|
||||
transition_to_send_to_learner = []
|
||||
for transition in transitions:
|
||||
tr = move_transition_to_device(transition=transition, device="cpu")
|
||||
for key, value in tr["state"].items():
|
||||
if torch.isnan(value).any():
|
||||
logging.warning(f"Found NaN values in transition {key}")
|
||||
|
||||
transition_to_send_to_learner.append(tr)
|
||||
|
||||
transitions_queue.put(transitions_to_bytes(transition_to_send_to_learner))
|
||||
|
||||
|
||||
def get_frequency_stats(timer: TimerManager) -> dict[str, float]:
|
||||
"""Get the frequency statistics of the policy.
|
||||
|
||||
Args:
|
||||
timer (TimerManager): The timer with collected metrics.
|
||||
|
||||
Returns:
|
||||
dict[str, float]: The frequency statistics of the policy.
|
||||
"""
|
||||
stats = {}
|
||||
if timer.count > 1:
|
||||
avg_fps = timer.fps_avg
|
||||
p90_fps = timer.fps_percentile(90)
|
||||
logging.debug(f"[ACTOR] Average policy frame rate: {avg_fps}")
|
||||
logging.debug(f"[ACTOR] Policy frame rate 90th percentile: {p90_fps}")
|
||||
stats = {
|
||||
"Policy frequency [Hz]": avg_fps,
|
||||
"Policy frequency 90th-p [Hz]": p90_fps,
|
||||
}
|
||||
return stats
|
||||
|
||||
|
||||
def log_policy_frequency_issue(policy_fps: float, cfg: TrainRLServerPipelineConfig, interaction_step: int):
|
||||
if policy_fps < cfg.env.fps:
|
||||
logging.warning(
|
||||
f"[ACTOR] Policy FPS {policy_fps:.1f} below required {cfg.env.fps} at step {interaction_step}"
|
||||
)
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.actor == "threads"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
actor_cli()
|
||||
@@ -0,0 +1,313 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import json
|
||||
from copy import deepcopy
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import torch
|
||||
import torchvision.transforms.functional as F # type: ignore # noqa: N812
|
||||
from tqdm import tqdm # type: ignore
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def select_rect_roi(img):
|
||||
"""
|
||||
Allows the user to draw a rectangular ROI on the image.
|
||||
|
||||
The user must click and drag to draw the rectangle.
|
||||
- While dragging, the rectangle is dynamically drawn.
|
||||
- On mouse button release, the rectangle is fixed.
|
||||
- Press 'c' to confirm the selection.
|
||||
- Press 'r' to reset the selection.
|
||||
- Press ESC to cancel.
|
||||
|
||||
Returns:
|
||||
A tuple (top, left, height, width) representing the rectangular ROI,
|
||||
or None if no valid ROI is selected.
|
||||
"""
|
||||
# Create a working copy of the image
|
||||
clone = img.copy()
|
||||
working_img = clone.copy()
|
||||
|
||||
roi = None # Will store the final ROI as (top, left, height, width)
|
||||
drawing = False
|
||||
index_x, index_y = -1, -1 # Initial click coordinates
|
||||
|
||||
def mouse_callback(event, x, y, flags, param):
|
||||
nonlocal index_x, index_y, drawing, roi, working_img
|
||||
|
||||
if event == cv2.EVENT_LBUTTONDOWN:
|
||||
# Start drawing: record starting coordinates
|
||||
drawing = True
|
||||
index_x, index_y = x, y
|
||||
|
||||
elif event == cv2.EVENT_MOUSEMOVE:
|
||||
if drawing:
|
||||
# Compute the top-left and bottom-right corners regardless of drag direction
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
# Show a temporary image with the current rectangle drawn
|
||||
temp = working_img.copy()
|
||||
cv2.rectangle(temp, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", temp)
|
||||
|
||||
elif event == cv2.EVENT_LBUTTONUP:
|
||||
# Finish drawing
|
||||
drawing = False
|
||||
top = min(index_y, y)
|
||||
left = min(index_x, x)
|
||||
bottom = max(index_y, y)
|
||||
right = max(index_x, x)
|
||||
height = bottom - top
|
||||
width = right - left
|
||||
roi = (top, left, height, width) # (top, left, height, width)
|
||||
# Draw the final rectangle on the working image and display it
|
||||
working_img = clone.copy()
|
||||
cv2.rectangle(working_img, (left, top), (right, bottom), (0, 255, 0), 2)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
# Create the window and set the callback
|
||||
cv2.namedWindow("Select ROI")
|
||||
cv2.setMouseCallback("Select ROI", mouse_callback)
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
|
||||
print("Instructions for ROI selection:")
|
||||
print(" - Click and drag to draw a rectangular ROI.")
|
||||
print(" - Press 'c' to confirm the selection.")
|
||||
print(" - Press 'r' to reset and draw again.")
|
||||
print(" - Press ESC to cancel the selection.")
|
||||
|
||||
# Wait until the user confirms with 'c', resets with 'r', or cancels with ESC
|
||||
while True:
|
||||
key = cv2.waitKey(1) & 0xFF
|
||||
# Confirm ROI if one has been drawn
|
||||
if key == ord("c") and roi is not None:
|
||||
break
|
||||
# Reset: clear the ROI and restore the original image
|
||||
elif key == ord("r"):
|
||||
working_img = clone.copy()
|
||||
roi = None
|
||||
cv2.imshow("Select ROI", working_img)
|
||||
# Cancel selection for this image
|
||||
elif key == 27: # ESC key
|
||||
roi = None
|
||||
break
|
||||
|
||||
cv2.destroyWindow("Select ROI")
|
||||
return roi
|
||||
|
||||
|
||||
def select_square_roi_for_images(images: dict) -> dict:
|
||||
"""
|
||||
For each image in the provided dictionary, open a window to allow the user
|
||||
to select a rectangular ROI. Returns a dictionary mapping each key to a tuple
|
||||
(top, left, height, width) representing the ROI.
|
||||
|
||||
Parameters:
|
||||
images (dict): Dictionary where keys are identifiers and values are OpenCV images.
|
||||
|
||||
Returns:
|
||||
dict: Mapping of image keys to the selected rectangular ROI.
|
||||
"""
|
||||
selected_rois = {}
|
||||
|
||||
for key, img in images.items():
|
||||
if img is None:
|
||||
print(f"Image for key '{key}' is None, skipping.")
|
||||
continue
|
||||
|
||||
print(f"\nSelect rectangular ROI for image with key: '{key}'")
|
||||
roi = select_rect_roi(img)
|
||||
|
||||
if roi is None:
|
||||
print(f"No valid ROI selected for '{key}'.")
|
||||
else:
|
||||
selected_rois[key] = roi
|
||||
print(f"ROI for '{key}': {roi}")
|
||||
|
||||
return selected_rois
|
||||
|
||||
|
||||
def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
|
||||
"""
|
||||
Find the first row in the dataset and extract the image in order to be used for the crop.
|
||||
"""
|
||||
row = dataset[0]
|
||||
image_dict = {}
|
||||
for k in row:
|
||||
if "image" in k:
|
||||
image_dict[k] = deepcopy(row[k])
|
||||
return image_dict
|
||||
|
||||
|
||||
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset: LeRobotDataset,
|
||||
crop_params_dict: dict[str, tuple[int, int, int, int]],
|
||||
new_repo_id: str,
|
||||
new_dataset_root: str,
|
||||
resize_size: tuple[int, int] = (128, 128),
|
||||
push_to_hub: bool = False,
|
||||
task: str = "",
|
||||
) -> LeRobotDataset:
|
||||
"""
|
||||
Converts an existing LeRobotDataset by iterating over its episodes and frames,
|
||||
applying cropping and resizing to image observations, and saving a new dataset
|
||||
with the transformed data.
|
||||
|
||||
Args:
|
||||
original_dataset (LeRobotDataset): The source dataset.
|
||||
crop_params_dict (Dict[str, Tuple[int, int, int, int]]):
|
||||
A dictionary mapping observation keys to crop parameters (top, left, height, width).
|
||||
new_repo_id (str): Repository id for the new dataset.
|
||||
new_dataset_root (str): The root directory where the new dataset will be written.
|
||||
resize_size (Tuple[int, int], optional): The target size (height, width) after cropping.
|
||||
Defaults to (128, 128).
|
||||
|
||||
Returns:
|
||||
LeRobotDataset: A new LeRobotDataset where the specified image observations have been cropped
|
||||
and resized.
|
||||
"""
|
||||
# 1. Create a new (empty) LeRobotDataset for writing.
|
||||
new_dataset = LeRobotDataset.create(
|
||||
repo_id=new_repo_id,
|
||||
fps=original_dataset.fps,
|
||||
root=new_dataset_root,
|
||||
robot_type=original_dataset.meta.robot_type,
|
||||
features=original_dataset.meta.info["features"],
|
||||
use_videos=len(original_dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Update the metadata for every image key that will be cropped:
|
||||
# (Here we simply set the shape to be the final resize_size.)
|
||||
for key in crop_params_dict:
|
||||
if key in new_dataset.meta.info["features"]:
|
||||
new_dataset.meta.info["features"][key]["shape"] = [3] + list(resize_size)
|
||||
|
||||
# TODO: Directly modify the mp4 video + meta info features, instead of recreating a dataset
|
||||
prev_episode_index = 0
|
||||
for frame_idx in tqdm(range(len(original_dataset))):
|
||||
frame = original_dataset[frame_idx]
|
||||
|
||||
# Create a copy of the frame to add to the new dataset
|
||||
new_frame = {}
|
||||
for key, value in frame.items():
|
||||
if key in ("task_index", "timestamp", "episode_index", "frame_index", "index", "task"):
|
||||
continue
|
||||
if key in ("next.done", "next.reward"):
|
||||
# if not isinstance(value, str) and len(value.shape) == 0:
|
||||
value = value.unsqueeze(0)
|
||||
|
||||
if key in crop_params_dict:
|
||||
top, left, height, width = crop_params_dict[key]
|
||||
# Apply crop then resize.
|
||||
cropped = F.crop(value, top, left, height, width)
|
||||
value = F.resize(cropped, resize_size)
|
||||
value = value.clamp(0, 1)
|
||||
if key.startswith("complementary_info") and isinstance(value, torch.Tensor) and value.dim() == 0:
|
||||
value = value.unsqueeze(0)
|
||||
new_frame[key] = value
|
||||
|
||||
new_frame["task"] = task
|
||||
new_dataset.add_frame(new_frame)
|
||||
|
||||
if frame["episode_index"].item() != prev_episode_index:
|
||||
# Save the episode
|
||||
new_dataset.save_episode()
|
||||
prev_episode_index = frame["episode_index"].item()
|
||||
|
||||
# Save the last episode
|
||||
new_dataset.save_episode()
|
||||
|
||||
if push_to_hub:
|
||||
new_dataset.push_to_hub()
|
||||
|
||||
return new_dataset
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Crop rectangular ROIs from a LeRobot dataset.")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default="lerobot",
|
||||
help="The repository id of the LeRobot dataset to process.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The root directory of the LeRobot dataset.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crop-params-path",
|
||||
type=str,
|
||||
default=None,
|
||||
help="The path to the JSON file containing the ROIs.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Whether to push the new dataset to the hub.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task",
|
||||
type=str,
|
||||
default="",
|
||||
help="The natural language task to describe the dataset.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
|
||||
|
||||
images = get_image_from_lerobot_dataset(dataset)
|
||||
images = {k: v.cpu().permute(1, 2, 0).numpy() for k, v in images.items()}
|
||||
images = {k: (v * 255).astype("uint8") for k, v in images.items()}
|
||||
|
||||
if args.crop_params_path is None:
|
||||
rois = select_square_roi_for_images(images)
|
||||
else:
|
||||
with open(args.crop_params_path) as f:
|
||||
rois = json.load(f)
|
||||
|
||||
# Print the selected rectangular ROIs
|
||||
print("\nSelected Rectangular Regions of Interest (top, left, height, width):")
|
||||
for key, roi in rois.items():
|
||||
print(f"{key}: {roi}")
|
||||
|
||||
new_repo_id = args.repo_id + "_cropped_resized"
|
||||
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
|
||||
|
||||
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
|
||||
original_dataset=dataset,
|
||||
crop_params_dict=rois,
|
||||
new_repo_id=new_repo_id,
|
||||
new_dataset_root=new_dataset_root,
|
||||
resize_size=(128, 128),
|
||||
push_to_hub=args.push_to_hub,
|
||||
task=args.task,
|
||||
)
|
||||
|
||||
meta_dir = new_dataset_root / "meta"
|
||||
meta_dir.mkdir(exist_ok=True)
|
||||
|
||||
with open(meta_dir / "crop_params.json", "w") as f:
|
||||
json.dump(rois, f, indent=4)
|
||||
@@ -0,0 +1,75 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import logging
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
)
|
||||
from lerobot.teleoperators import (
|
||||
gamepad, # noqa: F401
|
||||
so101_leader, # noqa: F401
|
||||
)
|
||||
|
||||
from .gym_manipulator import make_robot_env
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
def eval_policy(env, policy, n_episodes):
|
||||
sum_reward_episode = []
|
||||
for _ in range(n_episodes):
|
||||
obs, _ = env.reset()
|
||||
episode_reward = 0.0
|
||||
while True:
|
||||
action = policy.select_action(obs)
|
||||
obs, reward, terminated, truncated, _ = env.step(action)
|
||||
episode_reward += reward
|
||||
if terminated or truncated:
|
||||
break
|
||||
sum_reward_episode.append(episode_reward)
|
||||
|
||||
logging.info(f"Success after 20 steps {sum_reward_episode}")
|
||||
logging.info(f"success rate {sum(sum_reward_episode) / len(sum_reward_episode)}")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: TrainRLServerPipelineConfig):
|
||||
env_cfg = cfg.env
|
||||
env = make_robot_env(env_cfg)
|
||||
dataset_cfg = cfg.dataset
|
||||
dataset = LeRobotDataset(repo_id=dataset_cfg.repo_id)
|
||||
dataset_meta = dataset.meta
|
||||
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
# env_cfg=cfg.env,
|
||||
ds_meta=dataset_meta,
|
||||
)
|
||||
policy.from_pretrained(env_cfg.pretrained_policy_name_or_path)
|
||||
policy.eval()
|
||||
|
||||
eval_policy(env, policy=policy, n_episodes=10)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,769 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.envs.configs import HILSerlRobotEnvConfig
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
AddTeleopActionAsComplimentaryDataStep,
|
||||
AddTeleopEventsAsInfoStep,
|
||||
DataProcessorPipeline,
|
||||
DeviceProcessorStep,
|
||||
EnvTransition,
|
||||
GripperPenaltyProcessorStep,
|
||||
ImageCropResizeProcessorStep,
|
||||
InterventionActionProcessorStep,
|
||||
JointVelocityProcessorStep,
|
||||
MapDeltaActionToRobotActionStep,
|
||||
MapTensorToDeltaActionDictStep,
|
||||
MotorCurrentProcessorStep,
|
||||
Numpy2TorchActionProcessorStep,
|
||||
RewardClassifierProcessorStep,
|
||||
RobotActionToPolicyActionProcessorStep,
|
||||
TimeLimitProcessorStep,
|
||||
Torch2NumpyActionProcessorStep,
|
||||
TransitionKey,
|
||||
VanillaObservationProcessorStep,
|
||||
create_transition,
|
||||
)
|
||||
from lerobot.processor.converters import identity_transition
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
RobotConfig,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
)
|
||||
from lerobot.robots.robot import Robot
|
||||
from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
EEBoundsAndSafety,
|
||||
EEReferenceAndDelta,
|
||||
ForwardKinematicsJointsToEEObservation,
|
||||
GripperVelocityToJoint,
|
||||
InverseKinematicsRLStep,
|
||||
)
|
||||
from lerobot.teleoperators import (
|
||||
gamepad, # noqa: F401
|
||||
keyboard, # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
so101_leader, # noqa: F401
|
||||
)
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
|
||||
@dataclass
|
||||
class DatasetConfig:
|
||||
"""Configuration for dataset creation and management."""
|
||||
|
||||
repo_id: str
|
||||
task: str
|
||||
root: str | None = None
|
||||
num_episodes_to_record: int = 5
|
||||
replay_episode: int | None = None
|
||||
push_to_hub: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class GymManipulatorConfig:
|
||||
"""Main configuration for gym manipulator environment."""
|
||||
|
||||
env: HILSerlRobotEnvConfig
|
||||
dataset: DatasetConfig
|
||||
mode: str | None = None # Either "record", "replay", None
|
||||
device: str = "cpu"
|
||||
|
||||
|
||||
def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> None:
|
||||
"""Reset robot arm to target position using smooth trajectory."""
|
||||
current_position_dict = robot_arm.bus.sync_read("Present_Position")
|
||||
current_position = np.array(
|
||||
[current_position_dict[name] for name in current_position_dict], dtype=np.float32
|
||||
)
|
||||
trajectory = torch.from_numpy(
|
||||
np.linspace(current_position, target_position, 50)
|
||||
) # NOTE: 30 is just an arbitrary number
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
class RobotEnv(gym.Env):
|
||||
"""Gym environment for robotic control with human intervention support."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
robot,
|
||||
use_gripper: bool = False,
|
||||
display_cameras: bool = False,
|
||||
reset_pose: list[float] | None = None,
|
||||
reset_time_s: float = 5.0,
|
||||
) -> None:
|
||||
"""Initialize robot environment with configuration options.
|
||||
|
||||
Args:
|
||||
robot: Robot interface for hardware communication.
|
||||
use_gripper: Whether to include gripper in action space.
|
||||
display_cameras: Whether to show camera feeds during execution.
|
||||
reset_pose: Joint positions for environment reset.
|
||||
reset_time_s: Time to wait during reset.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
self.robot = robot
|
||||
self.display_cameras = display_cameras
|
||||
|
||||
# Connect to the robot if not already connected.
|
||||
if not self.robot.is_connected:
|
||||
self.robot.connect()
|
||||
|
||||
# Episode tracking.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
|
||||
self._joint_names = [f"{key}.pos" for key in self.robot.bus.motors]
|
||||
self._image_keys = self.robot.cameras.keys()
|
||||
|
||||
self.reset_pose = reset_pose
|
||||
self.reset_time_s = reset_time_s
|
||||
|
||||
self.use_gripper = use_gripper
|
||||
|
||||
self._joint_names = list(self.robot.bus.motors.keys())
|
||||
self._raw_joint_positions = None
|
||||
|
||||
self._setup_spaces()
|
||||
|
||||
def _get_observation(self) -> dict[str, Any]:
|
||||
"""Get current robot observation including joint positions and camera images."""
|
||||
obs_dict = self.robot.get_observation()
|
||||
raw_joint_joint_position = {f"{name}.pos": obs_dict[f"{name}.pos"] for name in self._joint_names}
|
||||
joint_positions = np.array([raw_joint_joint_position[f"{name}.pos"] for name in self._joint_names])
|
||||
|
||||
images = {key: obs_dict[key] for key in self._image_keys}
|
||||
|
||||
return {"agent_pos": joint_positions, "pixels": images, **raw_joint_joint_position}
|
||||
|
||||
def _setup_spaces(self) -> None:
|
||||
"""Configure observation and action spaces based on robot capabilities."""
|
||||
current_observation = self._get_observation()
|
||||
|
||||
observation_spaces = {}
|
||||
|
||||
# Define observation spaces for images and other states.
|
||||
if current_observation is not None and "pixels" in current_observation:
|
||||
prefix = "observation.images"
|
||||
observation_spaces = {
|
||||
f"{prefix}.{key}": gym.spaces.Box(
|
||||
low=0, high=255, shape=current_observation["pixels"][key].shape, dtype=np.uint8
|
||||
)
|
||||
for key in current_observation["pixels"]
|
||||
}
|
||||
|
||||
if current_observation is not None:
|
||||
agent_pos = current_observation["agent_pos"]
|
||||
observation_spaces["observation.state"] = gym.spaces.Box(
|
||||
low=0,
|
||||
high=10,
|
||||
shape=agent_pos.shape,
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
self.observation_space = gym.spaces.Dict(observation_spaces)
|
||||
|
||||
# Define the action space for joint positions along with setting an intervention flag.
|
||||
action_dim = 3
|
||||
bounds = {}
|
||||
bounds["min"] = -np.ones(action_dim)
|
||||
bounds["max"] = np.ones(action_dim)
|
||||
|
||||
if self.use_gripper:
|
||||
action_dim += 1
|
||||
bounds["min"] = np.concatenate([bounds["min"], [0]])
|
||||
bounds["max"] = np.concatenate([bounds["max"], [2]])
|
||||
|
||||
self.action_space = gym.spaces.Box(
|
||||
low=bounds["min"],
|
||||
high=bounds["max"],
|
||||
shape=(action_dim,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
def reset(
|
||||
self, *, seed: int | None = None, options: dict[str, Any] | None = None
|
||||
) -> tuple[dict[str, Any], dict[str, Any]]:
|
||||
"""Reset environment to initial state.
|
||||
|
||||
Args:
|
||||
seed: Random seed for reproducibility.
|
||||
options: Additional reset options.
|
||||
|
||||
Returns:
|
||||
Tuple of (observation, info) dictionaries.
|
||||
"""
|
||||
# Reset the robot
|
||||
# self.robot.reset()
|
||||
start_time = time.perf_counter()
|
||||
if self.reset_pose is not None:
|
||||
log_say("Reset the environment.", play_sounds=True)
|
||||
reset_follower_position(self.robot, np.array(self.reset_pose))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
# Reset episode tracking variables.
|
||||
self.current_step = 0
|
||||
self.episode_data = None
|
||||
obs = self._get_observation()
|
||||
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
|
||||
return obs, {TeleopEvents.IS_INTERVENTION: False}
|
||||
|
||||
def step(self, action) -> tuple[dict[str, np.ndarray], float, bool, bool, dict[str, Any]]:
|
||||
"""Execute one environment step with given action."""
|
||||
joint_targets_dict = {f"{key}.pos": action[i] for i, key in enumerate(self.robot.bus.motors.keys())}
|
||||
|
||||
self.robot.send_action(joint_targets_dict)
|
||||
|
||||
obs = self._get_observation()
|
||||
|
||||
self._raw_joint_positions = {f"{key}.pos": obs[f"{key}.pos"] for key in self._joint_names}
|
||||
|
||||
if self.display_cameras:
|
||||
self.render()
|
||||
|
||||
self.current_step += 1
|
||||
|
||||
reward = 0.0
|
||||
terminated = False
|
||||
truncated = False
|
||||
|
||||
return (
|
||||
obs,
|
||||
reward,
|
||||
terminated,
|
||||
truncated,
|
||||
{TeleopEvents.IS_INTERVENTION: False},
|
||||
)
|
||||
|
||||
def render(self) -> None:
|
||||
"""Display robot camera feeds."""
|
||||
import cv2
|
||||
|
||||
current_observation = self._get_observation()
|
||||
if current_observation is not None:
|
||||
image_keys = [key for key in current_observation if "image" in key]
|
||||
|
||||
for key in image_keys:
|
||||
cv2.imshow(key, cv2.cvtColor(current_observation[key].numpy(), cv2.COLOR_RGB2BGR))
|
||||
cv2.waitKey(1)
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close environment and disconnect robot."""
|
||||
if self.robot.is_connected:
|
||||
self.robot.disconnect()
|
||||
|
||||
def get_raw_joint_positions(self) -> dict[str, float]:
|
||||
"""Get raw joint positions."""
|
||||
return self._raw_joint_positions
|
||||
|
||||
|
||||
def make_robot_env(cfg: HILSerlRobotEnvConfig) -> tuple[gym.Env, Any]:
|
||||
"""Create robot environment from configuration.
|
||||
|
||||
Args:
|
||||
cfg: Environment configuration.
|
||||
|
||||
Returns:
|
||||
Tuple of (gym environment, teleoperator device).
|
||||
"""
|
||||
# Check if this is a GymHIL simulation environment
|
||||
if cfg.name == "gym_hil":
|
||||
assert cfg.robot is None and cfg.teleop is None, "GymHIL environment does not support robot or teleop"
|
||||
import gym_hil # noqa: F401
|
||||
|
||||
# Extract gripper settings with defaults
|
||||
use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
|
||||
gripper_penalty = cfg.processor.gripper.gripper_penalty if cfg.processor.gripper is not None else 0.0
|
||||
|
||||
env = gym.make(
|
||||
f"gym_hil/{cfg.task}",
|
||||
image_obs=True,
|
||||
render_mode="human",
|
||||
use_gripper=use_gripper,
|
||||
gripper_penalty=gripper_penalty,
|
||||
)
|
||||
|
||||
return env, None
|
||||
|
||||
# Real robot environment
|
||||
assert cfg.robot is not None, "Robot config must be provided for real robot environment"
|
||||
assert cfg.teleop is not None, "Teleop config must be provided for real robot environment"
|
||||
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
teleop_device = make_teleoperator_from_config(cfg.teleop)
|
||||
teleop_device.connect()
|
||||
|
||||
# Create base environment with safe defaults
|
||||
use_gripper = cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else True
|
||||
display_cameras = (
|
||||
cfg.processor.observation.display_cameras if cfg.processor.observation is not None else False
|
||||
)
|
||||
reset_pose = cfg.processor.reset.fixed_reset_joint_positions if cfg.processor.reset is not None else None
|
||||
|
||||
env = RobotEnv(
|
||||
robot=robot,
|
||||
use_gripper=use_gripper,
|
||||
display_cameras=display_cameras,
|
||||
reset_pose=reset_pose,
|
||||
)
|
||||
|
||||
return env, teleop_device
|
||||
|
||||
|
||||
def make_processors(
|
||||
env: gym.Env, teleop_device: Teleoperator | None, cfg: HILSerlRobotEnvConfig, device: str = "cpu"
|
||||
) -> tuple[
|
||||
DataProcessorPipeline[EnvTransition, EnvTransition], DataProcessorPipeline[EnvTransition, EnvTransition]
|
||||
]:
|
||||
"""Create environment and action processors.
|
||||
|
||||
Args:
|
||||
env: Robot environment instance.
|
||||
teleop_device: Teleoperator device for intervention.
|
||||
cfg: Processor configuration.
|
||||
device: Target device for computations.
|
||||
|
||||
Returns:
|
||||
Tuple of (environment processor, action processor).
|
||||
"""
|
||||
terminate_on_success = (
|
||||
cfg.processor.reset.terminate_on_success if cfg.processor.reset is not None else True
|
||||
)
|
||||
|
||||
if cfg.name == "gym_hil":
|
||||
action_pipeline_steps = [
|
||||
InterventionActionProcessorStep(terminate_on_success=terminate_on_success),
|
||||
Torch2NumpyActionProcessorStep(),
|
||||
]
|
||||
|
||||
env_pipeline_steps = [
|
||||
Numpy2TorchActionProcessorStep(),
|
||||
VanillaObservationProcessorStep(),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
DeviceProcessorStep(device=device),
|
||||
]
|
||||
|
||||
return DataProcessorPipeline(
|
||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
), DataProcessorPipeline(
|
||||
steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
# Full processor pipeline for real robot environment
|
||||
# Get robot and motor information for kinematics
|
||||
motor_names = list(env.robot.bus.motors.keys())
|
||||
|
||||
# Set up kinematics solver if inverse kinematics is configured
|
||||
kinematics_solver = None
|
||||
if cfg.processor.inverse_kinematics is not None:
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path=cfg.processor.inverse_kinematics.urdf_path,
|
||||
target_frame_name=cfg.processor.inverse_kinematics.target_frame_name,
|
||||
joint_names=motor_names,
|
||||
)
|
||||
|
||||
env_pipeline_steps = [VanillaObservationProcessorStep()]
|
||||
|
||||
if cfg.processor.observation is not None:
|
||||
if cfg.processor.observation.add_joint_velocity_to_observation:
|
||||
env_pipeline_steps.append(JointVelocityProcessorStep(dt=1.0 / cfg.fps))
|
||||
if cfg.processor.observation.add_current_to_observation:
|
||||
env_pipeline_steps.append(MotorCurrentProcessorStep(robot=env.robot))
|
||||
|
||||
if kinematics_solver is not None:
|
||||
env_pipeline_steps.append(
|
||||
ForwardKinematicsJointsToEEObservation(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=motor_names,
|
||||
)
|
||||
)
|
||||
|
||||
if cfg.processor.image_preprocessing is not None:
|
||||
env_pipeline_steps.append(
|
||||
ImageCropResizeProcessorStep(
|
||||
crop_params_dict=cfg.processor.image_preprocessing.crop_params_dict,
|
||||
resize_size=cfg.processor.image_preprocessing.resize_size,
|
||||
)
|
||||
)
|
||||
|
||||
# Add time limit processor if reset config exists
|
||||
if cfg.processor.reset is not None:
|
||||
env_pipeline_steps.append(
|
||||
TimeLimitProcessorStep(max_episode_steps=int(cfg.processor.reset.control_time_s * cfg.fps))
|
||||
)
|
||||
|
||||
# Add gripper penalty processor if gripper config exists and enabled
|
||||
if cfg.processor.gripper is not None and cfg.processor.gripper.use_gripper:
|
||||
env_pipeline_steps.append(
|
||||
GripperPenaltyProcessorStep(
|
||||
penalty=cfg.processor.gripper.gripper_penalty,
|
||||
max_gripper_pos=cfg.processor.max_gripper_pos,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
cfg.processor.reward_classifier is not None
|
||||
and cfg.processor.reward_classifier.pretrained_path is not None
|
||||
):
|
||||
env_pipeline_steps.append(
|
||||
RewardClassifierProcessorStep(
|
||||
pretrained_path=cfg.processor.reward_classifier.pretrained_path,
|
||||
device=device,
|
||||
success_threshold=cfg.processor.reward_classifier.success_threshold,
|
||||
success_reward=cfg.processor.reward_classifier.success_reward,
|
||||
terminate_on_success=terminate_on_success,
|
||||
)
|
||||
)
|
||||
|
||||
env_pipeline_steps.append(AddBatchDimensionProcessorStep())
|
||||
env_pipeline_steps.append(DeviceProcessorStep(device=device))
|
||||
|
||||
action_pipeline_steps = [
|
||||
AddTeleopActionAsComplimentaryDataStep(teleop_device=teleop_device),
|
||||
AddTeleopEventsAsInfoStep(teleop_device=teleop_device),
|
||||
InterventionActionProcessorStep(
|
||||
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False,
|
||||
terminate_on_success=terminate_on_success,
|
||||
),
|
||||
]
|
||||
|
||||
# Replace InverseKinematicsProcessor with new kinematic processors
|
||||
if cfg.processor.inverse_kinematics is not None and kinematics_solver is not None:
|
||||
# Add EE bounds and safety processor
|
||||
inverse_kinematics_steps = [
|
||||
MapTensorToDeltaActionDictStep(
|
||||
use_gripper=cfg.processor.gripper.use_gripper if cfg.processor.gripper is not None else False
|
||||
),
|
||||
MapDeltaActionToRobotActionStep(),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes=cfg.processor.inverse_kinematics.end_effector_step_sizes,
|
||||
motor_names=motor_names,
|
||||
use_latched_reference=False,
|
||||
use_ik_solution=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds=cfg.processor.inverse_kinematics.end_effector_bounds,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
clip_max=cfg.processor.max_gripper_pos,
|
||||
speed_factor=1.0,
|
||||
discrete_gripper=True,
|
||||
),
|
||||
InverseKinematicsRLStep(
|
||||
kinematics=kinematics_solver, motor_names=motor_names, initial_guess_current_joints=False
|
||||
),
|
||||
]
|
||||
action_pipeline_steps.extend(inverse_kinematics_steps)
|
||||
action_pipeline_steps.append(RobotActionToPolicyActionProcessorStep(motor_names=motor_names))
|
||||
|
||||
return DataProcessorPipeline(
|
||||
steps=env_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
), DataProcessorPipeline(
|
||||
steps=action_pipeline_steps, to_transition=identity_transition, to_output=identity_transition
|
||||
)
|
||||
|
||||
|
||||
def step_env_and_process_transition(
|
||||
env: gym.Env,
|
||||
transition: EnvTransition,
|
||||
action: torch.Tensor,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
) -> EnvTransition:
|
||||
"""
|
||||
Execute one step with processor pipeline.
|
||||
|
||||
Args:
|
||||
env: The robot environment
|
||||
transition: Current transition state
|
||||
action: Action to execute
|
||||
env_processor: Environment processor
|
||||
action_processor: Action processor
|
||||
|
||||
Returns:
|
||||
Processed transition with updated state.
|
||||
"""
|
||||
|
||||
# Create action transition
|
||||
transition[TransitionKey.ACTION] = action
|
||||
transition[TransitionKey.OBSERVATION] = (
|
||||
env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {}
|
||||
)
|
||||
processed_action_transition = action_processor(transition)
|
||||
processed_action = processed_action_transition[TransitionKey.ACTION]
|
||||
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
|
||||
reward = reward + processed_action_transition[TransitionKey.REWARD]
|
||||
terminated = terminated or processed_action_transition[TransitionKey.DONE]
|
||||
truncated = truncated or processed_action_transition[TransitionKey.TRUNCATED]
|
||||
complementary_data = processed_action_transition[TransitionKey.COMPLEMENTARY_DATA].copy()
|
||||
new_info = processed_action_transition[TransitionKey.INFO].copy()
|
||||
new_info.update(info)
|
||||
|
||||
new_transition = create_transition(
|
||||
observation=obs,
|
||||
action=processed_action,
|
||||
reward=reward,
|
||||
done=terminated,
|
||||
truncated=truncated,
|
||||
info=new_info,
|
||||
complementary_data=complementary_data,
|
||||
)
|
||||
new_transition = env_processor(new_transition)
|
||||
|
||||
return new_transition
|
||||
|
||||
|
||||
def control_loop(
|
||||
env: gym.Env,
|
||||
env_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
action_processor: DataProcessorPipeline[EnvTransition, EnvTransition],
|
||||
teleop_device: Teleoperator,
|
||||
cfg: GymManipulatorConfig,
|
||||
) -> None:
|
||||
"""Main control loop for robot environment interaction.
|
||||
if cfg.mode == "record": then a dataset will be created and recorded
|
||||
|
||||
Args:
|
||||
env: The robot environment
|
||||
env_processor: Environment processor
|
||||
action_processor: Action processor
|
||||
teleop_device: Teleoperator device
|
||||
cfg: gym_manipulator configuration
|
||||
"""
|
||||
dt = 1.0 / cfg.env.fps
|
||||
|
||||
print(f"Starting control loop at {cfg.env.fps} FPS")
|
||||
print("Controls:")
|
||||
print("- Use gamepad/teleop device for intervention")
|
||||
print("- When not intervening, robot will stay still")
|
||||
print("- Press Ctrl+C to exit")
|
||||
|
||||
# Reset environment and processors
|
||||
obs, info = env.reset()
|
||||
complementary_data = (
|
||||
{"raw_joint_positions": info.pop("raw_joint_positions")} if "raw_joint_positions" in info else {}
|
||||
)
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
# Process initial observation
|
||||
transition = create_transition(observation=obs, info=info, complementary_data=complementary_data)
|
||||
transition = env_processor(data=transition)
|
||||
|
||||
# Determine if gripper is used
|
||||
use_gripper = cfg.env.processor.gripper.use_gripper if cfg.env.processor.gripper is not None else True
|
||||
|
||||
dataset = None
|
||||
if cfg.mode == "record":
|
||||
action_features = teleop_device.action_features
|
||||
features = {
|
||||
"action": action_features,
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
if use_gripper:
|
||||
features["complementary_info.discrete_penalty"] = {
|
||||
"dtype": "float32",
|
||||
"shape": (1,),
|
||||
"names": ["discrete_penalty"],
|
||||
}
|
||||
|
||||
for key, value in transition[TransitionKey.OBSERVATION].items():
|
||||
if key == "observation.state":
|
||||
features[key] = {
|
||||
"dtype": "float32",
|
||||
"shape": value.squeeze(0).shape,
|
||||
"names": None,
|
||||
}
|
||||
if "image" in key:
|
||||
features[key] = {
|
||||
"dtype": "video",
|
||||
"shape": value.squeeze(0).shape,
|
||||
"names": ["channels", "height", "width"],
|
||||
}
|
||||
|
||||
# Create dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
cfg.dataset.repo_id,
|
||||
cfg.env.fps,
|
||||
root=cfg.dataset.root,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
image_writer_processes=0,
|
||||
features=features,
|
||||
)
|
||||
|
||||
episode_idx = 0
|
||||
episode_step = 0
|
||||
episode_start_time = time.perf_counter()
|
||||
|
||||
while episode_idx < cfg.dataset.num_episodes_to_record:
|
||||
step_start_time = time.perf_counter()
|
||||
|
||||
# Create a neutral action (no movement)
|
||||
neutral_action = torch.tensor([0.0, 0.0, 0.0], dtype=torch.float32)
|
||||
if use_gripper:
|
||||
neutral_action = torch.cat([neutral_action, torch.tensor([1.0])]) # Gripper stay
|
||||
|
||||
# Use the new step function
|
||||
transition = step_env_and_process_transition(
|
||||
env=env,
|
||||
transition=transition,
|
||||
action=neutral_action,
|
||||
env_processor=env_processor,
|
||||
action_processor=action_processor,
|
||||
)
|
||||
terminated = transition.get(TransitionKey.DONE, False)
|
||||
truncated = transition.get(TransitionKey.TRUNCATED, False)
|
||||
|
||||
if cfg.mode == "record":
|
||||
observations = {
|
||||
k: v.squeeze(0).cpu()
|
||||
for k, v in transition[TransitionKey.OBSERVATION].items()
|
||||
if isinstance(v, torch.Tensor)
|
||||
}
|
||||
# Use teleop_action if available, otherwise use the action from the transition
|
||||
action_to_record = transition[TransitionKey.COMPLEMENTARY_DATA].get(
|
||||
"teleop_action", transition[TransitionKey.ACTION]
|
||||
)
|
||||
frame = {
|
||||
**observations,
|
||||
"action": action_to_record.cpu(),
|
||||
"next.reward": np.array([transition[TransitionKey.REWARD]], dtype=np.float32),
|
||||
"next.done": np.array([terminated or truncated], dtype=bool),
|
||||
}
|
||||
if use_gripper:
|
||||
discrete_penalty = transition[TransitionKey.COMPLEMENTARY_DATA].get("discrete_penalty", 0.0)
|
||||
frame["complementary_info.discrete_penalty"] = np.array([discrete_penalty], dtype=np.float32)
|
||||
|
||||
if dataset is not None:
|
||||
frame["task"] = cfg.dataset.task
|
||||
dataset.add_frame(frame)
|
||||
|
||||
episode_step += 1
|
||||
|
||||
# Handle episode termination
|
||||
if terminated or truncated:
|
||||
episode_time = time.perf_counter() - episode_start_time
|
||||
logging.info(
|
||||
f"Episode ended after {episode_step} steps in {episode_time:.1f}s with reward {transition[TransitionKey.REWARD]}"
|
||||
)
|
||||
episode_step = 0
|
||||
episode_idx += 1
|
||||
|
||||
if dataset is not None:
|
||||
if transition[TransitionKey.INFO].get("rerecord_episode", False):
|
||||
logging.info(f"Re-recording episode {episode_idx}")
|
||||
dataset.clear_episode_buffer()
|
||||
episode_idx -= 1
|
||||
else:
|
||||
logging.info(f"Saving episode {episode_idx}")
|
||||
dataset.save_episode()
|
||||
|
||||
# Reset for new episode
|
||||
obs, info = env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
|
||||
transition = create_transition(observation=obs, info=info)
|
||||
transition = env_processor(transition)
|
||||
|
||||
# Maintain fps timing
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Pushing dataset to hub")
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
def replay_trajectory(
|
||||
env: gym.Env, action_processor: DataProcessorPipeline, cfg: GymManipulatorConfig
|
||||
) -> None:
|
||||
"""Replay recorded trajectory on robot environment."""
|
||||
assert cfg.dataset.replay_episode is not None, "Replay episode must be provided for replay"
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
root=cfg.dataset.root,
|
||||
episodes=[cfg.dataset.replay_episode],
|
||||
download_videos=False,
|
||||
)
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == cfg.dataset.replay_episode)
|
||||
actions = episode_frames.select_columns("action")
|
||||
|
||||
_, info = env.reset()
|
||||
|
||||
for action_data in actions:
|
||||
start_time = time.perf_counter()
|
||||
transition = create_transition(
|
||||
observation=env.get_raw_joint_positions() if hasattr(env, "get_raw_joint_positions") else {},
|
||||
action=action_data["action"],
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: GymManipulatorConfig) -> None:
|
||||
"""Main entry point for gym manipulator script."""
|
||||
env, teleop_device = make_robot_env(cfg.env)
|
||||
env_processor, action_processor = make_processors(env, teleop_device, cfg.env, cfg.device)
|
||||
|
||||
print("Environment observation space:", env.observation_space)
|
||||
print("Environment action space:", env.action_space)
|
||||
print("Environment processor:", env_processor)
|
||||
print("Action processor:", action_processor)
|
||||
|
||||
if cfg.mode == "replay":
|
||||
replay_trajectory(env, action_processor, cfg)
|
||||
exit()
|
||||
|
||||
control_loop(env, env_processor, action_processor, teleop_device, cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,117 @@
|
||||
# !/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team.
|
||||
# All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.transport import services_pb2, services_pb2_grpc
|
||||
from lerobot.transport.utils import receive_bytes_in_chunks, send_bytes_in_chunks
|
||||
from lerobot.utils.queue import get_last_item_from_queue
|
||||
|
||||
MAX_WORKERS = 3 # Stream parameters, send transitions and interactions
|
||||
SHUTDOWN_TIMEOUT = 10
|
||||
|
||||
|
||||
class LearnerService(services_pb2_grpc.LearnerServiceServicer):
|
||||
"""
|
||||
Implementation of the LearnerService gRPC service
|
||||
This service is used to send parameters to the Actor and receive transitions and interactions from the Actor
|
||||
check transport.proto for the gRPC service definition
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shutdown_event: Event, # type: ignore
|
||||
parameters_queue: Queue,
|
||||
seconds_between_pushes: float,
|
||||
transition_queue: Queue,
|
||||
interaction_message_queue: Queue,
|
||||
queue_get_timeout: float = 0.001,
|
||||
):
|
||||
self.shutdown_event = shutdown_event
|
||||
self.parameters_queue = parameters_queue
|
||||
self.seconds_between_pushes = seconds_between_pushes
|
||||
self.transition_queue = transition_queue
|
||||
self.interaction_message_queue = interaction_message_queue
|
||||
self.queue_get_timeout = queue_get_timeout
|
||||
|
||||
def StreamParameters(self, request, context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to stream parameters from the Actor")
|
||||
|
||||
last_push_time = 0
|
||||
|
||||
while not self.shutdown_event.is_set():
|
||||
time_since_last_push = time.time() - last_push_time
|
||||
if time_since_last_push < self.seconds_between_pushes:
|
||||
self.shutdown_event.wait(self.seconds_between_pushes - time_since_last_push)
|
||||
# Continue, because we could receive a shutdown event,
|
||||
# and it's checked in the while loop
|
||||
continue
|
||||
|
||||
logging.info("[LEARNER] Push parameters to the Actor")
|
||||
buffer = get_last_item_from_queue(
|
||||
self.parameters_queue, block=True, timeout=self.queue_get_timeout
|
||||
)
|
||||
|
||||
if buffer is None:
|
||||
continue
|
||||
|
||||
yield from send_bytes_in_chunks(
|
||||
buffer,
|
||||
services_pb2.Parameters,
|
||||
log_prefix="[LEARNER] Sending parameters",
|
||||
silent=True,
|
||||
)
|
||||
|
||||
last_push_time = time.time()
|
||||
logging.info("[LEARNER] Parameters sent")
|
||||
|
||||
logging.info("[LEARNER] Stream parameters finished")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendTransitions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive transitions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.transition_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] transitions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving transitions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def SendInteractions(self, request_iterator, _context): # noqa: N802
|
||||
# TODO: authorize the request
|
||||
logging.info("[LEARNER] Received request to receive interactions from the Actor")
|
||||
|
||||
receive_bytes_in_chunks(
|
||||
request_iterator,
|
||||
self.interaction_message_queue,
|
||||
self.shutdown_event,
|
||||
log_prefix="[LEARNER] interactions",
|
||||
)
|
||||
|
||||
logging.debug("[LEARNER] Finished receiving interactions")
|
||||
return services_pb2.Empty()
|
||||
|
||||
def Ready(self, request, context): # noqa: N802
|
||||
return services_pb2.Empty()
|
||||
Reference in New Issue
Block a user