feat: add RLT algorithm

This commit is contained in:
Khalil Meftah
2026-03-22 22:59:35 +01:00
parent 17f47b9cbc
commit d9371b9a34
4 changed files with 423 additions and 0 deletions
+3
View File
@@ -21,6 +21,7 @@ from lerobot.rl.algorithms.base import (
RLAlgorithmConfig,
TrainingStats,
)
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
@@ -63,5 +64,7 @@ __all__ = [
"TrainingStats",
"SACAlgorithm",
"SACAlgorithmConfig",
"RLTAlgorithm",
"RLTAlgorithmConfig",
"make_algorithm",
]
+18
View File
@@ -0,0 +1,18 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"]
@@ -0,0 +1,83 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT algorithm configuration."""
from __future__ import annotations
from dataclasses import dataclass
from typing import TYPE_CHECKING
import torch
from lerobot.rl.algorithms.base import RLAlgorithmConfig
if TYPE_CHECKING:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
@RLAlgorithmConfig.register_subclass("rlt")
@dataclass
class RLTAlgorithmConfig(RLAlgorithmConfig):
"""RLT-specific hyper-parameters that control the update loop."""
# ── Action chunks ──
chunk_size: int = 10
chunk_stride: int = 2
# ── Update cadence ──
utd_ratio: int = 5
policy_update_freq: int = 2
clip_grad_norm: float = 10.0
# ── Learning rates ──
actor_lr: float = 3e-4
critic_lr: float = 3e-4
rl_token_lr: float = 1e-4
# ── TD learning ──
discount: float = 0.99
tau: float = 0.005
num_critics: int = 2
# ── Policy constraint (paper Eq. 5) ──
bc_reg_coeff: float = 0.1
ref_dropout: float = 0.5
# ── Offline RL-token training ──
vla_finetune_weight: float = 0.0
@classmethod
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
return cls(
chunk_size=policy_cfg.chunk_size,
chunk_stride=policy_cfg.chunk_stride,
utd_ratio=policy_cfg.utd_ratio,
policy_update_freq=policy_cfg.policy_update_freq,
clip_grad_norm=policy_cfg.clip_grad_norm,
actor_lr=policy_cfg.actor_lr,
critic_lr=policy_cfg.critic_lr,
rl_token_lr=policy_cfg.rl_token_lr,
discount=policy_cfg.discount,
tau=policy_cfg.tau,
num_critics=policy_cfg.num_critics,
bc_reg_coeff=policy_cfg.bc_reg_coeff,
ref_dropout=policy_cfg.ref_dropout,
vla_finetune_weight=policy_cfg.vla_finetune_weight,
)
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
return RLTAlgorithm(policy=policy, config=self)
@@ -0,0 +1,319 @@
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""RLT (RL Token) algorithm.
Implements the two-stage training from "RL Token: Bootstrapping Online RL
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
reference-action pass-through, and reference-action dropout.
"""
from __future__ import annotations
import copy
from collections.abc import Iterator
from typing import Any
import torch
import torch.nn as nn
import torch.nn.functional as F # noqa: N812
from torch import Tensor
from torch.optim import Optimizer
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
from lerobot.policies.utils import get_device_from_parameters
from lerobot.rl.algorithms.base import (
BatchType,
RLAlgorithm,
TrainingStats,
)
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
from lerobot.utils.constants import ACTION
class RLTCritic(nn.Module):
"""Q-function over (state, action_chunk) pairs.
Paper Eq. 3: Q_psi(x, a_{1:C})
Training-only component — lives on the algorithm side, not in the policy.
"""
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
super().__init__()
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
x = torch.cat([state, action_chunk], dim=-1)
return self.net(x)
class RLTAlgorithm(RLAlgorithm):
"""RL Token: lightweight actor-critic on frozen VLA features.
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
ensemble, and target networks. All VLA-specific logic (embedding
extraction, reference actions) lives in ``_prepare_forward_batch``.
"""
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
self.policy = policy
self.config = config
self.optimizers: dict[str, Optimizer] = {}
self._optimization_step: int = 0
self._device = get_device_from_parameters(self.policy)
self._is_online = False
self._init_critics()
self._move_to_device()
# ── Initialization ───────────────────────────────────────────────
def _init_critics(self) -> None:
state_dim = self.policy._state_dim
action_chunk_dim = self.policy._action_chunk_dim
hidden_dims = self.policy.config.critic.hidden_dims
self.critics = torch.nn.ModuleList(
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
)
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
for ct in self.critic_targets:
ct.requires_grad_(False)
def _move_to_device(self) -> None:
self.critics.to(self._device)
self.critic_targets.to(self._device)
# ── Offline phase (Stage 1): RL-token training ───────────────────
def supports_offline_phase(self) -> bool:
return True
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""Train RL-token encoder/decoder on demonstration data.
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
"""
batch = next(batch_iterator)
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
z_rl = self.policy.rl_token_encoder(z_vla)
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
loss_ro = F.mse_loss(z_reconstructed, z_vla)
self.optimizers["rl_token"].zero_grad()
loss_ro.backward()
torch.nn.utils.clip_grad_norm_(
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
max_norm=self.config.clip_grad_norm,
)
self.optimizers["rl_token"].step()
self._optimization_step += 1
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
def transition_to_online(self) -> None:
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
self.policy.rl_token_encoder.requires_grad_(False)
self.policy.rl_token_decoder.requires_grad_(False)
self._is_online = True
self.optimizers = {
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
self._optimization_step = 0
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
"""One full RLT update step with UTD critic warm-up.
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
the last batch also updates the actor (every ``policy_update_freq`` steps).
"""
for _ in range(self.config.utd_ratio - 1):
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
self._critic_step(fb)
self._update_target_networks()
batch = next(batch_iterator)
fb = self._prepare_forward_batch(batch)
critic_loss = self._critic_step(fb)
stats = TrainingStats(losses={"loss_critic": critic_loss})
if self._optimization_step % self.config.policy_update_freq == 0:
actor_loss, bc_loss, q_val = self._actor_step(fb)
stats.losses["loss_actor"] = actor_loss
stats.extra["bc_loss"] = bc_loss
stats.extra["q_value_mean"] = q_val
self._update_target_networks()
self._optimization_step += 1
return stats
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
"""Convert a replay batch into algorithm-ready tensors.
Extracts RL-token from VLA embeddings, builds RL state, reads
reference action from complementary_info.
"""
obs = batch["state"]
next_obs = batch["next_state"]
device = self._device
vla_emb = obs["observation.vla_embeddings"].to(device)
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
with torch.no_grad():
z_rl = self.policy.rl_token_encoder(vla_emb)
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
parts = [z_rl]
next_parts = [z_rl_next]
if "observation.state" in obs and self.policy._proprioception_dim > 0:
prop = obs["observation.state"].to(device)
next_prop = next_obs["observation.state"].to(device)
parts.append(prop)
next_parts.append(next_prop)
state = torch.cat(parts, dim=-1)
next_state = torch.cat(next_parts, dim=-1)
action = batch[ACTION].to(device)
reward = batch["reward"].to(device)
done = batch["done"].to(device)
ref_action = None
comp_info = batch.get("complementary_info")
if comp_info is not None and "reference_action" in comp_info:
ref_action = comp_info["reference_action"].to(device)
return {
"state": state,
"next_state": next_state,
"action": action,
"reward": reward,
"done": done,
"reference_action": ref_action,
}
def _critic_step(self, fb: dict[str, Any]) -> float:
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
state = fb["state"]
next_state = fb["next_state"]
action = fb["action"]
reward = fb["reward"]
done = fb["done"]
with torch.no_grad():
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros_like(action)
next_action = self.policy.actor(next_state, ref)
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
discount_chunk = self.config.discount**self.config.chunk_size
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
q_preds = [c(state, action) for c in self.critics]
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
self.optimizers["critic"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["critic"].step()
return loss.item()
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
"""Paper Eq. 5: maximize Q while staying near VLA reference.
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
With reference-action dropout applied to the actor's ref input.
"""
state = fb["state"]
ref = fb.get("reference_action")
if ref is None:
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
# Reference-action dropout (paper Section IV-B)
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
ref_input = ref * mask
action = self.policy.actor(state, ref_input)
q_value = self.critics[0](state, action)
bc_loss = F.mse_loss(action, ref)
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
self.optimizers["actor"].zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
self.optimizers["actor"].step()
return loss.item(), bc_loss.item(), q_value.mean().item()
def _update_target_networks(self) -> None:
tau = self.config.tau
for critic, target in zip(self.critics, self.critic_targets, strict=True):
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
# ── Optimizer management ─────────────────────────────────────────
def make_optimizers(self) -> dict[str, Optimizer]:
"""Create optimizers. Initially for RL-token (Stage 1)."""
self.optimizers = {
"rl_token": torch.optim.Adam(
list(self.policy.rl_token_encoder.parameters())
+ list(self.policy.rl_token_decoder.parameters()),
lr=self.config.rl_token_lr,
),
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
}
return self.optimizers
def get_optimizers(self) -> dict[str, Optimizer]:
return self.optimizers
# ── Weight sync ──────────────────────────────────────────────────
def get_weights(self) -> dict[str, Any]:
"""Push actor + RL-token encoder to actors (small footprint)."""
weights = {
"actor": self.policy.actor.state_dict(),
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
}
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
if "actor" in weights:
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
if "rl_token_encoder" in weights:
self.policy.rl_token_encoder.load_state_dict(
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
)