Compare commits

...

10 Commits

Author SHA1 Message Date
Khalil Meftah 519234a5d8 feat: add offline training in learner 2026-03-22 23:00:07 +01:00
Khalil Meftah d9371b9a34 feat: add RLT algorithm 2026-03-22 22:59:35 +01:00
Khalil Meftah 17f47b9cbc feat: add RLT policy RL-token encoder-decoder and actor 2026-03-22 22:57:43 +01:00
Khalil Meftah 05395c8b10 Add offline phase hooks to RLAlgorithm base 2026-03-22 22:52:56 +01:00
Khalil Meftah f495054321 disable processor in actor for sac/hilserl 2026-03-19 13:42:46 +01:00
Khalil Meftah 2345c779ee disable processor for sac/hilserl 2026-03-19 13:12:21 +01:00
Khalil Meftah aaf8576411 chore: rename losses 2026-03-19 12:36:02 +01:00
Khalil Meftah d3e6f14d4f fix: move algorithm-owned modules to the policy device 2026-03-18 15:27:41 +01:00
Khalil Meftah 1f5487eea8 refactor: decouple policy from algorithm 2026-03-11 16:49:14 +01:00
Khalil Meftah 8d50be9faa refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring
- Add RLAlgorithm base class and RLAlgorithmConfig with draccus.ChoiceRegistry
- Add RLTrainer for unified training orchestration with iterator pattern
- Add DataMixer and OnlineOfflineMixer for online/offline data mixing
- Restructure SAC algorithm with batch iterator and factory pattern
- Add observation normalization pre/post processors
- Add comprehensive tests for all new components
2026-03-03 16:50:00 +01:00
27 changed files with 3131 additions and 898 deletions
+17 -14
View File
@@ -4,7 +4,6 @@ from pathlib import Path
from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
@@ -12,6 +11,7 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
@@ -40,8 +40,9 @@ def run_learner(
policy_learner.train()
policy_learner.to(device)
# Create Adam optimizer from scratch - simple and clean
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
algorithm.make_optimizers()
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
@@ -83,24 +84,26 @@ def run_learner(
else:
batch[key] = online_batch[key]
loss, _ = policy_learner.forward(batch)
def batch_iter(b=batch):
while True:
yield b
optimizer.zero_grad()
loss.backward()
optimizer.step()
stats = algorithm.update(batch_iter())
training_step += 1
if training_step % LOG_EVERY == 0:
log_dict = stats.to_log_dict()
print(
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
f"[LEARNER] Training step {training_step}, "
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
)
# Send updated parameters to actor every 10 training steps
if training_step % SEND_EVERY == 0:
try:
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
parameters_queue.put_nowait(state_dict)
weights = algorithm.get_weights()
parameters_queue.put_nowait(weights)
print("[LEARNER] Sent updated parameters to actor")
except Full:
# Missing write due to queue not being consumed (should happen rarely)
@@ -144,15 +147,15 @@ def run_actor(
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
try:
new_params = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_params)
new_weights = parameters_queue.get_nowait()
policy_actor.load_state_dict(new_weights)
print("[ACTOR] Updated policy parameters from learner")
except Empty: # No new updated parameters available from learner, waiting
pass
# Get action from policy
# Get action from policy (returns full action: continuous + discrete)
policy_obs = make_policy_obs(obs, device=device)
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
action_tensor = policy_actor.select_action(policy_obs)
action = action_tensor.squeeze(0).cpu().numpy()
# Step environment
+12
View File
@@ -211,3 +211,15 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
# NOTE: In RL, we don't need an offline dataset
# TODO: Make `TrainPipelineConfig.dataset` optional
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
# Algorithm name registered in RLAlgorithmConfig registry
algorithm: str = "sac"
# Data mixer strategy name. Currently supports "online_offline"
mixer: str = "online_offline"
# Fraction sampled from online replay when using OnlineOfflineMixer
online_ratio: float = 0.5
# RL trainer iterator
async_prefetch: bool = True
queue_size: int = 2
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.policies.rlt.configuration_rlt import RLTConfig
from lerobot.policies.rlt.modeling_rlt import RLTPolicy
__all__ = ["RLTConfig", "RLTPolicy"]
@@ -0,0 +1,156 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy configuration.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
"""
from __future__ import annotations
from dataclasses import dataclass, field
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import NormalizationMode
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
@dataclass
class RLTokenConfig:
"""Configuration for the RL-token encoder/decoder transformer."""
input_dim: int = 2048
rl_token_dim: int = 2048
num_encoder_layers: int = 2
num_decoder_layers: int = 2
num_heads: int = 8
ff_dim: int = 2048
dropout: float = 0.0
@dataclass
class RLTActorConfig:
"""Configuration for the lightweight RL actor MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
std: float = 0.1
@dataclass
class RLTCriticConfig:
"""Configuration for the RLT critic MLP."""
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
@PreTrainedConfig.register_subclass("rlt")
@dataclass
class RLTConfig(PreTrainedConfig):
"""Configuration for the RLT (RL Token) policy.
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
a lightweight actor-critic head using the RL token as state representation.
The frozen VLA also provides reference action chunks that the actor refines.
"""
normalization_mapping: dict[str, NormalizationMode] = field(
default_factory=lambda: {
"VISUAL": NormalizationMode.MEAN_STD,
"STATE": NormalizationMode.MIN_MAX,
"ACTION": NormalizationMode.MIN_MAX,
}
)
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
default_factory=lambda: {
OBS_IMAGE: {
"mean": [0.485, 0.456, 0.406],
"std": [0.229, 0.224, 0.225],
},
OBS_STATE: {"min": [0.0], "max": [1.0]},
ACTION: {"min": [0.0], "max": [1.0]},
}
)
# ── Device ──
device: str = "cuda"
storage_device: str = "cpu"
# ── VLA backbone ──
vla_checkpoint: str | None = None
# ── RL-token ──
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
# ── Actor / Critic heads ──
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
# ── Action chunks ──
chunk_size: int = 10
vla_chunk_size: int = 50
# ── Training parameters ──
online_steps: int = 50000
offline_steps: int = 5000
online_buffer_capacity: int = 100000
offline_buffer_capacity: int = 100000
online_step_before_learning: int = 500
warmup_steps: int = 500
async_prefetch: bool = False
# ── Algorithm hyperparameters ──
utd_ratio: int = 5
policy_update_freq: int = 2
discount: float = 0.99
critic_lr: float = 3e-4
actor_lr: float = 3e-4
rl_token_lr: float = 1e-4
tau: float = 0.005
clip_grad_norm: float = 10.0
num_critics: int = 2
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
chunk_stride: int = 2
vla_finetune_weight: float = 0.0
# ── Distributed ──
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
def __post_init__(self):
super().__post_init__()
def get_optimizer_preset(self):
return None
def get_scheduler_preset(self):
return None
def validate_features(self) -> None:
if ACTION not in self.output_features:
raise ValueError("You must provide 'action' in the output features")
@property
def observation_delta_indices(self) -> list | None:
return None
@property
def action_delta_indices(self) -> list | None:
return None
@property
def reward_delta_indices(self) -> None:
return None
+318
View File
@@ -0,0 +1,318 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) policy networks.
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
(Xu et al., Physical Intelligence, 2026)
Architecture:
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
"""
from __future__ import annotations
import math
import torch
import torch.nn as nn
from torch import Tensor
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies.rlt.configuration_rlt import RLTConfig
# ── Building blocks ──────────────────────────────────────────────────
class MLP(nn.Module):
"""Simple feedforward network with ReLU activations."""
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
super().__init__()
layers: list[nn.Module] = []
prev = input_dim
for h in hidden_dims:
layers.append(nn.Linear(prev, h))
layers.append(nn.ReLU())
prev = h
layers.append(nn.Linear(prev, output_dim))
self.net = nn.Sequential(*layers)
def forward(self, x: Tensor) -> Tensor:
return self.net(x)
# ── RL Token Encoder ─────────────────────────────────────────────────
class RLTokenEncoder(nn.Module):
"""Compress VLA token embeddings into a single RL token via a small transformer.
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
through transformer encoder layers, and returns the output at the ``e_rl``
position as the RL token ``z_rl``.
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
"""
def __init__(
self,
input_dim: int,
rl_token_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.rl_token_dim = rl_token_dim
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
if input_dim != rl_token_dim:
self.input_proj = nn.Linear(input_dim, rl_token_dim)
else:
self.input_proj = nn.Identity()
encoder_layer = nn.TransformerEncoderLayer(
d_model=rl_token_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
def forward(self, z_vla: Tensor) -> Tensor:
"""
Args:
z_vla: VLA token embeddings, shape ``(B, M, D)``.
Returns:
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
"""
batch_size = z_vla.shape[0]
e_rl = self.e_rl.expand(batch_size, -1, -1)
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
seq = self.input_proj(seq)
out = self.transformer(seq)
z_rl = out[:, -1, :] # output at e_rl position
return z_rl
# ── RL Token Decoder ─────────────────────────────────────────────────
class RLTokenDecoder(nn.Module):
"""Autoregressively reconstruct VLA embeddings from z_rl.
Used only during Stage 1 (offline RL-token training).
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
"""
def __init__(
self,
rl_token_dim: int,
output_dim: int,
num_layers: int,
num_heads: int,
ff_dim: int,
dropout: float = 0.0,
):
super().__init__()
self.output_dim = output_dim
if rl_token_dim != output_dim:
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
else:
self.rl_proj = nn.Identity()
decoder_layer = nn.TransformerDecoderLayer(
d_model=output_dim,
nhead=num_heads,
dim_feedforward=ff_dim,
dropout=dropout,
batch_first=True,
)
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
self.output_head = nn.Linear(output_dim, output_dim)
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
"""
Args:
z_rl: RL token, shape ``(B, D_rl)``.
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
Returns:
Reconstructed embeddings, shape ``(B, M, D)``.
"""
seq_len = z_vla_stopped.shape[1]
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
decoded = self.transformer(
tgt=target,
memory=z_rl_proj,
tgt_mask=causal_mask,
)
return self.output_head(decoded) # (B, M, D)
# ── Actor ────────────────────────────────────────────────────────────
class RLTActor(nn.Module):
"""Lightweight actor that refines VLA reference action chunks.
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
The actor is conditioned on both the RL state and the VLA's proposed action
chunk, acting as a "VLA-guided action editor".
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
super().__init__()
input_dim = state_dim + action_chunk_dim
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
self.log_std = math.log(std)
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
"""Return the mean action chunk.
Args:
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
Returns:
Refined action chunk (mean), shape ``(B, C*d)``.
"""
x = torch.cat([state, ref_action_chunk], dim=-1)
return self.net(x)
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
"""Sample an action and return (action, log_prob)."""
mean = self.forward(state, ref_action_chunk)
std = math.exp(self.log_std)
noise = torch.randn_like(mean) * std
action = mean + noise
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
std * math.sqrt(2 * math.pi)
)
return action, log_prob
# ── Policy (inference bundle) ────────────────────────────────────────
class RLTPolicy(PreTrainedPolicy):
"""RLT policy — bundles the RL-token encoder and actor for inference.
The frozen VLA backbone is **not** part of this module; it is loaded
separately and its embeddings / reference actions are passed in via the
observation dict (populated by the actor process or a preprocessor).
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
and optimizers. This class only contains what is needed for ``select_action``.
"""
name = "rlt"
config_class = RLTConfig
def __init__(self, config: RLTConfig, dataset_stats=None):
super().__init__(config, dataset_stats)
action_dim = config.output_features["action"].shape[0]
action_chunk_dim = config.chunk_size * action_dim
prop_feature = config.input_features.get("observation.state", None)
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
state_dim = config.rl_token.rl_token_dim + proprioception_dim
# RL-token encoder (frozen after Stage 1)
self.rl_token_encoder = RLTokenEncoder(
input_dim=config.rl_token.input_dim,
rl_token_dim=config.rl_token.rl_token_dim,
num_layers=config.rl_token.num_encoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# RL-token decoder (used only during Stage 1 training)
self.rl_token_decoder = RLTokenDecoder(
rl_token_dim=config.rl_token.rl_token_dim,
output_dim=config.rl_token.input_dim,
num_layers=config.rl_token.num_decoder_layers,
num_heads=config.rl_token.num_heads,
ff_dim=config.rl_token.ff_dim,
dropout=config.rl_token.dropout,
)
# Actor MLP
self.actor = RLTActor(
state_dim=state_dim,
action_chunk_dim=action_chunk_dim,
hidden_dims=config.actor.hidden_dims,
std=config.actor.std,
)
self._action_dim = action_dim
self._action_chunk_dim = action_chunk_dim
self._state_dim = state_dim
self._proprioception_dim = proprioception_dim
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a refined action chunk given an observation.
Expects the observation dict to contain:
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
Returns:
Action chunk tensor of shape ``(C*d,)``.
"""
self.eval()
vla_emb = batch["observation.vla_embeddings"]
if vla_emb.dim() == 2:
vla_emb = vla_emb.unsqueeze(0)
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
parts = [z_rl]
if "observation.state" in batch and self._proprioception_dim > 0:
prop = batch["observation.state"]
if prop.dim() == 1:
prop = prop.unsqueeze(0)
parts.append(prop)
state = torch.cat(parts, dim=-1)
ref = batch["observation.reference_action"]
if ref.dim() == 1:
ref = ref.unsqueeze(0)
action = self.actor(state, ref)
return action.squeeze(0)
def reset(self):
pass
+24 -374
View File
@@ -15,16 +15,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
from collections.abc import Callable
from dataclasses import asdict
from typing import Literal
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
@@ -52,20 +47,13 @@ class SACPolicy(
# Determine action dimension and initialize all components
continuous_action_dim = config.output_features[ACTION].shape[0]
self._init_encoders()
self._init_critics(continuous_action_dim)
self.encoder = SACObservationEncoder(config)
self._init_actor(continuous_action_dim)
self._init_temperature()
self._init_discrete_critic()
def get_optim_params(self) -> dict:
optim_params = {
"actor": [
p
for n, p in self.actor.named_parameters()
if not n.startswith("encoder") or not self.shared_encoder
],
"critic": self.critic_ensemble.parameters(),
"temperature": self.log_alpha,
"actor": [self.actor.parameters()],
}
if self.config.num_discrete_actions is not None:
optim_params["discrete_critic"] = self.discrete_critic.parameters()
@@ -83,10 +71,9 @@ class SACPolicy(
@torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select action for inference/evaluation"""
observations_features = None
if self.shared_encoder and self.actor.encoder.has_images:
observations_features = self.actor.encoder.get_cached_image_features(batch)
if self.encoder.has_images:
observations_features = self.encoder.get_cached_image_features(batch)
actions, _, _ = self.actor(batch, observations_features)
@@ -97,372 +84,35 @@ class SACPolicy(
return actions
def critic_forward(
self,
observations: dict[str, Tensor],
actions: Tensor,
use_target: bool = False,
observation_features: Tensor | None = None,
) -> Tensor:
"""Forward pass through a critic network ensemble
Args:
observations: Dictionary of observations
actions: Action tensor
use_target: If True, use target critics, otherwise use ensemble critics
Returns:
Tensor of Q-values from all critics
"""
critics = self.critic_target if use_target else self.critic_ensemble
q_values = critics(observations, actions, observation_features)
return q_values
def discrete_critic_forward(
self, observations, use_target=False, observation_features=None
) -> torch.Tensor:
"""Forward pass through a discrete critic network
Args:
observations: Dictionary of observations
use_target: If True, use target critics, otherwise use ensemble critics
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
Returns:
Tensor of Q-values from the discrete critic network
"""
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
q_values = discrete_critic(observations, observation_features)
return q_values
def forward(
self,
batch: dict[str, Tensor | dict[str, Tensor]],
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
) -> dict[str, Tensor]:
"""Compute the loss for the given model
"""Actor forward pass."""
observations = batch.get("state", batch)
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
actions, log_probs, means = self.actor(observations, observation_features)
return {"action": actions, "log_prob": log_probs, "action_mean": means}
Args:
batch: Dictionary containing:
- action: Action tensor
- reward: Reward tensor
- state: Observations tensor dict
- next_state: Next observations tensor dict
- done: Done mask tensor
- observation_feature: Optional pre-computed observation features
- next_observation_feature: Optional pre-computed next observation features
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
Returns:
The computed loss tensor
"""
# Extract common components from batch
actions: Tensor = batch[ACTION]
observations: dict[str, Tensor] = batch["state"]
observation_features: Tensor = batch.get("observation_feature")
if model == "critic":
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
loss_critic = self.compute_loss_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
)
return {"loss_critic": loss_critic}
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
# Extract critic-specific components
rewards: Tensor = batch["reward"]
next_observations: dict[str, Tensor] = batch["next_state"]
done: Tensor = batch["done"]
next_observation_features: Tensor = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
loss_discrete_critic = self.compute_loss_discrete_critic(
observations=observations,
actions=actions,
rewards=rewards,
next_observations=next_observations,
done=done,
observation_features=observation_features,
next_observation_features=next_observation_features,
complementary_info=complementary_info,
)
return {"loss_discrete_critic": loss_discrete_critic}
if model == "actor":
return {
"loss_actor": self.compute_loss_actor(
observations=observations,
observation_features=observation_features,
)
}
if model == "temperature":
return {
"loss_temperature": self.compute_loss_temperature(
observations=observations,
observation_features=observation_features,
)
}
raise ValueError(f"Unknown model type: {model}")
def update_target_networks(self):
"""Update target networks with exponential moving average"""
for target_param, param in zip(
self.critic_target.parameters(),
self.critic_ensemble.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
if self.config.num_discrete_actions is not None:
for target_param, param in zip(
self.discrete_critic_target.parameters(),
self.discrete_critic.parameters(),
strict=True,
):
target_param.data.copy_(
param.data * self.config.critic_target_update_weight
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
)
@property
def temperature(self) -> float:
"""Return the current temperature value, always in sync with log_alpha."""
return self.log_alpha.exp().item()
def compute_loss_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features: Tensor | None = None,
next_observation_features: Tensor | None = None,
) -> Tensor:
with torch.no_grad():
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
# 2- compute q targets
q_targets = self.critic_forward(
observations=next_observations,
actions=next_action_preds,
use_target=True,
observation_features=next_observation_features,
)
# subsample critics to prevent overfitting if use high UTD (update to date)
# TODO: Get indices before forward pass to avoid unnecessary computation
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
# critics subsample size
min_q, _ = q_targets.min(dim=0) # Get values from min operation
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
# 3- compute predicted qs
if self.config.num_discrete_actions is not None:
# NOTE: We only want to keep the continuous action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_forward(
observations=observations,
actions=actions,
use_target=False,
observation_features=observation_features,
)
# 4- Calculate loss
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
critics_loss = (
F.mse_loss(
input=q_preds,
target=td_target_duplicate,
reduction="none",
).mean(dim=1)
).sum()
return critics_loss
def compute_loss_discrete_critic(
self,
observations,
actions,
rewards,
next_observations,
done,
observation_features=None,
next_observation_features=None,
complementary_info=None,
):
# NOTE: We only want to keep the discrete action part
# In the buffer we have the full action space (continuous + discrete)
# We need to split them before concatenating them in the critic forward
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete)
actions_discrete = actions_discrete.long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
with torch.no_grad():
# For DQN, select actions using online network, evaluate with target network
next_discrete_qs = self.discrete_critic_forward(
next_observations, use_target=False, observation_features=next_observation_features
)
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
# Get target Q-values from target network
target_next_discrete_qs = self.discrete_critic_forward(
observations=next_observations,
use_target=True,
observation_features=next_observation_features,
)
# Use gather to select Q-values for best actions
target_next_discrete_q = torch.gather(
target_next_discrete_qs, dim=1, index=best_next_discrete_action
).squeeze(-1)
# Compute target Q-value with Bellman equation
rewards_discrete = rewards
if discrete_penalties is not None:
rewards_discrete = rewards + discrete_penalties
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
# Get predicted Q-values for current observations
predicted_discrete_qs = self.discrete_critic_forward(
observations=observations, use_target=False, observation_features=observation_features
)
# Use gather to select Q-values for taken actions
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
# Compute MSE loss between predicted and target Q-values
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
return discrete_critic_loss
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
"""Compute the temperature loss"""
# calculate temperature loss
with torch.no_grad():
_, log_probs, _ = self.actor(observations, observation_features)
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
return temperature_loss
def compute_loss_actor(
self,
observations,
observation_features: Tensor | None = None,
) -> Tensor:
actions_pi, log_probs, _ = self.actor(observations, observation_features)
q_preds = self.critic_forward(
observations=observations,
actions=actions_pi,
use_target=False,
observation_features=observation_features,
)
min_q_preds = q_preds.min(dim=0)[0]
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
return actor_loss
def _init_encoders(self):
"""Initialize shared or separate encoders for actor and critic."""
self.shared_encoder = self.config.shared_encoder
self.encoder_critic = SACObservationEncoder(self.config)
self.encoder_actor = (
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
)
def _init_critics(self, continuous_action_dim):
"""Build critic ensemble, targets, and optional discrete critic."""
heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
target_heads = [
CriticHead(
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
**asdict(self.config.critic_network_kwargs),
)
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critics()
def _init_discrete_critics(self):
"""Build discrete discrete critic ensemble and target networks."""
self.discrete_critic = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
self.discrete_critic_target = DiscreteCritic(
encoder=self.encoder_critic,
input_dim=self.encoder_critic.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (maractingi, azouitine) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
def _init_actor(self, continuous_action_dim):
"""Initialize policy actor network and default target entropy."""
# NOTE: The actor select only the continuous action part
def _init_actor(self, continuous_action_dim: int) -> None:
self.actor = Policy(
encoder=self.encoder_actor,
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
encoder=self.encoder,
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
action_dim=continuous_action_dim,
encoder_is_shared=self.shared_encoder,
encoder_is_shared=False,
**asdict(self.config.policy_kwargs),
)
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha)."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
def _init_discrete_critic(self) -> None:
if self.config.num_discrete_actions is None:
self.discrete_critic = None
return
self.discrete_critic = DiscreteCritic(
encoder=self.encoder,
input_dim=self.encoder.output_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
class SACObservationEncoder(nn.Module):
@@ -131,6 +131,15 @@ class _NormalizationMixin:
if self.dtype is None:
self.dtype = torch.float32
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
def _reshape_visual_stats(self) -> None:
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
for key, feature in self.features.items():
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
for stat_name, stat_tensor in self._tensor_stats[key].items():
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
def to(
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
@@ -149,6 +158,7 @@ class _NormalizationMixin:
if dtype is not None:
self.dtype = dtype
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
self._reshape_visual_stats()
return self
def state_dict(self) -> dict[str, Tensor]:
@@ -198,6 +208,7 @@ class _NormalizationMixin:
# Don't load from state_dict, keep the explicitly provided stats
# But ensure _tensor_stats is properly initialized
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
self._reshape_visual_stats()
return
# Normal behavior: load stats from state_dict
@@ -208,6 +219,7 @@ class _NormalizationMixin:
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
dtype=torch.float32, device=self.device
)
self._reshape_visual_stats()
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
# and other functions that rely on self.stats
+13
View File
@@ -0,0 +1,13 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
+9 -19
View File
@@ -61,7 +61,7 @@ from lerobot.cameras import opencv # noqa: F401
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.processor import TransitionKey
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.queue import get_last_item_from_queue
@@ -248,16 +248,16 @@ def act_with_policy(
logging.info("make_policy")
### Instantiate the policy in both the actor and learner processes
### To avoid sending a SACPolicy object through the port, we create a policy instance
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
policy = policy.eval()
assert isinstance(policy, nn.Module)
# TODO: Re-enable processor pipeline once refactoring is validated against main
# preprocessor, postprocessor = None, None
obs, info = online_env.reset()
env_processor.reset()
action_processor.reset()
@@ -288,7 +288,6 @@ def act_with_policy(
# Time policy inference and check if it meets FPS requirement
with policy_timer:
# Extract observation from transition for policy
action = policy.select_action(batch=observation)
policy_fps = policy_timer.fps_last
@@ -649,12 +648,12 @@ def interactions_stream(
# Policy functions
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
"""Load the latest policy weights from the learner."""
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
if bytes_state_dict is not None:
logging.info("[ACTOR] Load new parameters from Learner.")
state_dicts = bytes_to_state_dict(bytes_state_dict)
# TODO: check encoder parameter synchronization possible issues:
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
# instead of the updated encoder params from critic (which is optimized separately)
@@ -664,18 +663,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
# - Send critic's encoder state when shared_encoder=True
# - Skip encoder params entirely when freeze_vision_encoder=True
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
# Load actor state dict
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
policy.actor.load_state_dict(actor_state_dict)
# Load discrete critic if present
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
discrete_critic_state_dict = move_state_dict_to_device(
state_dicts["discrete_critic"], device=device
)
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
state_dicts = move_state_dict_to_device(state_dicts, device=device)
policy.load_state_dict(state_dicts)
# Utilities functions
+70
View File
@@ -0,0 +1,70 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import torch
from lerobot.rl.algorithms.base import (
RLAlgorithm,
RLAlgorithmConfig,
TrainingStats,
)
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
def make_algorithm(
policy: torch.nn.Module,
policy_cfg,
*,
algorithm_name: str,
) -> RLAlgorithm:
"""Construct an :class:`RLAlgorithm` from a policy and its config.
Algorithm selection is explicit via ``algorithm_name`` (from
``cfg.algorithm``).
This is fully registry-driven — adding a new algorithm only requires
registering an ``RLAlgorithmConfig`` subclass; no changes here.
The returned algorithm has **no optimizers** yet. On the learner side,
call ``algorithm.make_optimizers()`` afterwards to create them. On the
actor side (inference-only), leave them empty.
Args:
policy: Instantiated policy (e.g. ``SACPolicy``).
policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters
expected by the algorithm config's ``from_policy_config`` class-method.
algorithm_name: Algorithm registry key to instantiate.
"""
known = RLAlgorithmConfig.get_known_choices()
if algorithm_name not in known:
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
algo_config = config_cls.from_policy_config(policy_cfg)
return algo_config.build_algorithm(policy)
__all__ = [
"RLAlgorithm",
"RLAlgorithmConfig",
"TrainingStats",
"SACAlgorithm",
"SACAlgorithmConfig",
"RLTAlgorithm",
"RLTAlgorithmConfig",
"make_algorithm",
]
+183
View File
@@ -0,0 +1,183 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Base classes for RL algorithms.
Defines the abstract interface that every algorithm must implement, a registry
for algorithm configs, and a dataclass for training statistics.
"""
from __future__ import annotations
import abc
from collections.abc import Iterator
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any
import draccus
import torch
from torch import Tensor
from torch.optim import Optimizer
if TYPE_CHECKING:
from lerobot.rl.data_sources.data_mixer import DataMixer
BatchType = dict[str, Any]
@dataclass
class TrainingStats:
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
# Generic containers for all algorithms
losses: dict[str, float] = field(default_factory=dict)
grad_norms: dict[str, float] = field(default_factory=dict)
extra: dict[str, float] = field(default_factory=dict)
def to_log_dict(self) -> dict[str, float]:
"""Flatten all stats into a single dict for logging."""
d: dict[str, float] = {}
for name, val in self.losses.items():
d[name] = val
for name, val in self.grad_norms.items():
d[f"{name}_grad_norm"] = val
for name, val in self.extra.items():
d[name] = val
return d
@dataclass
class RLAlgorithmConfig(draccus.ChoiceRegistry):
"""Registry for algorithm configs."""
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
"""Construct the :class:`RLAlgorithm` for this config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
@classmethod
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
"""Build an algorithm config from a policy config.
Must be overridden by every registered config subclass.
"""
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
class RLAlgorithm(abc.ABC):
"""Base for all RL algorithms."""
@abc.abstractmethod
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One complete training step.
The algorithm calls ``next(batch_iterator)`` as many times as it
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
The iterator is owned by the trainer; the algorithm just consumes
from it.
"""
...
def supports_offline_phase(self) -> bool:
"""Whether this algorithm has an offline pretraining phase.
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
return ``True`` here. The learner checks this before the main online
loop and routes to :meth:`offline_update` accordingly.
"""
return False
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One offline training step (called before any online collection).
Only called when :meth:`supports_offline_phase` returns ``True``.
Uses the same iterator protocol as :meth:`update`.
"""
raise NotImplementedError(
f"{type(self).__name__} does not implement offline_update(). "
"Either override this method or return False from supports_offline_phase()."
)
def transition_to_online(self) -> None: # noqa: B027
"""Called once when switching from offline to online phase.
Use this to freeze modules trained offline, rebuild optimizers for the
online phase, reset step counters, etc.
Default is a no-op; subclasses override when they have an offline phase.
"""
def configure_data_iterator(
self,
data_mixer: DataMixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
) -> Iterator[BatchType]:
"""Create the data iterator this algorithm needs.
The default implementation uses the standard ``data_mixer.get_iterator()``.
Algorithms that need specialised sampling should override this method.
"""
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create, store, and return the optimizers needed for training.
Called on the **learner** side after construction. Subclasses must
override this with algorithm-specific optimizer setup.
"""
return {}
def get_optimizers(self) -> dict[str, Optimizer]:
"""Return optimizers for checkpointing / external scheduling."""
return {}
@property
def optimization_step(self) -> int:
"""Current learner optimization step.
Part of the stable contract for checkpoint/resume. Algorithms can
either use this default storage or override for custom behavior.
"""
return getattr(self, "_optimization_step", 0)
@optimization_step.setter
def optimization_step(self, value: int) -> None:
self._optimization_step = int(value)
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors."""
return {}
@abc.abstractmethod
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner (inverse of ``get_weights``)."""
@torch.no_grad()
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
"""Pre-compute observation features (e.g. frozen encoder cache).
Returns ``(None, None)`` when caching is not applicable.
"""
return None, None
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"]
@@ -0,0 +1,83 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT algorithm configuration."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from lerobot.rl.algorithms.base import RLAlgorithmConfig
if TYPE_CHECKING:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
@RLAlgorithmConfig.register_subclass("rlt")
@dataclass
class RLTAlgorithmConfig(RLAlgorithmConfig):
"""RLT-specific hyper-parameters that control the update loop."""
# ── Action chunks ──
chunk_size: int = 10
chunk_stride: int = 2
# ── Update cadence ──
utd_ratio: int = 5
policy_update_freq: int = 2
clip_grad_norm: float = 10.0
# ── Learning rates ──
actor_lr: float = 3e-4
critic_lr: float = 3e-4
rl_token_lr: float = 1e-4
# ── TD learning ──
discount: float = 0.99
tau: float = 0.005
num_critics: int = 2
# ── Policy constraint (paper Eq. 5) ──
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
# ── Offline RL-token training ──
vla_finetune_weight: float = 0.0
@classmethod
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
return cls(
chunk_size=policy_cfg.chunk_size,
chunk_stride=policy_cfg.chunk_stride,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.clip_grad_norm,
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
rl_token_lr=policy_cfg.rl_token_lr,
discount=policy_cfg.discount,
tau=policy_cfg.tau,
num_critics=policy_cfg.num_critics,
bc_reg_coeff=policy_cfg.bc_reg_coeff,
ref_dropout=policy_cfg.ref_dropout,
vla_finetune_weight=policy_cfg.vla_finetune_weight,
)
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
return RLTAlgorithm(policy=policy, config=self)
@@ -0,0 +1,319 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) algorithm.
Implements the two-stage training from "RL Token: Bootstrapping Online RL
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
reference-action pass-through, and reference-action dropout.
"""
from __future__ import annotations
import copy
from collections.abc import Iterator
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
from lerobot.utils.constants import ACTION
class RLTCritic(nn.Module):
"""Q-function over (state, action_chunk) pairs.
Paper Eq. 3: Q_psi(x, a_{1:C})
Training-only component — lives on the algorithm side, not in the policy.
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
super().__init__()
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
x = torch.cat([state, action_chunk], dim=-1)
return self.net(x)
class RLTAlgorithm(RLAlgorithm):
"""RL Token: lightweight actor-critic on frozen VLA features.
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
ensemble, and target networks. All VLA-specific logic (embedding
extraction, reference actions) lives in ``_prepare_forward_batch``.
"""
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
self.policy = policy
self.config = config
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._device = get_device_from_parameters(self.policy)
self._is_online = False
self._init_critics()
self._move_to_device()
# ── Initialization ───────────────────────────────────────────────
def _init_critics(self) -> None:
state_dim = self.policy._state_dim
action_chunk_dim = self.policy._action_chunk_dim
hidden_dims = self.policy.config.critic.hidden_dims
self.critics = torch.nn.ModuleList(
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
)
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
for ct in self.critic_targets:
ct.requires_grad_(False)
def _move_to_device(self) -> None:
self.critics.to(self._device)
self.critic_targets.to(self._device)
# ── Offline phase (Stage 1): RL-token training ───────────────────
def supports_offline_phase(self) -> bool:
return True
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Train RL-token encoder/decoder on demonstration data.
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
"""
batch = next(batch_iterator)
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
z_rl = self.policy.rl_token_encoder(z_vla)
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
loss_ro = F.mse_loss(z_reconstructed, z_vla)
self.optimizers["rl_token"].zero_grad()
loss_ro.backward()
torch.nn.utils.clip_grad_norm_(
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
max_norm=self.config.clip_grad_norm,
)
self.optimizers["rl_token"].step()
self._optimization_step += 1
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
def transition_to_online(self) -> None:
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
self.policy.rl_token_encoder.requires_grad_(False)
self.policy.rl_token_decoder.requires_grad_(False)
self._is_online = True
self.optimizers = {
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
self._optimization_step = 0
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One full RLT update step with UTD critic warm-up.
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
the last batch also updates the actor (every ``policy_update_freq`` steps).
"""
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
self._critic_step(fb)
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
critic_loss = self._critic_step(fb)
stats = TrainingStats(losses={"loss_critic": critic_loss})
if self._optimization_step % self.config.policy_update_freq == 0:
actor_loss, bc_loss, q_val = self._actor_step(fb)
stats.losses["loss_actor"] = actor_loss
stats.extra["bc_loss"] = bc_loss
stats.extra["q_value_mean"] = q_val
self._update_target_networks()
self._optimization_step += 1
return stats
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
"""Convert a replay batch into algorithm-ready tensors.
Extracts RL-token from VLA embeddings, builds RL state, reads
reference action from complementary_info.
"""
obs = batch["state"]
next_obs = batch["next_state"]
device = self._device
vla_emb = obs["observation.vla_embeddings"].to(device)
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
with torch.no_grad():
z_rl = self.policy.rl_token_encoder(vla_emb)
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
parts = [z_rl]
next_parts = [z_rl_next]
if "observation.state" in obs and self.policy._proprioception_dim > 0:
prop = obs["observation.state"].to(device)
next_prop = next_obs["observation.state"].to(device)
parts.append(prop)
next_parts.append(next_prop)
state = torch.cat(parts, dim=-1)
next_state = torch.cat(next_parts, dim=-1)
action = batch[ACTION].to(device)
reward = batch["reward"].to(device)
done = batch["done"].to(device)
ref_action = None
comp_info = batch.get("complementary_info")
if comp_info is not None and "reference_action" in comp_info:
ref_action = comp_info["reference_action"].to(device)
return {
"state": state,
"next_state": next_state,
"action": action,
"reward": reward,
"done": done,
"reference_action": ref_action,
}
def _critic_step(self, fb: dict[str, Any]) -> float:
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
state = fb["state"]
next_state = fb["next_state"]
action = fb["action"]
reward = fb["reward"]
done = fb["done"]
with torch.no_grad():
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros_like(action)
next_action = self.policy.actor(next_state, ref)
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
discount_chunk = self.config.discount**self.config.chunk_size
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
q_preds = [c(state, action) for c in self.critics]
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
self.optimizers["critic"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["critic"].step()
return loss.item()
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
"""Paper Eq. 5: maximize Q while staying near VLA reference.
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
With reference-action dropout applied to the actor's ref input.
"""
state = fb["state"]
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
# Reference-action dropout (paper Section IV-B)
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
ref_input = ref * mask
action = self.policy.actor(state, ref_input)
q_value = self.critics[0](state, action)
bc_loss = F.mse_loss(action, ref)
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
self.optimizers["actor"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["actor"].step()
return loss.item(), bc_loss.item(), q_value.mean().item()
def _update_target_networks(self) -> None:
tau = self.config.tau
for critic, target in zip(self.critics, self.critic_targets, strict=True):
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
# ── Optimizer management ─────────────────────────────────────────
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create optimizers. Initially for RL-token (Stage 1)."""
self.optimizers = {
"rl_token": torch.optim.Adam(
list(self.policy.rl_token_encoder.parameters())
+ list(self.policy.rl_token_decoder.parameters()),
lr=self.config.rl_token_lr,
),
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
# ── Weight sync ──────────────────────────────────────────────────
def get_weights(self) -> dict[str, Any]:
"""Push actor + RL-token encoder to actors (small footprint)."""
weights = {
"actor": self.policy.actor.state_dict(),
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
}
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
if "actor" in weights:
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
if "rl_token_encoder" in weights:
self.policy.rl_token_encoder.load_state_dict(
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
)
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
@@ -0,0 +1,81 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAC algorithm configuration."""
from __future__ import annotations
from dataclasses import dataclass, field
from typing import TYPE_CHECKING
import torch
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
from lerobot.rl.algorithms.base import RLAlgorithmConfig
if TYPE_CHECKING:
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
@RLAlgorithmConfig.register_subclass("sac")
@dataclass
class SACAlgorithmConfig(RLAlgorithmConfig):
"""SAC-specific hyper-parameters that control the update loop."""
utd_ratio: int = 1
policy_update_freq: int = 1
clip_grad_norm: float = 40.0
actor_lr: float = 3e-4
critic_lr: float = 3e-4
temperature_lr: float = 3e-4
discount: float = 0.99
temperature_init: float = 1.0
target_entropy: float | None = None
use_backup_entropy: bool = True
critic_target_update_weight: float = 0.005
num_critics: int = 2
num_subsample_critics: int | None = None
num_discrete_actions: int | None = None
shared_encoder: bool = True
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
use_torch_compile: bool = True
@classmethod
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
return cls(
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.grad_clip_norm,
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
temperature_lr=policy_cfg.temperature_lr,
discount=policy_cfg.discount,
temperature_init=policy_cfg.temperature_init,
target_entropy=policy_cfg.target_entropy,
use_backup_entropy=policy_cfg.use_backup_entropy,
critic_target_update_weight=policy_cfg.critic_target_update_weight,
num_critics=policy_cfg.num_critics,
num_subsample_critics=policy_cfg.num_subsample_critics,
num_discrete_actions=policy_cfg.num_discrete_actions,
shared_encoder=policy_cfg.shared_encoder,
critic_network_kwargs=policy_cfg.critic_network_kwargs,
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
use_torch_compile=policy_cfg.use_torch_compile,
)
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
return SACAlgorithm(policy=policy, config=self)
@@ -0,0 +1,409 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""SAC (Soft Actor-Critic) algorithm.
This module encapsulates all SAC-specific training logic (critic, actor,
temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface.
"""
from __future__ import annotations
import math
from collections.abc import Iterator
from dataclasses import asdict
from typing import Any
import einops
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.sac.modeling_sac import (
DISCRETE_DIMENSION_INDEX,
CriticEnsemble,
CriticHead,
DiscreteCritic,
SACObservationEncoder,
SACPolicy,
)
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
from lerobot.utils.constants import ACTION
from lerobot.utils.transition import move_state_dict_to_device
class SACAlgorithm(RLAlgorithm):
"""Soft Actor-Critic with optional discrete-critic head.
Owns the ``SACPolicy`` and its optimizers. All loss methods call
``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor``
directly, so any policy that returns ``{"action", "log_prob"}`` from its
``forward()`` is compatible.
"""
def __init__(
self,
policy: SACPolicy,
config: SACAlgorithmConfig,
):
self.policy = policy
self.config = config
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._device = get_device_from_parameters(self.policy)
self._init_critic_encoder()
self._init_critics()
self._init_temperature()
self._move_to_device()
def _init_critic_encoder(self) -> None:
"""Build or share the encoder used by critics."""
if self.config.shared_encoder:
self.critic_encoder = self.policy.encoder
self.policy.actor.encoder_is_shared = True
else:
self.critic_encoder = SACObservationEncoder(self.policy.config)
def _init_critics(self) -> None:
"""Build critic ensemble, targets, and optional discrete critic."""
action_dim = self.policy.config.output_features[ACTION].shape[0]
input_dim = self.critic_encoder.output_dim + action_dim
heads = [
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
for _ in range(self.config.num_critics)
]
self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads)
target_heads = [
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
for _ in range(self.config.num_critics)
]
self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads)
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
if self.config.use_torch_compile:
self.critic_ensemble = torch.compile(self.critic_ensemble)
self.critic_target = torch.compile(self.critic_target)
if self.config.num_discrete_actions is not None:
self._init_discrete_critic_target()
def _init_discrete_critic_target(self) -> None:
"""Build only the target discrete critic."""
input_dim = self.critic_encoder.output_dim
self.discrete_critic_target = DiscreteCritic(
encoder=self.critic_encoder,
input_dim=input_dim,
output_dim=self.config.num_discrete_actions,
**asdict(self.config.discrete_critic_network_kwargs),
)
# TODO: (kmeftah) Compile the discrete critic
self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
def _init_temperature(self) -> None:
"""Set up temperature parameter (log_alpha) and default target entropy."""
temp_init = self.config.temperature_init
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
action_dim = self.policy.config.output_features[ACTION].shape[0]
self.target_entropy = self.config.target_entropy
if self.target_entropy is None:
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
self.target_entropy = -np.prod(dim) / 2
def _move_to_device(self) -> None:
"""Move algorithm-owned modules to the policy device."""
self.critic_ensemble.to(self._device)
self.critic_target.to(self._device)
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
if hasattr(self, "discrete_critic_target"):
self.discrete_critic_target.to(self._device)
@property
def temperature(self) -> float:
return self.log_alpha.exp().item()
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Run one full SAC update with UTD critic warm-up.
Pulls ``utd_ratio`` batches from ``batch_iterator``. The first
``utd_ratio - 1`` batches are used for critic-only warm-up steps;
the last batch drives the full update (critic + actor + temperature).
"""
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
forward_batch = self._prepare_forward_batch(batch)
loss_critic = self._compute_loss_critic(forward_batch)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
torch.nn.utils.clip_grad_norm_(
self.critic_ensemble.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["critic"].step()
if self.config.num_discrete_actions is not None:
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
self.optimizers["discrete_critic"].zero_grad()
loss_discrete.backward()
torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["discrete_critic"].step()
self._update_target_networks()
batch = next(batch_iterator)
forward_batch = self._prepare_forward_batch(batch)
loss_critic = self._compute_loss_critic(forward_batch)
self.optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
self.critic_ensemble.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["critic"].step()
critic_loss_val = loss_critic.item()
stats = TrainingStats(
losses={"loss_critic": critic_loss_val},
grad_norms={"critic": critic_grad_norm},
)
if self.config.num_discrete_actions is not None:
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
self.optimizers["discrete_critic"].zero_grad()
loss_discrete.backward()
dc_grad = torch.nn.utils.clip_grad_norm_(
self.policy.discrete_critic.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["discrete_critic"].step()
stats.losses["loss_discrete_critic"] = loss_discrete.item()
stats.grad_norms["discrete_critic"] = dc_grad
if self._optimization_step % self.config.policy_update_freq == 0:
for _ in range(self.config.policy_update_freq):
actor_loss = self._compute_loss_actor(forward_batch)
self.optimizers["actor"].zero_grad()
actor_loss.backward()
actor_grad = torch.nn.utils.clip_grad_norm_(
self.policy.actor.parameters(),
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["actor"].step()
temp_loss = self._compute_loss_temperature(forward_batch)
self.optimizers["temperature"].zero_grad()
temp_loss.backward()
temp_grad = torch.nn.utils.clip_grad_norm_(
[self.log_alpha],
max_norm=self.config.clip_grad_norm,
).item()
self.optimizers["temperature"].step()
stats.losses["loss_actor"] = actor_loss.item()
stats.losses["loss_temperature"] = temp_loss.item()
stats.grad_norms["actor"] = actor_grad
stats.grad_norms["temperature"] = temp_grad
stats.extra["temperature"] = self.temperature
self._update_target_networks()
self._optimization_step += 1
return stats
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
obs_features = batch.get("observation_feature")
next_obs_features = batch.get("next_observation_feature")
with torch.no_grad():
next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features})
next_actions = next_output["action"]
next_log_probs = next_output["log_prob"]
q_targets = self.critic_target(next_observations, next_actions, next_obs_features)
if self.config.num_subsample_critics is not None:
indices = torch.randperm(self.config.num_critics)
indices = indices[: self.config.num_subsample_critics]
q_targets = q_targets[indices]
min_q, _ = q_targets.min(dim=0)
if self.config.use_backup_entropy:
min_q = min_q - (self.temperature * next_log_probs)
td_target = rewards + (1 - done) * self.config.discount * min_q
if self.config.num_discrete_actions is not None:
actions = actions[:, :DISCRETE_DIMENSION_INDEX]
q_preds = self.critic_ensemble(observations, actions, obs_features)
td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum()
return critics_loss
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
actions = batch[ACTION]
rewards = batch["reward"]
next_observations = batch["next_state"]
done = batch["done"]
obs_features = batch.get("observation_feature")
next_obs_features = batch.get("next_observation_feature")
complementary_info = batch.get("complementary_info")
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
actions_discrete = torch.round(actions_discrete).long()
discrete_penalties: Tensor | None = None
if complementary_info is not None:
discrete_penalties = complementary_info.get("discrete_penalty")
with torch.no_grad():
next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features)
best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
target_next_qs = self.discrete_critic_target(next_observations, next_obs_features)
target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1)
rewards_disc = rewards
if discrete_penalties is not None:
rewards_disc = rewards + discrete_penalties
target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q
predicted_qs = self.policy.discrete_critic(observations, obs_features)
predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1)
return F.mse_loss(input=predicted_q, target=target_q)
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
obs_features = batch.get("observation_feature")
output = self.policy({"state": observations, "observation_feature": obs_features})
actions_pi = output["action"]
log_probs = output["log_prob"]
q_preds = self.critic_ensemble(observations, actions_pi, obs_features)
min_q = q_preds.min(dim=0)[0]
return ((self.temperature * log_probs) - min_q).mean()
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
observations = batch["state"]
obs_features = batch.get("observation_feature")
with torch.no_grad():
output = self.policy({"state": observations, "observation_feature": obs_features})
log_probs = output["log_prob"]
return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
def _update_target_networks(self) -> None:
tau = self.config.critic_target_update_weight
for target_p, p in zip(
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
):
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
if self.config.num_discrete_actions is not None:
for target_p, p in zip(
self.discrete_critic_target.parameters(),
self.policy.discrete_critic.parameters(),
strict=True,
):
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
"""Build the dict expected by loss computation from a sampled batch."""
observations = batch["state"]
next_observations = batch["next_state"]
observation_features, next_observation_features = self.get_observation_features(
observations, next_observations
)
forward_batch: dict[str, Any] = {
ACTION: batch[ACTION],
"reward": batch["reward"],
"state": observations,
"next_state": next_observations,
"done": batch["done"],
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
if "complementary_info" in batch:
forward_batch["complementary_info"] = batch["complementary_info"]
return forward_batch
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create Adam optimizers for the SAC components and store them."""
actor_params = [
p
for n, p in self.policy.actor.named_parameters()
if not self.config.shared_encoder or not n.startswith("encoder")
]
self.optimizers = {
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
}
if self.config.num_discrete_actions is not None:
self.optimizers["discrete_critic"] = torch.optim.Adam(
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
)
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
def get_weights(self) -> dict[str, Any]:
"""Policy state-dict to push to actors (includes actor + discrete critic)."""
return move_state_dict_to_device(self.policy.state_dict(), device="cpu")
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
"""Load policy state-dict received from the learner."""
state = move_state_dict_to_device(weights, device=device)
self.policy.load_state_dict(state)
@torch.no_grad()
def get_observation_features(
self, observations: Tensor, next_observations: Tensor
) -> tuple[Tensor | None, Tensor | None]:
if not self.config.shared_encoder:
return None, None
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
return None, None
if not self.policy.encoder.has_images:
return None, None
observation_features = self.policy.encoder.get_cached_image_features(observations)
next_observation_features = self.policy.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
+17
View File
@@ -0,0 +1,17 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
+94
View File
@@ -0,0 +1,94 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import abc
from typing import Any
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
BatchType = dict[str, Any]
class DataMixer(abc.ABC):
"""Abstract interface for all data mixing strategies.
Subclasses must implement ``sample(batch_size)`` and may override
``get_iterator`` for specialised iteration.
"""
@abc.abstractmethod
def sample(self, batch_size: int) -> BatchType:
"""Draw one batch of ``batch_size`` transitions."""
...
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Infinite iterator that yields batches.
The default implementation repeatedly calls ``self.sample()``.
Subclasses with underlying buffer iterators (async prefetch)
should override this for better throughput.
"""
while True:
yield self.sample(batch_size)
class OnlineOfflineMixer(DataMixer):
"""Mixes transitions from an online and an optional offline replay buffer.
When both buffers are present, each batch is constructed by sampling
``ceil(batch_size * online_ratio)`` from the online buffer and the
remainder from the offline buffer, then concatenating.
This mixer assumes both online and offline buffers are present.
"""
def __init__(
self,
online_buffer: ReplayBuffer,
offline_buffer: ReplayBuffer | None = None,
online_ratio: float = 1.0,
):
if not 0.0 <= online_ratio <= 1.0:
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
self.online_buffer = online_buffer
self.offline_buffer = offline_buffer
self.online_ratio = online_ratio
def sample(self, batch_size: int) -> BatchType:
if self.offline_buffer is None:
return self.online_buffer.sample(batch_size)
n_online = max(1, int(batch_size * self.online_ratio))
n_offline = batch_size - n_online
online_batch = self.online_buffer.sample(n_online)
offline_batch = self.offline_buffer.sample(n_offline)
return concatenate_batch_transitions(online_batch, offline_batch)
def get_iterator(
self,
batch_size: int,
async_prefetch: bool = True,
queue_size: int = 2,
):
"""Yield batches from online/offline mixed sampling."""
while True:
yield self.sample(batch_size)
+91 -283
View File
@@ -65,9 +65,11 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets.factory import make_dataset
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.factory import make_policy
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
from lerobot.rl.algorithms import make_algorithm
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.process import ProcessSignalHandler
from lerobot.rl.trainer import RLTrainer
from lerobot.rl.wandb_utils import WandBLogger
from lerobot.robots import so_follower # noqa: F401
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
@@ -93,7 +95,7 @@ from lerobot.utils.train_utils import (
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.transition import move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
get_safe_torch_device,
@@ -264,8 +266,8 @@ def add_actor_information_and_train(
- Transfers transitions from the actor to the replay buffer.
- Logs received interaction messages.
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
- Samples batches from the replay buffer and performs multiple critic updates.
- Periodically updates the actor, critic, and temperature optimizers.
- Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``).
- Periodically pushes updated weights to actors.
- Logs training statistics, including loss values and optimization frequency.
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
@@ -284,17 +286,15 @@ def add_actor_information_and_train(
# of 7%
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
clip_grad_norm_value = cfg.policy.grad_clip_norm
online_step_before_learning = cfg.policy.online_step_before_learning
utd_ratio = cfg.policy.utd_ratio
fps = cfg.env.fps
log_freq = cfg.log_freq
save_freq = cfg.save_freq
policy_update_freq = cfg.policy.policy_update_freq
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
saving_checkpoint = cfg.save_checkpoint
online_steps = cfg.policy.online_steps
async_prefetch = cfg.policy.async_prefetch
async_prefetch = cfg.async_prefetch
queue_size = cfg.queue_size
# Initialize logging for multiprocessing
if not use_threads(cfg):
@@ -306,7 +306,7 @@ def add_actor_information_and_train(
logging.info("Initializing policy")
policy: SACPolicy = make_policy(
policy = make_policy(
cfg=cfg.policy,
env_cfg=cfg.env,
)
@@ -315,19 +315,24 @@ def add_actor_information_and_train(
policy.train()
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
algorithm = make_algorithm(
policy=policy,
policy_cfg=cfg.policy,
algorithm_name=cfg.algorithm,
)
# TODO: Re-enable processor pipeline once refactoring is validated against main
preprocessor, postprocessor = None, None
# Push initial policy weights to actors (same path as periodic push)
state_bytes = state_to_bytes(algorithm.get_weights())
parameters_queue.put(state_bytes)
last_time_policy_pushed = time.time()
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
# If we are resuming, we need to load the training state
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
log_training_info(cfg=cfg, policy=policy)
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
batch_size = cfg.batch_size
total_batch_size = cfg.batch_size
offline_replay_buffer = None
if cfg.dataset is not None:
@@ -336,20 +341,70 @@ def add_actor_information_and_train(
device=device,
storage_device=storage_device,
)
batch_size: int = batch_size // 2 # We will sample from both replay buffer
# DataMixer: online-only or online/offline 50-50 mix
data_mixer = OnlineOfflineMixer(
online_buffer=replay_buffer,
offline_buffer=offline_replay_buffer,
online_ratio=cfg.online_ratio,
)
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=total_batch_size,
preprocessor=preprocessor,
action_dim=cfg.policy.output_features["action"].shape[0],
async_prefetch=async_prefetch,
queue_size=queue_size,
)
# If we are resuming, we need to load the training state
optimizers = algorithm.get_optimizers()
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
logging.info("Starting learner thread")
interaction_message = None
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
algorithm.optimization_step = optimization_step
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
dataset_repo_id = None
if cfg.dataset is not None:
dataset_repo_id = cfg.dataset.repo_id
# Initialize iterators
online_iterator = None
offline_iterator = None
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
offline_steps = getattr(cfg.policy, "offline_steps", 0)
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
offline_mixer = OnlineOfflineMixer(
online_buffer=offline_replay_buffer,
offline_buffer=None,
online_ratio=1.0,
)
offline_iterator = algorithm.configure_data_iterator(
data_mixer=offline_mixer,
batch_size=total_batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
for step in range(offline_steps):
if shutdown_event is not None and shutdown_event.is_set():
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
return
stats = algorithm.offline_update(offline_iterator)
if step % log_freq == 0:
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
if wandb_logger:
log_dict = stats.to_log_dict()
log_dict["offline_step"] = step
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
algorithm.transition_to_online()
optimizers = algorithm.get_optimizers()
logging.info("[LEARNER] Offline phase complete, transitioned to online")
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
while True:
@@ -380,180 +435,22 @@ def add_actor_information_and_train(
if len(replay_buffer) < online_step_before_learning:
continue
if online_iterator is None:
online_iterator = replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
if offline_replay_buffer is not None and offline_iterator is None:
offline_iterator = offline_replay_buffer.get_iterator(
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
)
time_for_one_optimization_step = time.time()
for _ in range(utd_ratio - 1):
# Sample from the iterators
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
"complementary_info": batch["complementary_info"],
}
# Use the forward method for critic loss
critic_output = policy.forward(forward_batch, model="critic")
# Main critic optimization
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
)
optimizers["critic"].step()
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
)
optimizers["discrete_critic"].step()
# Update target networks (main and discrete)
policy.update_target_networks()
# Sample for the last update in the UTD ratio
batch = next(online_iterator)
if dataset_repo_id is not None:
batch_offline = next(offline_iterator)
batch = concatenate_batch_transitions(
left_batch_transitions=batch, right_batch_transition=batch_offline
)
actions = batch[ACTION]
rewards = batch["reward"]
observations = batch["state"]
next_observations = batch["next_state"]
done = batch["done"]
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
observation_features, next_observation_features = get_observation_features(
policy=policy, observations=observations, next_observations=next_observations
)
# Create a batch dictionary with all required elements for the forward method
forward_batch = {
ACTION: actions,
"reward": rewards,
"state": observations,
"next_state": next_observations,
"done": done,
"observation_feature": observation_features,
"next_observation_feature": next_observation_features,
}
critic_output = policy.forward(forward_batch, model="critic")
loss_critic = critic_output["loss_critic"]
optimizers["critic"].zero_grad()
loss_critic.backward()
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["critic"].step()
# Initialize training info dictionary
training_infos = {
"loss_critic": loss_critic.item(),
"critic_grad_norm": critic_grad_norm,
}
# Discrete critic optimization (if available)
if policy.config.num_discrete_actions is not None:
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
optimizers["discrete_critic"].zero_grad()
loss_discrete_critic.backward()
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["discrete_critic"].step()
# Add discrete critic info to training info
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
# Actor and temperature optimization (at specified frequency)
if optimization_step % policy_update_freq == 0:
for _ in range(policy_update_freq):
# Actor optimization
actor_output = policy.forward(forward_batch, model="actor")
loss_actor = actor_output["loss_actor"]
optimizers["actor"].zero_grad()
loss_actor.backward()
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
).item()
optimizers["actor"].step()
# Add actor info to training info
training_infos["loss_actor"] = loss_actor.item()
training_infos["actor_grad_norm"] = actor_grad_norm
# Temperature optimization
temperature_output = policy.forward(forward_batch, model="temperature")
loss_temperature = temperature_output["loss_temperature"]
optimizers["temperature"].zero_grad()
loss_temperature.backward()
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
).item()
optimizers["temperature"].step()
# Add temperature info to training info
training_infos["loss_temperature"] = loss_temperature.item()
training_infos["temperature_grad_norm"] = temp_grad_norm
training_infos["temperature"] = policy.temperature
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
stats = trainer.training_step()
# Push policy to actors if needed
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
state_dicts = algorithm.get_weights()
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
last_time_policy_pushed = time.time()
# Update target networks (main and discrete)
policy.update_target_networks()
training_infos = stats.to_log_dict()
# Log training metrics at specified intervals
optimization_step = algorithm.optimization_step
if optimization_step % log_freq == 0:
training_infos["replay_buffer_size"] = len(replay_buffer)
if offline_replay_buffer is not None:
@@ -581,7 +478,6 @@ def add_actor_information_and_train(
custom_step_key="Optimization step",
)
optimization_step += 1
if optimization_step % log_freq == 0:
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
@@ -598,6 +494,8 @@ def add_actor_information_and_train(
offline_replay_buffer=offline_replay_buffer,
dataset_repo_id=dataset_repo_id,
fps=fps,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
@@ -682,6 +580,8 @@ def save_training_checkpoint(
offline_replay_buffer: ReplayBuffer | None = None,
dataset_repo_id: str | None = None,
fps: int = 30,
preprocessor=None,
postprocessor=None,
) -> None:
"""
Save training checkpoint and associated data.
@@ -705,6 +605,8 @@ def save_training_checkpoint(
offline_replay_buffer: Optional offline replay buffer to save
dataset_repo_id: Repository ID for dataset
fps: Frames per second for dataset
preprocessor: Optional preprocessor pipeline to save
postprocessor: Optional postprocessor pipeline to save
"""
logging.info(f"Checkpoint policy after step {optimization_step}")
_num_digits = max(6, len(str(online_steps)))
@@ -721,6 +623,8 @@ def save_training_checkpoint(
policy=policy,
optimizer=optimizers,
scheduler=None,
preprocessor=preprocessor,
postprocessor=postprocessor,
)
# Save interaction step manually
@@ -758,58 +662,6 @@ def save_training_checkpoint(
logging.info("Resume training")
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
"""
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
This function sets up Adam optimizers for:
- The **actor network**, ensuring that only relevant parameters are optimized.
- The **critic ensemble**, which evaluates the value function.
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
It also initializes a learning rate scheduler, though currently, it is set to `None`.
NOTE:
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
Args:
cfg: Configuration object containing hyperparameters.
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
Returns:
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
A tuple containing:
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
"""
optimizer_actor = torch.optim.Adam(
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=cfg.policy.actor_lr,
)
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
if cfg.policy.num_discrete_actions is not None:
optimizer_discrete_critic = torch.optim.Adam(
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
)
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
lr_scheduler = None
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if cfg.policy.num_discrete_actions is not None:
optimizers["discrete_critic"] = optimizer_discrete_critic
return optimizers, lr_scheduler
# Training setup functions
@@ -1014,33 +866,6 @@ def initialize_offline_replay_buffer(
# Utilities/Helpers functions
def get_observation_features(
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
"""
Get observation features from the policy encoder. It act as cache for the observation features.
when the encoder is frozen, the observation features are not updated.
We can save compute by caching the observation features.
Args:
policy: The policy model
observations: The current observations
next_observations: The next observations
Returns:
tuple: observation_features, next_observation_features
"""
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
return None, None
with torch.no_grad():
observation_features = policy.actor.encoder.get_cached_image_features(observations)
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
return observation_features, next_observation_features
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
return cfg.policy.concurrency.learner == "threads"
@@ -1091,23 +916,6 @@ def check_nan_in_transition(
return nan_detected
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
logging.debug("[LEARNER] Pushing actor policy to the queue")
# Create a dictionary to hold all the state dicts
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
# Add discrete critic if it exists
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
state_dicts["discrete_critic"] = move_state_dict_to_device(
policy.discrete_critic.state_dict(), device="cpu"
)
logging.debug("[LEARNER] Including discrete critic in state dict push")
state_bytes = state_to_bytes(state_dicts)
parameters_queue.put(state_bytes)
def process_interaction_message(
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
):
+132
View File
@@ -0,0 +1,132 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Iterator
from typing import Any
import torch
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.data_sources.data_mixer import DataMixer
from lerobot.utils.constants import ACTION
def preprocess_rl_batch(preprocessor: Any, batch: BatchType, *, action_dim: int | None = None) -> BatchType:
"""Apply a policy preprocessor to an RL batch."""
observations = batch["state"]
next_observations = batch["next_state"]
actions = batch[ACTION]
extra_action = None
if action_dim is not None and actions.shape[-1] > action_dim:
extra_action = actions[..., action_dim:]
actions = actions[..., :action_dim]
obs_action = {**observations, ACTION: actions}
obs_action = preprocessor(obs_action)
batch["state"] = {k: v for k, v in obs_action.items() if k.startswith("observation.")}
batch[ACTION] = obs_action[ACTION]
if extra_action is not None:
batch[ACTION] = torch.cat([batch[ACTION], extra_action], dim=-1)
next_obs = {**next_observations}
next_obs = preprocessor(next_obs)
batch["next_state"] = {k: v for k, v in next_obs.items() if k.startswith("observation.")}
return batch
class _PreprocessedIterator:
"""Iterator wrapper that preprocesses each sampled RL batch."""
__slots__ = ("_raw", "_preprocessor", "_action_dim")
def __init__(
self, raw_iterator: Iterator[BatchType], preprocessor: Any, action_dim: int | None = None
) -> None:
self._raw = raw_iterator
self._preprocessor = preprocessor
self._action_dim = action_dim
def __iter__(self) -> _PreprocessedIterator:
return self
def __next__(self) -> BatchType:
batch = next(self._raw)
return preprocess_rl_batch(self._preprocessor, batch, action_dim=self._action_dim)
class RLTrainer:
"""Unified training step orchestrator.
Holds the algorithm, a DataMixer, and an optional preprocessor.
"""
def __init__(
self,
algorithm: RLAlgorithm,
data_mixer: DataMixer,
batch_size: int,
*,
preprocessor: Any | None = None,
action_dim: int | None = None,
async_prefetch: bool = True,
queue_size: int = 2,
):
self.algorithm = algorithm
self.data_mixer = data_mixer
self.batch_size = batch_size
self._preprocessor = preprocessor
self._action_dim = action_dim
self.async_prefetch = async_prefetch
self.queue_size = queue_size
self._iterator: Iterator[BatchType] | None = None
self.algorithm.make_optimizers()
def _build_data_iterator(self) -> Iterator[BatchType]:
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
raw = self.algorithm.configure_data_iterator(
data_mixer=self.data_mixer,
batch_size=self.batch_size,
async_prefetch=self.async_prefetch,
queue_size=self.queue_size,
)
if self._preprocessor is not None:
return _PreprocessedIterator(raw, self._preprocessor, self._action_dim)
return raw
def reset_data_iterator(self) -> None:
"""Discard the current iterator so it will be rebuilt lazily next step."""
self._iterator = None
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
"""Swap the active data mixer, optionally resetting the iterator."""
self.data_mixer = data_mixer
if reset:
self.reset_data_iterator()
def training_step(self) -> TrainingStats:
"""Run one training step (algorithm-agnostic)."""
if self._iterator is None:
self._iterator = self._build_data_iterator()
return self.algorithm.update(self._iterator)
+1
View File
@@ -95,6 +95,7 @@ def save_checkpoint(
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
preprocessor: The preprocessor/pipeline to save. Defaults to None.
postprocessor: The postprocessor/pipeline to save. Defaults to None.
"""
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
policy.save_pretrained(pretrained_dir)
+188 -207
View File
@@ -14,8 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import pytest
import torch
from torch import Tensor, nn
@@ -23,6 +21,7 @@ from torch import Tensor, nn
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import seeded_context, set_seed
@@ -138,41 +137,6 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
}
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
"""Create optimizers for the SAC policy."""
optimizer_actor = torch.optim.Adam(
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
params=[
p
for n, p in policy.actor.named_parameters()
if not policy.config.shared_encoder or not n.startswith("encoder")
],
lr=policy.config.actor_lr,
)
optimizer_critic = torch.optim.Adam(
params=policy.critic_ensemble.parameters(),
lr=policy.config.critic_lr,
)
optimizer_temperature = torch.optim.Adam(
params=[policy.log_alpha],
lr=policy.config.critic_lr,
)
optimizers = {
"actor": optimizer_actor,
"critic": optimizer_critic,
"temperature": optimizer_temperature,
}
if has_discrete_action:
optimizers["discrete_critic"] = torch.optim.Adam(
params=policy.discrete_critic.parameters(),
lr=policy.config.critic_lr,
)
return optimizers
def create_default_config(
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
) -> SACConfig:
@@ -212,7 +176,6 @@ def create_config_with_visual_input(
"std": torch.randn(3, 1, 1),
}
# Let make tests a little bit faster
config.state_encoder_hidden_dim = 32
config.latent_dim = 32
@@ -220,75 +183,112 @@ def create_config_with_visual_input(
return config
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
"""Helper to create policy + algorithm pair for tests that need critics."""
policy = SACPolicy(config=config)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers()
return algorithm, policy
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
assert actor_loss.shape == ()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
# squeeze(0) removes batch dim when batch_size==1
assert selected_action.shape[-1] == action_dim
def test_sac_policy_select_action_with_discrete():
"""select_action should return continuous + discrete actions."""
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
# Squeeze to unbatched (single observation)
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
selected_action = policy.select_action(observation_batch)
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
policy = SACPolicy(config=config)
policy.eval()
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
with torch.no_grad():
output = policy.forward(batch)
assert "action" in output
assert "log_prob" in output
assert "action_mean" in output
assert output["action"].shape == (batch_size, action_dim)
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
forward_batch = algorithm._prepare_forward_batch(batch)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
algorithm.optimizers["actor"].step()
temp_loss = algorithm._compute_loss_temperature(forward_batch)
assert temp_loss.item() is not None
assert temp_loss.shape == ()
algorithm.optimizers["temperature"].zero_grad()
temp_loss.backward()
algorithm.optimizers["temperature"].step()
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
policy.train()
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
optimizers["actor"].step()
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
assert temperature_loss.item() is not None
assert temperature_loss.shape == ()
temperature_loss.backward()
optimizers["temperature"].step()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
@@ -296,207 +296,181 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, action_dim)
assert selected_action.shape[-1] == action_dim
# Let's check best candidates for pretrained encoders
@pytest.mark.parametrize(
"batch_size,state_dim,action_dim,vision_encoder_name",
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
)
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
def test_sac_policy_with_pretrained_encoder(
def test_sac_training_with_pretrained_encoder(
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
):
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.vision_encoder_name = vision_encoder_name
policy = SACPolicy(config=config)
policy.train()
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
optimizers = make_optimizers(policy)
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.item() is not None
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.item() is not None
assert actor_loss.shape == ()
def test_sac_policy_with_shared_encoder():
def test_sac_training_with_shared_encoder():
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.shared_encoder = True
policy = SACPolicy(config=config)
policy.train()
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
policy.train()
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
optimizers["actor"].step()
algorithm.optimizers["actor"].step()
def test_sac_policy_with_discrete_critic():
def test_sac_training_with_discrete_critic():
batch_size = 2
continuous_action_dim = 9
full_action_dim = continuous_action_dim + 1 # the last action is discrete
full_action_dim = continuous_action_dim + 1
state_dim = 10
config = create_config_with_visual_input(
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
)
config.num_discrete_actions = 5
num_discrete_actions = 5
config.num_discrete_actions = num_discrete_actions
policy = SACPolicy(config=config)
policy.train()
algorithm, policy = _make_algorithm(config)
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
policy.train()
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
optimizers = make_optimizers(policy, has_discrete_action=True)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
assert discrete_critic_loss.item() is not None
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
assert discrete_critic_loss.shape == ()
algorithm.optimizers["discrete_critic"].zero_grad()
discrete_critic_loss.backward()
optimizers["discrete_critic"].step()
algorithm.optimizers["discrete_critic"].step()
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
assert actor_loss.item() is not None
actor_loss = algorithm._compute_loss_actor(forward_batch)
assert actor_loss.shape == ()
algorithm.optimizers["actor"].zero_grad()
actor_loss.backward()
optimizers["actor"].step()
algorithm.optimizers["actor"].step()
policy.eval()
with torch.no_grad():
observation_batch = create_observation_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim
)
selected_action = policy.select_action(observation_batch)
assert selected_action.shape == (batch_size, full_action_dim)
discrete_actions = selected_action[:, -1].long()
discrete_action_values = set(discrete_actions.tolist())
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
)
# Policy.select_action now handles both continuous + discrete
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
assert selected_action.shape[-1] == continuous_action_dim + 1
def test_sac_policy_with_default_entropy():
def test_sac_algorithm_target_entropy():
config = create_default_config(continuous_action_dim=10, state_dim=10)
policy = SACPolicy(config=config)
assert policy.target_entropy == -5.0
_, policy = _make_algorithm(config)
algo_config = SACAlgorithmConfig.from_policy_config(config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -5.0
def test_sac_policy_default_target_entropy_with_discrete_action():
def test_sac_algorithm_target_entropy_with_discrete_action():
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
config.num_discrete_actions = 5
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
assert policy.target_entropy == -3.0
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert algorithm.target_entropy == -3.5
def test_sac_policy_with_predefined_entropy():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.target_entropy = -3.5
def test_sac_algorithm_temperature():
import math
policy = SACPolicy(config=config)
assert policy.target_entropy == pytest.approx(-3.5)
def test_sac_policy_update_temperature():
"""Test that temperature property is always in sync with log_alpha."""
config = create_default_config(continuous_action_dim=10, state_dim=10)
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
assert policy.temperature == pytest.approx(1.0)
policy.log_alpha.data = torch.tensor([math.log(0.1)])
# Temperature property automatically reflects log_alpha changes
assert policy.temperature == pytest.approx(0.1)
assert algorithm.temperature == pytest.approx(1.0)
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
assert algorithm.temperature == pytest.approx(0.1)
def test_sac_policy_update_target_network():
def test_sac_algorithm_update_target_network():
config = create_default_config(state_dim=10, continuous_action_dim=6)
config.critic_target_update_weight = 1.0
algo_config = SACAlgorithmConfig.from_policy_config(config)
policy = SACPolicy(config=config)
policy.train()
algorithm = SACAlgorithm(policy=policy, config=algo_config)
for p in policy.critic_ensemble.parameters():
for p in algorithm.critic_ensemble.parameters():
p.data = torch.ones_like(p.data)
policy.update_target_networks()
for p in policy.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data)), (
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
)
algorithm._update_target_networks()
for p in algorithm.critic_target.parameters():
assert torch.allclose(p.data, torch.ones_like(p.data))
@pytest.mark.parametrize("num_critics", [1, 3])
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
batch_size = 2
action_dim = 10
state_dim = 10
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_critics = num_critics
policy = SACPolicy(config=config)
policy.train()
algorithm, policy = _make_algorithm(config)
assert len(policy.critic_ensemble.critics) == num_critics
assert len(algorithm.critic_ensemble.critics) == num_critics
batch = create_train_batch_with_visual_input(
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
)
forward_batch = algorithm._prepare_forward_batch(batch)
policy.train()
optimizers = make_optimizers(policy)
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
assert cirtic_loss.item() is not None
assert cirtic_loss.shape == ()
cirtic_loss.backward()
optimizers["critic"].step()
critic_loss = algorithm._compute_loss_critic(forward_batch)
assert critic_loss.shape == ()
algorithm.optimizers["critic"].zero_grad()
critic_loss.backward()
algorithm.optimizers["critic"].step()
def test_sac_policy_save_and_load(tmp_path):
"""Test that the policy can be saved and loaded from pretrained."""
root = tmp_path / "test_sac_save_and_load"
state_dim = 10
@@ -510,34 +484,41 @@ def test_sac_policy_save_and_load(tmp_path):
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
with torch.no_grad():
with seeded_context(12):
# Collect policy values before saving
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
actions = policy.select_action(observation_batch)
with seeded_context(12):
# Collect policy values after loading
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
# Compare values before and after saving and loading
# They should be the same
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
assert torch.allclose(actor_loss, loaded_actor_loss)
assert torch.allclose(temperature_loss, loaded_temperature_loss)
assert torch.allclose(actions, loaded_actions)
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
"""Discrete critic should be saved/loaded as part of the policy."""
root = tmp_path / "test_sac_save_and_load_discrete"
state_dim = 10
action_dim = 6
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
config.num_discrete_actions = 3
policy = SACPolicy(config=config)
policy.eval()
policy.save_pretrained(root)
loaded_policy = SACPolicy.from_pretrained(root, config=config)
loaded_policy.eval()
assert loaded_policy.discrete_critic is not None
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
for k in policy.state_dict():
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
+171 -1
View File
@@ -23,8 +23,9 @@ import torch
from torch.multiprocessing import Event, Queue
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.utils.constants import OBS_STR
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.transition import Transition
from tests.utils import require_package
@@ -296,3 +297,172 @@ def test_end_to_end_parameters_flow(cfg, data_size):
assert received_params.keys() == input_params.keys()
for key in input_params:
assert torch.allclose(received_params[key], input_params[key])
# ---------------------------------------------------------------------------
# Regression test: learner algorithm integration (no gRPC required)
# ---------------------------------------------------------------------------
def test_learner_algorithm_wiring():
"""Verify that make_algorithm constructs an SACAlgorithm from config,
make_optimizers() creates the right optimizers, update() works, and
get_weights() output is serializable."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.algorithms import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
from lerobot.transport.utils import state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
optimizers = algorithm.make_optimizers()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
batch_size = 4
def batch_iterator():
while True:
yield {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
"done": torch.zeros(batch_size),
"complementary_info": {},
}
stats = algorithm.update(batch_iterator())
assert "critic" in stats.losses
# get_weights -> state_to_bytes round-trip
weights = algorithm.get_weights()
assert len(weights) > 0
serialized = state_to_bytes(weights)
assert isinstance(serialized, bytes)
assert len(serialized) > 0
# RLTrainer with DataMixer
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.rl.trainer import RLTrainer
replay_buffer = ReplayBuffer(
capacity=50,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for _ in range(50):
replay_buffer.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(action_dim),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=False,
truncated=False,
)
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
trainer = RLTrainer(
algorithm=algorithm,
data_mixer=data_mixer,
batch_size=batch_size,
async_prefetch=False,
)
trainer_stats = trainer.training_step()
assert "critic" in trainer_stats.losses
def test_initial_and_periodic_weight_push_consistency():
"""Both initial and periodic weight pushes should use algorithm.get_weights()
and produce identical structures."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.algorithms import make_algorithm
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
policy = SACPolicy(config=sac_cfg)
policy.train()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
algorithm.make_optimizers()
# Simulate initial push (same code path the learner now uses)
initial_weights = algorithm.get_weights()
initial_bytes = state_to_bytes(initial_weights)
# Simulate periodic push
periodic_weights = algorithm.get_weights()
periodic_bytes = state_to_bytes(periodic_weights)
initial_decoded = bytes_to_state_dict(initial_bytes)
periodic_decoded = bytes_to_state_dict(periodic_bytes)
assert initial_decoded.keys() == periodic_decoded.keys()
def test_actor_side_algorithm_select_action_and_load_weights():
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.algorithms import make_algorithm
from lerobot.rl.algorithms.sac import SACAlgorithm
state_dim = 10
action_dim = 6
sac_cfg = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
use_torch_compile=False,
)
sac_cfg.validate_features()
# Actor side: no optimizers
policy = SACPolicy(config=sac_cfg)
policy.eval()
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
# select_action should work
obs = {OBS_STATE: torch.randn(state_dim)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
# Simulate receiving weights from learner
fake_weights = algorithm.get_weights()
algorithm.load_weights(fake_weights, device="cpu")
+85
View File
@@ -0,0 +1,85 @@
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
import torch
from lerobot.rl.buffer import ReplayBuffer
from lerobot.rl.data_sources import OnlineOfflineMixer
from lerobot.utils.constants import OBS_STATE
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
buf = ReplayBuffer(
capacity=capacity,
device="cpu",
state_keys=[OBS_STATE],
storage_device="cpu",
use_drq=False,
)
for i in range(capacity):
buf.add(
state={OBS_STATE: torch.randn(state_dim)},
action=torch.randn(2),
reward=1.0,
next_state={OBS_STATE: torch.randn(state_dim)},
done=bool(i % 10 == 9),
truncated=False,
)
return buf
def test_online_only_mixer_sample():
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
batch = mixer.sample(batch_size=8)
assert batch["state"][OBS_STATE].shape[0] == 8
assert batch["action"].shape[0] == 8
assert batch["reward"].shape[0] == 8
def test_online_only_mixer_ratio_one():
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
def test_online_offline_mixer_sample():
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
online = _make_buffer(capacity=50)
offline = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(
online_buffer=online,
offline_buffer=offline,
online_ratio=0.5,
)
batch = mixer.sample(batch_size=10)
assert batch["state"][OBS_STATE].shape[0] == 10
assert batch["action"].shape[0] == 10
# 5 from online, 5 from offline (approx)
assert batch["reward"].shape[0] == 10
def test_online_offline_mixer_iterator():
"""get_iterator yields batches of the requested size."""
buf = _make_buffer(capacity=50)
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
batch1 = next(it)
batch2 = next(it)
assert batch1["state"][OBS_STATE].shape[0] == 4
assert batch2["state"][OBS_STATE].shape[0] == 4
+477
View File
@@ -0,0 +1,477 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
import pytest
import torch
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.policies.sac.configuration_sac import SACConfig
from lerobot.policies.sac.modeling_sac import SACPolicy
from lerobot.rl.algorithms import make_algorithm
from lerobot.rl.algorithms.base import RLAlgorithmConfig, TrainingStats
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
from lerobot.utils.random_utils import set_seed
# ---------------------------------------------------------------------------
# Helpers (reuse patterns from tests/policies/test_sac_policy.py)
# ---------------------------------------------------------------------------
@pytest.fixture(autouse=True)
def set_random_seed():
set_seed(42)
def _make_sac_config(
state_dim: int = 10,
action_dim: int = 6,
num_discrete_actions: int | None = None,
utd_ratio: int = 1,
policy_update_freq: int = 1,
with_images: bool = False,
) -> SACConfig:
config = SACConfig(
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
dataset_stats={
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
},
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
use_torch_compile=False,
)
if with_images:
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
config.dataset_stats[OBS_IMAGE] = {
"mean": torch.randn(3, 1, 1).tolist(),
"std": torch.randn(3, 1, 1).abs().tolist(),
}
config.latent_dim = 32
config.state_encoder_hidden_dim = 32
config.validate_features()
return config
def _make_algorithm(
state_dim: int = 10,
action_dim: int = 6,
utd_ratio: int = 1,
policy_update_freq: int = 1,
num_discrete_actions: int | None = None,
with_images: bool = False,
) -> tuple[SACAlgorithm, SACPolicy]:
sac_cfg = _make_sac_config(
state_dim=state_dim,
action_dim=action_dim,
utd_ratio=utd_ratio,
policy_update_freq=policy_update_freq,
num_discrete_actions=num_discrete_actions,
with_images=with_images,
)
policy = SACPolicy(config=sac_cfg)
policy.train()
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
algorithm = SACAlgorithm(policy=policy, config=algo_config)
algorithm.make_optimizers()
return algorithm, policy
def _make_batch(
batch_size: int = 4,
state_dim: int = 10,
action_dim: int = 6,
with_images: bool = False,
) -> dict:
obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
if with_images:
obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
return {
ACTION: torch.randn(batch_size, action_dim),
"reward": torch.randn(batch_size),
"state": obs,
"next_state": next_obs,
"done": torch.zeros(batch_size),
"complementary_info": {},
}
def _batch_iterator(**batch_kwargs):
"""Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator)."""
while True:
yield _make_batch(**batch_kwargs)
# ===========================================================================
# Registry / config tests
# ===========================================================================
def test_sac_algorithm_config_registered():
"""SACAlgorithmConfig should be discoverable through the registry."""
assert "sac" in RLAlgorithmConfig.get_known_choices()
cls = RLAlgorithmConfig.get_choice_class("sac")
assert cls is SACAlgorithmConfig
def test_sac_algorithm_config_from_policy_config():
"""from_policy_config should copy relevant fields."""
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
assert algo_cfg.utd_ratio == 4
assert algo_cfg.policy_update_freq == 2
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm
# ===========================================================================
# TrainingStats tests
# ===========================================================================
def test_training_stats_defaults():
stats = TrainingStats()
assert stats.losses == {}
assert stats.grad_norms == {}
assert stats.extra == {}
# ===========================================================================
# get_weights
# ===========================================================================
def test_get_weights_returns_policy_state_dict():
algorithm, policy = _make_algorithm()
weights = algorithm.get_weights()
for key in policy.state_dict():
assert key in weights
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
def test_get_weights_includes_discrete_critic_when_present():
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
weights = algorithm.get_weights()
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
def test_get_weights_excludes_discrete_critic_when_absent():
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) == 0
def test_get_weights_are_on_cpu():
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
for key, tensor in weights.items():
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
# ===========================================================================
# select_action (lives on the policy, not the algorithm)
# ===========================================================================
def test_select_action_returns_correct_shape():
action_dim = 6
_, policy = _make_algorithm(state_dim=10, action_dim=action_dim)
policy.eval()
obs = {OBS_STATE: torch.randn(10)}
action = policy.select_action(obs)
assert action.shape == (action_dim,)
def test_select_action_with_discrete_critic():
continuous_dim = 5
_, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
policy.eval()
obs = {OBS_STATE: torch.randn(10)}
action = policy.select_action(obs)
assert action.shape == (continuous_dim + 1,)
# ===========================================================================
# update (single batch, utd_ratio=1)
# ===========================================================================
def test_update_returns_training_stats():
algorithm, _ = _make_algorithm()
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "critic" in stats.losses
assert isinstance(stats.losses["critic"], float)
def test_update_populates_actor_and_temperature_losses():
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
algorithm, _ = _make_algorithm(policy_update_freq=1)
stats = algorithm.update(_batch_iterator())
assert "actor" in stats.losses
assert "temperature" in stats.losses
assert "temperature" in stats.extra
@pytest.mark.parametrize("policy_update_freq", [2, 3])
def test_update_skips_actor_at_non_update_steps(policy_update_freq):
"""Actor/temperature should only update when optimization_step % freq == 0."""
algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq)
it = _batch_iterator()
# Step 0: should update actor
stats_0 = algorithm.update(it)
assert "actor" in stats_0.losses
# Step 1: should NOT update actor
stats_1 = algorithm.update(it)
assert "actor" not in stats_1.losses
def test_update_increments_optimization_step():
algorithm, _ = _make_algorithm()
it = _batch_iterator()
assert algorithm.optimization_step == 0
algorithm.update(it)
assert algorithm.optimization_step == 1
algorithm.update(it)
assert algorithm.optimization_step == 2
def test_update_with_discrete_critic():
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
assert "discrete_critic" in stats.losses
assert "discrete_critic" in stats.grad_norms
# ===========================================================================
# update with UTD ratio > 1
# ===========================================================================
@pytest.mark.parametrize("utd_ratio", [2, 4])
def test_update_with_utd_ratio(utd_ratio):
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
stats = algorithm.update(_batch_iterator())
assert isinstance(stats, TrainingStats)
assert "critic" in stats.losses
assert algorithm.optimization_step == 1
def test_update_utd_ratio_pulls_utd_batches():
"""next(batch_iterator) should be called exactly utd_ratio times."""
utd_ratio = 3
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
call_count = 0
def counting_iterator():
nonlocal call_count
while True:
call_count += 1
yield _make_batch()
algorithm.update(counting_iterator())
assert call_count == utd_ratio
def test_update_utd_ratio_3_critic_warmup_changes_weights():
"""With utd_ratio=3, critic weights should change after update (3 critic steps)."""
algorithm, policy = _make_algorithm(utd_ratio=3)
critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()}
algorithm.update(_batch_iterator())
changed = False
for n, p in algorithm.critic_ensemble.named_parameters():
if not torch.equal(p, critic_params_before[n]):
changed = True
break
assert changed, "Critic weights should have changed after UTD update"
# ===========================================================================
# get_observation_features
# ===========================================================================
def test_get_observation_features_returns_none_without_frozen_encoder():
algorithm, _ = _make_algorithm(with_images=False)
obs = {OBS_STATE: torch.randn(4, 10)}
next_obs = {OBS_STATE: torch.randn(4, 10)}
feat, next_feat = algorithm.get_observation_features(obs, next_obs)
assert feat is None
assert next_feat is None
# ===========================================================================
# optimization_step setter
# ===========================================================================
def test_optimization_step_can_be_set_for_resume():
algorithm, _ = _make_algorithm()
algorithm.optimization_step = 100
assert algorithm.optimization_step == 100
# ===========================================================================
# make_algorithm factory
# ===========================================================================
def test_make_algorithm_returns_sac_for_sac_policy():
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
def test_make_optimizers_creates_expected_keys():
"""make_optimizers() should populate the algorithm with Adam optimizers."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
optimizers = algorithm.make_optimizers()
assert "actor" in optimizers
assert "critic" in optimizers
assert "temperature" in optimizers
assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values())
assert algorithm.get_optimizers() is optimizers
def test_actor_side_no_optimizers():
"""Actor-side usage: no optimizers needed, make_optimizers is not called."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.optimizers == {}
def test_make_algorithm_copies_config_fields():
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
policy = SACPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert algorithm.config.utd_ratio == 5
assert algorithm.config.policy_update_freq == 3
def test_make_algorithm_raises_for_unknown_type():
class FakeConfig:
type = "unknown_algo"
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
# ===========================================================================
# load_weights (round-trip with get_weights)
# ===========================================================================
def test_load_weights_round_trip():
"""get_weights -> load_weights should restore identical parameters on a fresh policy."""
algo_src, _ = _make_algorithm(state_dim=10, action_dim=6)
algo_src.update(_batch_iterator())
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
for key in weights:
assert torch.equal(
algo_dst.policy.state_dict()[key].cpu(),
weights[key].cpu(),
), f"Policy param '{key}' mismatch after load_weights"
def test_load_weights_round_trip_with_discrete_critic():
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
algo_src.update(_batch_iterator(action_dim=7))
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
policy_dst = SACPolicy(config=sac_cfg)
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
weights = algo_src.get_weights()
algo_dst.load_weights(weights, device="cpu")
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
assert len(dc_keys) > 0
for key in dc_keys:
assert torch.equal(
algo_dst.policy.state_dict()[key].cpu(),
weights[key].cpu(),
), f"Discrete critic param '{key}' mismatch after load_weights"
def test_load_weights_ignores_missing_discrete_critic():
"""load_weights should not fail when weights lack discrete_critic on a non-discrete policy."""
algorithm, _ = _make_algorithm()
weights = algorithm.get_weights()
algorithm.load_weights(weights, device="cpu")
# ===========================================================================
# TrainingStats generic losses dict
# ===========================================================================
def test_training_stats_generic_losses():
stats = TrainingStats(
losses={"loss_bc": 0.5, "loss_q": 1.2},
extra={"temperature": 0.1},
)
assert stats.losses["loss_bc"] == 0.5
assert stats.losses["loss_q"] == 1.2
assert stats.extra["temperature"] == 0.1
# ===========================================================================
# Registry-driven build_algorithm
# ===========================================================================
def test_build_algorithm_via_config():
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
sac_cfg = _make_sac_config(utd_ratio=2)
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
policy = SACPolicy(config=sac_cfg)
algorithm = algo_config.build_algorithm(policy)
assert isinstance(algorithm, SACAlgorithm)
assert algorithm.config.utd_ratio == 2
def test_make_algorithm_uses_build_algorithm():
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
sac_cfg = _make_sac_config()
policy = SACPolicy(config=sac_cfg)
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
assert isinstance(algorithm, SACAlgorithm)
+115
View File
@@ -0,0 +1,115 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import torch
from torch import Tensor
from lerobot.rl.algorithms.base import RLAlgorithm, TrainingStats
from lerobot.rl.trainer import RLTrainer
from lerobot.utils.constants import ACTION, OBS_STATE
class _CountingAlgorithm(RLAlgorithm):
def __init__(self):
self.configure_calls = 0
self.update_calls = 0
def select_action(self, observation: dict[str, Tensor]) -> Tensor:
return torch.zeros(1)
def configure_data_iterator(
self,
data_mixer,
batch_size: int,
*,
async_prefetch: bool = True,
queue_size: int = 2,
):
self.configure_calls += 1
return data_mixer.get_iterator(
batch_size=batch_size,
async_prefetch=async_prefetch,
queue_size=queue_size,
)
def make_optimizers(self):
return {}
def update(self, batch_iterator):
self.update_calls += 1
_ = next(batch_iterator)
return TrainingStats(losses={"dummy": 1.0})
def load_weights(self, weights, device="cpu") -> None:
_ = (weights, device)
class _SimpleMixer:
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
_ = (async_prefetch, queue_size)
while True:
yield {
"state": {OBS_STATE: torch.randn(batch_size, 3)},
ACTION: torch.randn(batch_size, 2),
"reward": torch.randn(batch_size),
"next_state": {OBS_STATE: torch.randn(batch_size, 3)},
"done": torch.zeros(batch_size),
"truncated": torch.zeros(batch_size),
"complementary_info": None,
}
def test_trainer_lazy_iterator_lifecycle_and_reset():
algo = _CountingAlgorithm()
mixer = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4, async_prefetch=False)
# First call builds iterator once.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 1
# Second call reuses existing iterator.
trainer.training_step()
assert algo.configure_calls == 1
assert algo.update_calls == 2
# Explicit reset forces lazy rebuild on next step.
trainer.reset_data_iterator()
trainer.training_step()
assert algo.configure_calls == 2
assert algo.update_calls == 3
def test_trainer_set_data_mixer_resets_by_default():
algo = _CountingAlgorithm()
mixer_a = _SimpleMixer()
mixer_b = _SimpleMixer()
trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2, async_prefetch=False)
trainer.training_step()
assert algo.configure_calls == 1
trainer.set_data_mixer(mixer_b, reset=True)
trainer.training_step()
assert algo.configure_calls == 2
def test_algorithm_optimization_step_contract_defaults():
algo = _CountingAlgorithm()
assert algo.optimization_step == 0
algo.optimization_step = 11
assert algo.optimization_step == 11