- Added additional logging information in wandb around the timings of the policy loop and optimization loop.

- Optimized critic design that improves the performance of the learner loop by a factor of 2
- Cleaned the code and fixed style issues

- Completed the config with actor_learner_config field that contains host-ip and port elemnts that are necessary for the actor-learner servers.

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-01-29 15:50:46 +00:00
committed by AdilZouitine
parent a0a81c0c12
commit 18207d995e
6 changed files with 461 additions and 313 deletions
+164 -115
View File
@@ -1,97 +1,97 @@
import grpc
from concurrent import futures
import functools
import logging
import queue
import pickle
import torch
import torch.nn.functional as F
#!/usr/bin/env python
# Copyright 2024 The HuggingFace Inc. team.
# All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import time
from pprint import pformat
import random
from typing import Optional, Sequence, TypedDict, Callable
from threading import Lock, Thread
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
import torch.nn.functional as F
from torch import nn
from tqdm import tqdm
from deepdiff import DeepDiff
from omegaconf import DictConfig, OmegaConf
from threading import Thread, Lock
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from torch import nn
# TODO: Remove the import of maniskill
from lerobot.common.datasets.factory import make_dataset
from lerobot.common.logger import Logger, log_output_dir
from lerobot.common.policies.factory import make_policy
from lerobot.common.policies.sac.modeling_sac import SACPolicy
from lerobot.common.policies.utils import get_device_from_parameters
from lerobot.common.utils.utils import (
format_big_number,
get_safe_torch_device,
init_hydra_config,
init_logging,
set_global_seed,
)
from lerobot.scripts.server.buffer import ReplayBuffer, move_transition_to_device, concatenate_batch_transitions, move_state_dict_to_device, Transition
# Import generated stubs
import hilserl_pb2
import hilserl_pb2_grpc
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
logging.basicConfig(level=logging.INFO)
# TODO: Implement it in cleaner way maybe
transition_queue = queue.Queue()
interaction_message_queue = queue.Queue()
# 1) Implement the LearnerService so the Actor can send transitions here.
class LearnerServiceServicer(hilserl_pb2_grpc.LearnerServiceServicer):
# def SendTransition(self, request, context):
# """
# Actor calls this method to push a Transition -> Learner.
# """
# buffer = io.BytesIO(request.transition_bytes)
# transition = torch.load(buffer)
# transition_queue.put(transition)
# return hilserl_pb2.Empty()
def SendInteractionMessage(self, request, context):
"""
Actor calls this method to push a Transition -> Learner.
"""
content = pickle.loads(request.interaction_message_bytes)
interaction_message_queue.put(content)
return hilserl_pb2.Empty()
def stream_transitions_from_actor(port=50051):
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
"""
Runs a gRPC server listening for transitions from the Actor.
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(f'127.0.0.1:{port}',
options=[('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)])
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField('transition'):
if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer)
transition_queue.put(transition)
if response.HasField('interaction_message'):
if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content)
# NOTE: Cool down the CPU, if you comment this line you will make a huge bottleneck
# TODO: LOOK TO REMOVE IT
time.sleep(0.001)
def learner_push_parameters(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
):
@@ -100,10 +100,10 @@ def learner_push_parameters(
and periodically push new parameters.
"""
time.sleep(10)
# The Actor's server is presumably listening on a different port, e.g. 50052
channel = grpc.insecure_channel(f"{actor_host}:{actor_port}",
options=[('grpc.max_send_message_length', -1),
('grpc.max_receive_message_length', -1)])
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
@@ -116,20 +116,19 @@ def learner_push_parameters(
params_bytes = buf.getvalue()
# Push them to the Actors "SendParameters" method
logging.info(f"[LEARNER] Pushing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes))
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
# Checked
def add_actor_information(
def add_actor_information_and_train(
cfg,
device,
device: str,
replay_buffer: ReplayBuffer,
offline_replay_buffer: ReplayBuffer,
batch_size: int,
optimizers,
policy,
optimizers: dict[str, torch.optim.Optimizer],
policy: nn.Module,
policy_lock: Lock,
buffer_lock: Lock,
offline_buffer_lock: Lock,
@@ -137,34 +136,52 @@ def add_actor_information(
logger: Logger,
):
"""
In a real application, you might run your training loop here,
reading from the transition queue and doing gradient updates.
Handles data transfer from the actor to the learner, manages training updates,
and logs training progress in an online reinforcement learning setup.
This function continuously:
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Logs training statistics, including loss values and optimization frequency.
**NOTE:**
- This function performs multiple responsibilities (data transfer, training, and logging).
It should ideally be split into smaller functions in the future.
- Due to Python's **Global Interpreter Lock (GIL)**, running separate threads for different tasks
significantly reduces performance. Instead, this function executes all operations in a single thread.
Args:
cfg: Configuration object containing hyperparameters.
device (str): The computing device (`"cpu"` or `"cuda"`).
replay_buffer (ReplayBuffer): The primary replay buffer storing online transitions.
offline_replay_buffer (ReplayBuffer): An additional buffer for offline transitions.
batch_size (int): The number of transitions to sample per training step.
optimizers (Dict[str, torch.optim.Optimizer]): A dictionary of optimizers (`"actor"`, `"critic"`, `"temperature"`).
policy (nn.Module): The reinforcement learning policy with critic, actor, and temperature parameters.
policy_lock (Lock): A threading lock to ensure safe policy updates.
buffer_lock (Lock): A threading lock to safely access the online replay buffer.
offline_buffer_lock (Lock): A threading lock to safely access the offline replay buffer.
logger_lock (Lock): A threading lock to safely log training metrics.
logger (Logger): Logger instance for tracking training progress.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
# are divided by 200. So we need to have a single thread that does all the work.
start = time.time()
time.time()
optimization_step = 0
timeout_for_adding_transitions = 1
while True:
time_for_adding_transitions = time.time()
while not transition_queue.empty():
transition_list = transition_queue.get()
for transition in transition_list:
transition = move_transition_to_device(transition, device=device)
replay_buffer.add(**transition)
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
# logging.info(f"[LEARNER] size of transition queues: {transition_queue.qsize()}")
# logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
# logging.info(f"[LEARNER] size of transition queues: {transition }")
if len(replay_buffer) > cfg.training.online_step_before_learning:
logging.info(f"[LEARNER] size of replay buffer: {len(replay_buffer)}")
while not interaction_message_queue.empty():
interaction_message = interaction_message_queue.get()
logger.log_dict(interaction_message,mode="train",custom_step_key="interaction_step")
# logging.info(f"[LEARNER] size of interaction message queue: {interaction_message_queue.qsize()}")
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
if len(replay_buffer) < cfg.training.online_step_before_learning:
continue
@@ -212,7 +229,7 @@ def add_actor_information(
loss_critic = policy.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
rewards=rewards,
next_observations=next_observations,
done=done,
)
@@ -223,7 +240,6 @@ def add_actor_information(
training_infos = {}
training_infos["loss_critic"] = loss_critic.item()
if optimization_step % cfg.training.policy_update_freq == 0:
for _ in range(cfg.training.policy_update_freq):
with policy_lock:
@@ -242,18 +258,52 @@ def add_actor_information(
training_infos["loss_temperature"] = loss_temperature.item()
policy.update_target_networks()
if optimization_step % cfg.training.log_freq == 0:
logger.log_dict(training_infos, step=optimization_step, mode="train")
policy.update_target_networks()
optimization_step += 1
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
logging.info(f"[LEARNER] Time for one optimization step: {time_for_one_optimization_step}")
logger.log_dict({"Time optimization step":time_for_one_optimization_step}, step=optimization_step, mode="train")
logging.debug(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logger.log_dict(
{"Optimization frequency loop [Hz]": frequency_for_one_optimization_step},
step=optimization_step,
mode="train",
)
optimization_step += 1
if optimization_step % cfg.training.log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
def make_optimizers_and_scheduler(cfg, policy):
def make_optimizers_and_scheduler(cfg, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
**NOTE:**
- If the encoder is shared, its parameters are excluded from the actors optimization process.
- The policys log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
# NOTE: Handle the case of shared encoder where the encoder weights are not optimized with the gradient of the actor
params=policy.actor.parameters_to_optimize,
@@ -273,8 +323,6 @@ def make_optimizers_and_scheduler(cfg, policy):
return optimizers, lr_scheduler
def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = None):
if out_dir is None:
raise NotImplementedError()
@@ -332,6 +380,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
batch_size = cfg.training.batch_size
offline_buffer_lock = None
offline_replay_buffer = None
if cfg.dataset_repo_id is not None:
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
@@ -342,48 +391,48 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
offline_buffer_lock = Lock()
batch_size: int = batch_size // 2 # We will sample from both replay buffer
server_thread = Thread(target=stream_transitions_from_actor, args=(50051,), daemon=True)
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
server_thread.start()
# Start a background thread to process transitions from the queue
transition_thread = Thread(
target=add_actor_information,
target=add_actor_information_and_train,
daemon=True,
args=(cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
buffer_lock,
offline_buffer_lock,
logger_lock,
logger),
args=(
cfg,
device,
replay_buffer,
offline_replay_buffer,
batch_size,
optimizers,
policy,
policy_lock,
buffer_lock,
offline_buffer_lock,
logger_lock,
logger,
),
)
transition_thread.start()
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, "127.0.0.1", 50051, 15),
# args=("127.0.0.1", 50052),
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
param_push_thread.start()
# interaction_thread = Thread(
# target=add_message_interaction_to_wandb,
# daemon=True,
# args=(cfg, logger, logger_lock),
# )
# interaction_thread.start()
transition_thread.join()
# param_push_thread.join()
server_thread.join()
# interaction_thread.join()
@hydra.main(version_base="1.2", config_name="default", config_path="../../configs")