[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)

This commit is contained in:
Eugene Mironov
2025-02-21 16:29:00 +07:00
committed by AdilZouitine
parent 150def839c
commit d48161da1b
17 changed files with 1949 additions and 475 deletions
+129 -113
View File
@@ -14,19 +14,19 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import io
import logging
import pickle
import queue
import shutil
import time
from pprint import pformat
from threading import Lock, Thread
import signal
from threading import Event
from concurrent.futures import ThreadPoolExecutor
import grpc
# Import generated stubs
import hilserl_pb2 # type: ignore
import hilserl_pb2_grpc # type: ignore
import hydra
import torch
@@ -55,10 +55,11 @@ from lerobot.common.utils.utils import (
from lerobot.scripts.server.buffer import (
ReplayBuffer,
concatenate_batch_transitions,
move_state_dict_to_device,
move_transition_to_device,
)
from lerobot.scripts.server import learner_service
logging.basicConfig(level=logging.INFO)
transition_queue = queue.Queue()
@@ -77,9 +78,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
# if resume == True
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
if not checkpoint_dir.exists():
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
raise RuntimeError(
f"No model checkpoint found in {checkpoint_dir} for resume=True"
)
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
checkpoint_cfg_path = str(
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
)
logging.info(
colored(
"Resume=True detected, resuming previous run",
@@ -112,7 +117,9 @@ def load_training_state(
if not cfg.resume:
return None, None
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
training_state = torch.load(
logger.last_checkpoint_dir / logger.training_state_file_name
)
if isinstance(training_state["optimizer"], dict):
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
@@ -126,7 +133,9 @@ def load_training_state(
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
num_learnable_params = sum(
p.numel() for p in policy.parameters() if p.requires_grad
)
num_total_params = sum(p.numel() for p in policy.parameters())
log_output_dir(out_dir)
@@ -136,7 +145,9 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
def initialize_replay_buffer(
cfg: DictConfig, logger: Logger, device: str
) -> ReplayBuffer:
if not cfg.resume:
return ReplayBuffer(
capacity=cfg.training.online_buffer_capacity,
@@ -146,7 +157,9 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
)
dataset = LeRobotDataset(
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
repo_id=cfg.dataset_repo_id,
local_files_only=True,
root=logger.log_dir / "dataset",
)
return ReplayBuffer.from_lerobot_dataset(
lerobot_dataset=dataset,
@@ -168,18 +181,10 @@ def start_learner_threads(
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
) -> None:
actor_ip = cfg.actor_learner_config.actor_ip
port = cfg.actor_learner_config.port
server_thread = Thread(
target=stream_transitions_from_actor,
args=(
actor_ip,
port,
),
daemon=True,
)
host = cfg.actor_learner_config.learner_host
port = cfg.actor_learner_config.learner_port
transition_thread = Thread(
target=add_actor_information_and_train,
@@ -196,95 +201,56 @@ def start_learner_threads(
logger,
resume_optimization_step,
resume_interaction_step,
shutdown_event,
),
)
param_push_thread = Thread(
target=learner_push_parameters,
args=(policy, policy_lock, actor_ip, port, 15),
daemon=True,
)
server_thread.start()
transition_thread.start()
param_push_thread.start()
param_push_thread.join()
service = learner_service.LearnerService(
shutdown_event,
policy,
policy_lock,
cfg.actor_learner_config.policy_parameters_push_frequency,
transition_queue,
interaction_message_queue,
)
server = start_learner_server(service, host, port)
shutdown_event.wait()
server.stop(learner_service.STUTDOWN_TIMEOUT)
logging.info("[LEARNER] gRPC server stopped")
transition_thread.join()
server_thread.join()
logging.info("[LEARNER] Transition thread stopped")
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
"""
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
This function establishes a gRPC connection with the given `host` and `port`, then continuously
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
and stored in a separate queue (`interaction_message_queue`).
Args:
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
"""
# NOTE: This is waiting for the handshake to be done
# In the future we will do it in a canonical way with a proper handshake
time.sleep(10)
channel = grpc.insecure_channel(
f"{host}:{port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
def start_learner_server(
service: learner_service.LearnerService,
host="0.0.0.0",
port=50051,
) -> grpc.server:
server = grpc.server(
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
options=[
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
],
)
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
for response in stub.StreamTransition(hilserl_pb2.Empty()):
if response.HasField("transition"):
buffer = io.BytesIO(response.transition.transition_bytes)
transition = torch.load(buffer)
transition_queue.put(transition)
if response.HasField("interaction_message"):
content = pickle.loads(response.interaction_message.interaction_message_bytes)
interaction_message_queue.put(content)
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
service,
server,
)
server.add_insecure_port(f"{host}:{port}")
server.start()
logging.info("[LEARNER] gRPC server started")
return server
def learner_push_parameters(
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
def check_nan_in_transition(
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
):
"""
As a client, connect to the Actor's gRPC server (ActorService)
and periodically push new parameters.
"""
time.sleep(10)
channel = grpc.insecure_channel(
f"{actor_host}:{actor_port}",
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
)
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
while True:
with policy_lock:
params_dict = policy.actor.state_dict()
if policy.config.vision_encoder_name is not None:
if policy.config.freeze_vision_encoder:
params_dict: dict[str, torch.Tensor] = {
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
}
else:
raise NotImplementedError(
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
)
params_dict = move_state_dict_to_device(params_dict, device="cpu")
# Serialize
buf = io.BytesIO()
torch.save(params_dict, buf)
params_bytes = buf.getvalue()
# Push them to the Actor's "SendParameters" method
logging.info("[LEARNER] Publishing parameters to the Actor")
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
time.sleep(seconds_between_pushes)
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
for k in observations:
if torch.isnan(observations[k]).any():
logging.error(f"observations[{k}] contains NaN values")
@@ -307,6 +273,7 @@ def add_actor_information_and_train(
logger: Logger,
resume_optimization_step: int | None = None,
resume_interaction_step: int | None = None,
shutdown_event: Event | None = None,
):
"""
Handles data transfer from the actor to the learner, manages training updates,
@@ -338,6 +305,7 @@ def add_actor_information_and_train(
logger (Logger): Logger instance for tracking training progress.
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
shutdown_event (Event | None): Event to signal shutdown.
"""
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
@@ -345,9 +313,17 @@ def add_actor_information_and_train(
time.time()
logging.info("Starting learner thread")
interaction_message, transition = None, None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
optimization_step = (
resume_optimization_step if resume_optimization_step is not None else 0
)
interaction_step_shift = (
resume_interaction_step if resume_interaction_step is not None else 0
)
while True:
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown signal received. Exiting...")
break
while not transition_queue.empty():
transition_list = transition_queue.get()
for transition in transition_list:
@@ -361,7 +337,9 @@ def add_actor_information_and_train(
interaction_message = interaction_message_queue.get()
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
interaction_message["Interaction step"] += interaction_step_shift
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
logger.log_dict(
interaction_message, mode="train", custom_step_key="Interaction step"
)
# logging.info(f"Interaction message: {interaction_message}")
if len(replay_buffer) < cfg.training.online_step_before_learning:
@@ -383,7 +361,9 @@ def add_actor_information_and_train(
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@@ -411,7 +391,9 @@ def add_actor_information_and_train(
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
check_nan_in_transition(
observations=observations, actions=actions, next_state=next_observations
)
with policy_lock:
loss_critic = policy.compute_loss_critic(
@@ -439,7 +421,9 @@ def add_actor_information_and_train(
training_infos["loss_actor"] = loss_actor.item()
loss_temperature = policy.compute_loss_temperature(observations=observations)
loss_temperature = policy.compute_loss_temperature(
observations=observations
)
optimizers["temperature"].zero_grad()
loss_temperature.backward()
optimizers["temperature"].step()
@@ -453,9 +437,13 @@ def add_actor_information_and_train(
# logging.info(f"Training infos: {training_infos}")
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
frequency_for_one_optimization_step = 1 / (
time_for_one_optimization_step + 1e-9
)
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
logging.info(
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
)
logger.log_dict(
{
@@ -471,7 +459,8 @@ def add_actor_information_and_train(
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
if cfg.training.save_checkpoint and (
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
optimization_step % cfg.training.save_freq == 0
or optimization_step == cfg.training.online_steps
):
logging.info(f"Checkpoint policy after step {optimization_step}")
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
@@ -479,7 +468,9 @@ def add_actor_information_and_train(
_num_digits = max(6, len(str(cfg.training.online_steps)))
step_identifier = f"{optimization_step:0{_num_digits}d}"
interaction_step = (
interaction_message["Interaction step"] if interaction_message is not None else 0
interaction_message["Interaction step"]
if interaction_message is not None
else 0
)
logger.save_checkpoint(
optimization_step,
@@ -538,7 +529,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha], lr=policy.config.critic_lr
)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
@@ -580,14 +573,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
# Hack: But if we do online traning, we do not need dataset_stats
dataset_stats=None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
if cfg.resume
else None,
)
# compile policy
policy = torch.compile(policy)
assert isinstance(policy, nn.Module)
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
resume_optimization_step, resume_interaction_step = load_training_state(
cfg, logger, optimizers
)
log_training_info(cfg, out_dir, policy)
@@ -599,7 +596,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logging.info("make_dataset offline buffer")
offline_dataset = make_dataset(cfg)
logging.info("Convertion to a offline replay buffer")
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
active_action_dims = [
i
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
if mask
]
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
offline_dataset,
device=device,
@@ -609,6 +610,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
shutdown_event = Event()
def signal_handler(signum, frame):
print(
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
)
shutdown_event.set()
# Register signal handlers
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
signal.signal(signal.SIGTERM, signal_handler) # Termination request
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
start_learner_threads(
cfg,
device,
@@ -621,6 +636,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
logger,
resume_optimization_step,
resume_interaction_step,
shutdown_event,
)