mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
[Port HIL-SERL] Adjust Actor-Learner architecture & clean up dependency management for HIL-SERL (#722)
This commit is contained in:
committed by
AdilZouitine
parent
150def839c
commit
d48161da1b
@@ -14,19 +14,19 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import io
|
||||
import logging
|
||||
import pickle
|
||||
import queue
|
||||
import shutil
|
||||
import time
|
||||
from pprint import pformat
|
||||
from threading import Lock, Thread
|
||||
import signal
|
||||
from threading import Event
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
|
||||
import grpc
|
||||
|
||||
# Import generated stubs
|
||||
import hilserl_pb2 # type: ignore
|
||||
import hilserl_pb2_grpc # type: ignore
|
||||
import hydra
|
||||
import torch
|
||||
@@ -55,10 +55,11 @@ from lerobot.common.utils.utils import (
|
||||
from lerobot.scripts.server.buffer import (
|
||||
ReplayBuffer,
|
||||
concatenate_batch_transitions,
|
||||
move_state_dict_to_device,
|
||||
move_transition_to_device,
|
||||
)
|
||||
|
||||
from lerobot.scripts.server import learner_service
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
transition_queue = queue.Queue()
|
||||
@@ -77,9 +78,13 @@ def handle_resume_logic(cfg: DictConfig, out_dir: str) -> DictConfig:
|
||||
# if resume == True
|
||||
checkpoint_dir = Logger.get_last_checkpoint_dir(out_dir)
|
||||
if not checkpoint_dir.exists():
|
||||
raise RuntimeError(f"No model checkpoint found in {checkpoint_dir} for resume=True")
|
||||
raise RuntimeError(
|
||||
f"No model checkpoint found in {checkpoint_dir} for resume=True"
|
||||
)
|
||||
|
||||
checkpoint_cfg_path = str(Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml")
|
||||
checkpoint_cfg_path = str(
|
||||
Logger.get_last_pretrained_model_dir(out_dir) / "config.yaml"
|
||||
)
|
||||
logging.info(
|
||||
colored(
|
||||
"Resume=True detected, resuming previous run",
|
||||
@@ -112,7 +117,9 @@ def load_training_state(
|
||||
if not cfg.resume:
|
||||
return None, None
|
||||
|
||||
training_state = torch.load(logger.last_checkpoint_dir / logger.training_state_file_name)
|
||||
training_state = torch.load(
|
||||
logger.last_checkpoint_dir / logger.training_state_file_name
|
||||
)
|
||||
|
||||
if isinstance(training_state["optimizer"], dict):
|
||||
assert set(training_state["optimizer"].keys()) == set(optimizers.keys())
|
||||
@@ -126,7 +133,9 @@ def load_training_state(
|
||||
|
||||
|
||||
def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
num_learnable_params = sum(p.numel() for p in policy.parameters() if p.requires_grad)
|
||||
num_learnable_params = sum(
|
||||
p.numel() for p in policy.parameters() if p.requires_grad
|
||||
)
|
||||
num_total_params = sum(p.numel() for p in policy.parameters())
|
||||
|
||||
log_output_dir(out_dir)
|
||||
@@ -136,7 +145,9 @@ def log_training_info(cfg: DictConfig, out_dir: str, policy: nn.Module) -> None:
|
||||
logging.info(f"{num_total_params=} ({format_big_number(num_total_params)})")
|
||||
|
||||
|
||||
def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> ReplayBuffer:
|
||||
def initialize_replay_buffer(
|
||||
cfg: DictConfig, logger: Logger, device: str
|
||||
) -> ReplayBuffer:
|
||||
if not cfg.resume:
|
||||
return ReplayBuffer(
|
||||
capacity=cfg.training.online_buffer_capacity,
|
||||
@@ -146,7 +157,9 @@ def initialize_replay_buffer(cfg: DictConfig, logger: Logger, device: str) -> Re
|
||||
)
|
||||
|
||||
dataset = LeRobotDataset(
|
||||
repo_id=cfg.dataset_repo_id, local_files_only=True, root=logger.log_dir / "dataset"
|
||||
repo_id=cfg.dataset_repo_id,
|
||||
local_files_only=True,
|
||||
root=logger.log_dir / "dataset",
|
||||
)
|
||||
return ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=dataset,
|
||||
@@ -168,18 +181,10 @@ def start_learner_threads(
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
) -> None:
|
||||
actor_ip = cfg.actor_learner_config.actor_ip
|
||||
port = cfg.actor_learner_config.port
|
||||
|
||||
server_thread = Thread(
|
||||
target=stream_transitions_from_actor,
|
||||
args=(
|
||||
actor_ip,
|
||||
port,
|
||||
),
|
||||
daemon=True,
|
||||
)
|
||||
host = cfg.actor_learner_config.learner_host
|
||||
port = cfg.actor_learner_config.learner_port
|
||||
|
||||
transition_thread = Thread(
|
||||
target=add_actor_information_and_train,
|
||||
@@ -196,95 +201,56 @@ def start_learner_threads(
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
shutdown_event,
|
||||
),
|
||||
)
|
||||
|
||||
param_push_thread = Thread(
|
||||
target=learner_push_parameters,
|
||||
args=(policy, policy_lock, actor_ip, port, 15),
|
||||
daemon=True,
|
||||
)
|
||||
|
||||
server_thread.start()
|
||||
transition_thread.start()
|
||||
param_push_thread.start()
|
||||
param_push_thread.join()
|
||||
|
||||
service = learner_service.LearnerService(
|
||||
shutdown_event,
|
||||
policy,
|
||||
policy_lock,
|
||||
cfg.actor_learner_config.policy_parameters_push_frequency,
|
||||
transition_queue,
|
||||
interaction_message_queue,
|
||||
)
|
||||
server = start_learner_server(service, host, port)
|
||||
|
||||
shutdown_event.wait()
|
||||
server.stop(learner_service.STUTDOWN_TIMEOUT)
|
||||
logging.info("[LEARNER] gRPC server stopped")
|
||||
|
||||
transition_thread.join()
|
||||
server_thread.join()
|
||||
logging.info("[LEARNER] Transition thread stopped")
|
||||
|
||||
|
||||
def stream_transitions_from_actor(host="127.0.0.1", port=50051):
|
||||
"""
|
||||
Runs a gRPC client that listens for transition and interaction messages from an Actor service.
|
||||
|
||||
This function establishes a gRPC connection with the given `host` and `port`, then continuously
|
||||
streams transition data from the `ActorServiceStub`. The received transition data is deserialized
|
||||
and stored in a queue (`transition_queue`). Similarly, interaction messages are also deserialized
|
||||
and stored in a separate queue (`interaction_message_queue`).
|
||||
|
||||
Args:
|
||||
host (str, optional): The IP address or hostname of the gRPC server. Defaults to `"127.0.0.1"`.
|
||||
port (int, optional): The port number on which the gRPC server is running. Defaults to `50051`.
|
||||
|
||||
"""
|
||||
# NOTE: This is waiting for the handshake to be done
|
||||
# In the future we will do it in a canonical way with a proper handshake
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{host}:{port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
def start_learner_server(
|
||||
service: learner_service.LearnerService,
|
||||
host="0.0.0.0",
|
||||
port=50051,
|
||||
) -> grpc.server:
|
||||
server = grpc.server(
|
||||
ThreadPoolExecutor(max_workers=learner_service.MAX_WORKERS),
|
||||
options=[
|
||||
("grpc.max_receive_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
("grpc.max_send_message_length", learner_service.MAX_MESSAGE_SIZE),
|
||||
],
|
||||
)
|
||||
stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
for response in stub.StreamTransition(hilserl_pb2.Empty()):
|
||||
if response.HasField("transition"):
|
||||
buffer = io.BytesIO(response.transition.transition_bytes)
|
||||
transition = torch.load(buffer)
|
||||
transition_queue.put(transition)
|
||||
if response.HasField("interaction_message"):
|
||||
content = pickle.loads(response.interaction_message.interaction_message_bytes)
|
||||
interaction_message_queue.put(content)
|
||||
hilserl_pb2_grpc.add_LearnerServiceServicer_to_server(
|
||||
service,
|
||||
server,
|
||||
)
|
||||
server.add_insecure_port(f"{host}:{port}")
|
||||
server.start()
|
||||
logging.info("[LEARNER] gRPC server started")
|
||||
|
||||
return server
|
||||
|
||||
|
||||
def learner_push_parameters(
|
||||
policy: nn.Module, policy_lock: Lock, actor_host="127.0.0.1", actor_port=50052, seconds_between_pushes=5
|
||||
def check_nan_in_transition(
|
||||
observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor
|
||||
):
|
||||
"""
|
||||
As a client, connect to the Actor's gRPC server (ActorService)
|
||||
and periodically push new parameters.
|
||||
"""
|
||||
time.sleep(10)
|
||||
channel = grpc.insecure_channel(
|
||||
f"{actor_host}:{actor_port}",
|
||||
options=[("grpc.max_send_message_length", -1), ("grpc.max_receive_message_length", -1)],
|
||||
)
|
||||
actor_stub = hilserl_pb2_grpc.ActorServiceStub(channel)
|
||||
|
||||
while True:
|
||||
with policy_lock:
|
||||
params_dict = policy.actor.state_dict()
|
||||
if policy.config.vision_encoder_name is not None:
|
||||
if policy.config.freeze_vision_encoder:
|
||||
params_dict: dict[str, torch.Tensor] = {
|
||||
k: v for k, v in params_dict.items() if not k.startswith("encoder.")
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Vision encoder is not frozen, we need to send the full model over the network which requires chunking the model."
|
||||
)
|
||||
|
||||
params_dict = move_state_dict_to_device(params_dict, device="cpu")
|
||||
# Serialize
|
||||
buf = io.BytesIO()
|
||||
torch.save(params_dict, buf)
|
||||
params_bytes = buf.getvalue()
|
||||
|
||||
# Push them to the Actor's "SendParameters" method
|
||||
logging.info("[LEARNER] Publishing parameters to the Actor")
|
||||
response = actor_stub.SendParameters(hilserl_pb2.Parameters(parameter_bytes=params_bytes)) # noqa: F841
|
||||
time.sleep(seconds_between_pushes)
|
||||
|
||||
|
||||
def check_nan_in_transition(observations: torch.Tensor, actions: torch.Tensor, next_state: torch.Tensor):
|
||||
for k in observations:
|
||||
if torch.isnan(observations[k]).any():
|
||||
logging.error(f"observations[{k}] contains NaN values")
|
||||
@@ -307,6 +273,7 @@ def add_actor_information_and_train(
|
||||
logger: Logger,
|
||||
resume_optimization_step: int | None = None,
|
||||
resume_interaction_step: int | None = None,
|
||||
shutdown_event: Event | None = None,
|
||||
):
|
||||
"""
|
||||
Handles data transfer from the actor to the learner, manages training updates,
|
||||
@@ -338,6 +305,7 @@ def add_actor_information_and_train(
|
||||
logger (Logger): Logger instance for tracking training progress.
|
||||
resume_optimization_step (int | None): In the case of resume training, start from the last optimization step reached.
|
||||
resume_interaction_step (int | None): In the case of resume training, shift the interaction step with the last saved step in order to not break logging.
|
||||
shutdown_event (Event | None): Event to signal shutdown.
|
||||
"""
|
||||
# NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
# in the future. The reason why we did that is the GIL in Python. It's super slow the performance
|
||||
@@ -345,9 +313,17 @@ def add_actor_information_and_train(
|
||||
time.time()
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message, transition = None, None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
optimization_step = (
|
||||
resume_optimization_step if resume_optimization_step is not None else 0
|
||||
)
|
||||
interaction_step_shift = (
|
||||
resume_interaction_step if resume_interaction_step is not None else 0
|
||||
)
|
||||
while True:
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown signal received. Exiting...")
|
||||
break
|
||||
|
||||
while not transition_queue.empty():
|
||||
transition_list = transition_queue.get()
|
||||
for transition in transition_list:
|
||||
@@ -361,7 +337,9 @@ def add_actor_information_and_train(
|
||||
interaction_message = interaction_message_queue.get()
|
||||
# If cfg.resume, shift the interaction step with the last checkpointed step in order to not break the logging
|
||||
interaction_message["Interaction step"] += interaction_step_shift
|
||||
logger.log_dict(interaction_message, mode="train", custom_step_key="Interaction step")
|
||||
logger.log_dict(
|
||||
interaction_message, mode="train", custom_step_key="Interaction step"
|
||||
)
|
||||
# logging.info(f"Interaction message: {interaction_message}")
|
||||
|
||||
if len(replay_buffer) < cfg.training.online_step_before_learning:
|
||||
@@ -383,7 +361,9 @@ def add_actor_information_and_train(
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
@@ -411,7 +391,9 @@ def add_actor_information_and_train(
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
check_nan_in_transition(
|
||||
observations=observations, actions=actions, next_state=next_observations
|
||||
)
|
||||
|
||||
with policy_lock:
|
||||
loss_critic = policy.compute_loss_critic(
|
||||
@@ -439,7 +421,9 @@ def add_actor_information_and_train(
|
||||
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
|
||||
loss_temperature = policy.compute_loss_temperature(observations=observations)
|
||||
loss_temperature = policy.compute_loss_temperature(
|
||||
observations=observations
|
||||
)
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
optimizers["temperature"].step()
|
||||
@@ -453,9 +437,13 @@ def add_actor_information_and_train(
|
||||
# logging.info(f"Training infos: {training_infos}")
|
||||
|
||||
time_for_one_optimization_step = time.time() - time_for_one_optimization_step
|
||||
frequency_for_one_optimization_step = 1 / (time_for_one_optimization_step + 1e-9)
|
||||
frequency_for_one_optimization_step = 1 / (
|
||||
time_for_one_optimization_step + 1e-9
|
||||
)
|
||||
|
||||
logging.info(f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}")
|
||||
logging.info(
|
||||
f"[LEARNER] Optimization frequency loop [Hz]: {frequency_for_one_optimization_step}"
|
||||
)
|
||||
|
||||
logger.log_dict(
|
||||
{
|
||||
@@ -471,7 +459,8 @@ def add_actor_information_and_train(
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
if cfg.training.save_checkpoint and (
|
||||
optimization_step % cfg.training.save_freq == 0 or optimization_step == cfg.training.online_steps
|
||||
optimization_step % cfg.training.save_freq == 0
|
||||
or optimization_step == cfg.training.online_steps
|
||||
):
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
# Note: Save with step as the identifier, and format it to have at least 6 digits but more if
|
||||
@@ -479,7 +468,9 @@ def add_actor_information_and_train(
|
||||
_num_digits = max(6, len(str(cfg.training.online_steps)))
|
||||
step_identifier = f"{optimization_step:0{_num_digits}d}"
|
||||
interaction_step = (
|
||||
interaction_message["Interaction step"] if interaction_message is not None else 0
|
||||
interaction_message["Interaction step"]
|
||||
if interaction_message is not None
|
||||
else 0
|
||||
)
|
||||
logger.save_checkpoint(
|
||||
optimization_step,
|
||||
@@ -538,7 +529,9 @@ def make_optimizers_and_scheduler(cfg, policy: nn.Module):
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(), lr=policy.config.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=policy.config.critic_lr)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha], lr=policy.config.critic_lr
|
||||
)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
@@ -580,14 +573,18 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
# dataset_stats=offline_dataset.meta.stats if not cfg.resume else None,
|
||||
# Hack: But if we do online traning, we do not need dataset_stats
|
||||
dataset_stats=None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir) if cfg.resume else None,
|
||||
pretrained_policy_name_or_path=str(logger.last_pretrained_model_dir)
|
||||
if cfg.resume
|
||||
else None,
|
||||
)
|
||||
# compile policy
|
||||
policy = torch.compile(policy)
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg, policy)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg, logger, optimizers)
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(
|
||||
cfg, logger, optimizers
|
||||
)
|
||||
|
||||
log_training_info(cfg, out_dir, policy)
|
||||
|
||||
@@ -599,7 +596,11 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logging.info("make_dataset offline buffer")
|
||||
offline_dataset = make_dataset(cfg)
|
||||
logging.info("Convertion to a offline replay buffer")
|
||||
active_action_dims = [i for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space) if mask]
|
||||
active_action_dims = [
|
||||
i
|
||||
for i, mask in enumerate(cfg.env.wrapper.joint_masking_action_space)
|
||||
if mask
|
||||
]
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
offline_dataset,
|
||||
device=device,
|
||||
@@ -609,6 +610,20 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
shutdown_event = Event()
|
||||
|
||||
def signal_handler(signum, frame):
|
||||
print(
|
||||
f"\nReceived signal {signal.Signals(signum).name}. Initiating learner shutdown..."
|
||||
)
|
||||
shutdown_event.set()
|
||||
|
||||
# Register signal handlers
|
||||
signal.signal(signal.SIGINT, signal_handler) # Ctrl+C
|
||||
signal.signal(signal.SIGTERM, signal_handler) # Termination request
|
||||
signal.signal(signal.SIGHUP, signal_handler) # Terminal closed
|
||||
signal.signal(signal.SIGQUIT, signal_handler) # Ctrl+\
|
||||
|
||||
start_learner_threads(
|
||||
cfg,
|
||||
device,
|
||||
@@ -621,6 +636,7 @@ def train(cfg: DictConfig, out_dir: str | None = None, job_name: str | None = No
|
||||
logger,
|
||||
resume_optimization_step,
|
||||
resume_interaction_step,
|
||||
shutdown_event,
|
||||
)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user