From d9371b9a34881d24a0f213cd86534df7ca7d71e9 Mon Sep 17 00:00:00 2001 From: Khalil Meftah Date: Sun, 22 Mar 2026 22:59:35 +0100 Subject: [PATCH] feat: add RLT algorithm --- src/lerobot/rl/algorithms/__init__.py | 3 + src/lerobot/rl/algorithms/rlt/__init__.py | 18 + .../rl/algorithms/rlt/configuration_rlt.py | 83 +++++ .../rl/algorithms/rlt/rlt_algorithm.py | 319 ++++++++++++++++++ 4 files changed, 423 insertions(+) create mode 100644 src/lerobot/rl/algorithms/rlt/__init__.py create mode 100644 src/lerobot/rl/algorithms/rlt/configuration_rlt.py create mode 100644 src/lerobot/rl/algorithms/rlt/rlt_algorithm.py diff --git a/src/lerobot/rl/algorithms/__init__.py b/src/lerobot/rl/algorithms/__init__.py index 271d983c4..895d619d3 100644 --- a/src/lerobot/rl/algorithms/__init__.py +++ b/src/lerobot/rl/algorithms/__init__.py @@ -21,6 +21,7 @@ from lerobot.rl.algorithms.base import ( RLAlgorithmConfig, TrainingStats, ) +from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig @@ -63,5 +64,7 @@ __all__ = [ "TrainingStats", "SACAlgorithm", "SACAlgorithmConfig", + "RLTAlgorithm", + "RLTAlgorithmConfig", "make_algorithm", ] diff --git a/src/lerobot/rl/algorithms/rlt/__init__.py b/src/lerobot/rl/algorithms/rlt/__init__.py new file mode 100644 index 000000000..95196f468 --- /dev/null +++ b/src/lerobot/rl/algorithms/rlt/__init__.py @@ -0,0 +1,18 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig +from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm + +__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"] diff --git a/src/lerobot/rl/algorithms/rlt/configuration_rlt.py b/src/lerobot/rl/algorithms/rlt/configuration_rlt.py new file mode 100644 index 000000000..80a7734c4 --- /dev/null +++ b/src/lerobot/rl/algorithms/rlt/configuration_rlt.py @@ -0,0 +1,83 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RLT algorithm configuration.""" + +from __future__ import annotations + +from dataclasses import dataclass +from typing import TYPE_CHECKING + +import torch + +from lerobot.rl.algorithms.base import RLAlgorithmConfig + +if TYPE_CHECKING: + from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm + + +@RLAlgorithmConfig.register_subclass("rlt") +@dataclass +class RLTAlgorithmConfig(RLAlgorithmConfig): + """RLT-specific hyper-parameters that control the update loop.""" + + # ── Action chunks ── + chunk_size: int = 10 + chunk_stride: int = 2 + + # ── Update cadence ── + utd_ratio: int = 5 + policy_update_freq: int = 2 + clip_grad_norm: float = 10.0 + + # ── Learning rates ── + actor_lr: float = 3e-4 + critic_lr: float = 3e-4 + rl_token_lr: float = 1e-4 + + # ── TD learning ── + discount: float = 0.99 + tau: float = 0.005 + num_critics: int = 2 + + # ── Policy constraint (paper Eq. 5) ── + bc_reg_coeff: float = 0.1 + ref_dropout: float = 0.5 + + # ── Offline RL-token training ── + vla_finetune_weight: float = 0.0 + + @classmethod + def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig: + """Build from an existing ``RLTConfig`` (cfg.policy).""" + return cls( + chunk_size=policy_cfg.chunk_size, + chunk_stride=policy_cfg.chunk_stride, + utd_ratio=policy_cfg.utd_ratio, + policy_update_freq=policy_cfg.policy_update_freq, + clip_grad_norm=policy_cfg.clip_grad_norm, + actor_lr=policy_cfg.actor_lr, + critic_lr=policy_cfg.critic_lr, + rl_token_lr=policy_cfg.rl_token_lr, + discount=policy_cfg.discount, + tau=policy_cfg.tau, + num_critics=policy_cfg.num_critics, + bc_reg_coeff=policy_cfg.bc_reg_coeff, + ref_dropout=policy_cfg.ref_dropout, + vla_finetune_weight=policy_cfg.vla_finetune_weight, + ) + + def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm: + from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm + + return RLTAlgorithm(policy=policy, config=self) diff --git a/src/lerobot/rl/algorithms/rlt/rlt_algorithm.py b/src/lerobot/rl/algorithms/rlt/rlt_algorithm.py new file mode 100644 index 000000000..ebba3fe8a --- /dev/null +++ b/src/lerobot/rl/algorithms/rlt/rlt_algorithm.py @@ -0,0 +1,319 @@ +# Copyright 2026 The HuggingFace Inc. team. All rights reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +"""RLT (RL Token) algorithm. + +Implements the two-stage training from "RL Token: Bootstrapping Online RL +with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026). + +Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss. +Stage 2 (online): Train actor-critic with chunked TD, BC regularization, + reference-action pass-through, and reference-action dropout. +""" + +from __future__ import annotations + +import copy +from collections.abc import Iterator +from typing import Any + +import torch +import torch.nn as nn +import torch.nn.functional as F # noqa: N812 +from torch import Tensor +from torch.optim import Optimizer + +from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy +from lerobot.policies.utils import get_device_from_parameters +from lerobot.rl.algorithms.base import ( + BatchType, + RLAlgorithm, + TrainingStats, +) +from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig +from lerobot.utils.constants import ACTION + + +class RLTCritic(nn.Module): + """Q-function over (state, action_chunk) pairs. + + Paper Eq. 3: Q_psi(x, a_{1:C}) + + Training-only component — lives on the algorithm side, not in the policy. + """ + + def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]): + super().__init__() + self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1) + + def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor: + x = torch.cat([state, action_chunk], dim=-1) + return self.net(x) + + +class RLTAlgorithm(RLAlgorithm): + """RL Token: lightweight actor-critic on frozen VLA features. + + Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic + ensemble, and target networks. All VLA-specific logic (embedding + extraction, reference actions) lives in ``_prepare_forward_batch``. + """ + + def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig): + self.policy = policy + self.config = config + self.optimizers: dict[str, Optimizer] = {} + self._optimization_step: int = 0 + self._device = get_device_from_parameters(self.policy) + self._is_online = False + + self._init_critics() + self._move_to_device() + + # ── Initialization ─────────────────────────────────────────────── + + def _init_critics(self) -> None: + state_dim = self.policy._state_dim + action_chunk_dim = self.policy._action_chunk_dim + hidden_dims = self.policy.config.critic.hidden_dims + + self.critics = torch.nn.ModuleList( + [RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)] + ) + self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics]) + for ct in self.critic_targets: + ct.requires_grad_(False) + + def _move_to_device(self) -> None: + self.critics.to(self._device) + self.critic_targets.to(self._device) + + # ── Offline phase (Stage 1): RL-token training ─────────────────── + + def supports_offline_phase(self) -> bool: + return True + + def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """Train RL-token encoder/decoder on demonstration data. + + Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ] + """ + batch = next(batch_iterator) + + vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device) + z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings + + z_rl = self.policy.rl_token_encoder(z_vla) + z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla) + + loss_ro = F.mse_loss(z_reconstructed, z_vla) + + self.optimizers["rl_token"].zero_grad() + loss_ro.backward() + torch.nn.utils.clip_grad_norm_( + list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()), + max_norm=self.config.clip_grad_norm, + ) + self.optimizers["rl_token"].step() + + self._optimization_step += 1 + return TrainingStats(losses={"loss_rl_token": loss_ro.item()}) + + def transition_to_online(self) -> None: + """Freeze RL-token modules; rebuild optimizers for actor-critic only.""" + self.policy.rl_token_encoder.requires_grad_(False) + self.policy.rl_token_decoder.requires_grad_(False) + self._is_online = True + + self.optimizers = { + "actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr), + } + self._optimization_step = 0 + + # ── Online phase (Stage 2): Actor-Critic ───────────────────────── + + def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats: + """One full RLT update step with UTD critic warm-up. + + Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only; + the last batch also updates the actor (every ``policy_update_freq`` steps). + """ + for _ in range(self.config.utd_ratio - 1): + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch) + self._critic_step(fb) + self._update_target_networks() + + batch = next(batch_iterator) + fb = self._prepare_forward_batch(batch) + critic_loss = self._critic_step(fb) + + stats = TrainingStats(losses={"loss_critic": critic_loss}) + + if self._optimization_step % self.config.policy_update_freq == 0: + actor_loss, bc_loss, q_val = self._actor_step(fb) + stats.losses["loss_actor"] = actor_loss + stats.extra["bc_loss"] = bc_loss + stats.extra["q_value_mean"] = q_val + + self._update_target_networks() + self._optimization_step += 1 + return stats + + def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]: + """Convert a replay batch into algorithm-ready tensors. + + Extracts RL-token from VLA embeddings, builds RL state, reads + reference action from complementary_info. + """ + obs = batch["state"] + next_obs = batch["next_state"] + device = self._device + + vla_emb = obs["observation.vla_embeddings"].to(device) + next_vla_emb = next_obs["observation.vla_embeddings"].to(device) + + with torch.no_grad(): + z_rl = self.policy.rl_token_encoder(vla_emb) + z_rl_next = self.policy.rl_token_encoder(next_vla_emb) + + parts = [z_rl] + next_parts = [z_rl_next] + if "observation.state" in obs and self.policy._proprioception_dim > 0: + prop = obs["observation.state"].to(device) + next_prop = next_obs["observation.state"].to(device) + parts.append(prop) + next_parts.append(next_prop) + + state = torch.cat(parts, dim=-1) + next_state = torch.cat(next_parts, dim=-1) + + action = batch[ACTION].to(device) + reward = batch["reward"].to(device) + done = batch["done"].to(device) + + ref_action = None + comp_info = batch.get("complementary_info") + if comp_info is not None and "reference_action" in comp_info: + ref_action = comp_info["reference_action"].to(device) + + return { + "state": state, + "next_state": next_state, + "action": action, + "reward": reward, + "done": done, + "reference_action": ref_action, + } + + def _critic_step(self, fb: dict[str, Any]) -> float: + """Paper Eq. 3: chunked TD with clipped double-Q target.""" + state = fb["state"] + next_state = fb["next_state"] + action = fb["action"] + reward = fb["reward"] + done = fb["done"] + + with torch.no_grad(): + ref = fb.get("reference_action") + if ref is None: + ref = torch.zeros_like(action) + next_action = self.policy.actor(next_state, ref) + + target_qs = [ct(next_state, next_action) for ct in self.critic_targets] + min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values + + discount_chunk = self.config.discount**self.config.chunk_size + td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q + + q_preds = [c(state, action) for c in self.critics] + loss = sum(F.mse_loss(q, td_target) for q in q_preds) + + self.optimizers["critic"].zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm) + self.optimizers["critic"].step() + return loss.item() + + def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]: + """Paper Eq. 5: maximize Q while staying near VLA reference. + + L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ] + With reference-action dropout applied to the actor's ref input. + """ + state = fb["state"] + ref = fb.get("reference_action") + if ref is None: + ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device) + + # Reference-action dropout (paper Section IV-B) + mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float() + ref_input = ref * mask + + action = self.policy.actor(state, ref_input) + + q_value = self.critics[0](state, action) + + bc_loss = F.mse_loss(action, ref) + + loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss + + self.optimizers["actor"].zero_grad() + loss.backward() + torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm) + self.optimizers["actor"].step() + + return loss.item(), bc_loss.item(), q_value.mean().item() + + def _update_target_networks(self) -> None: + tau = self.config.tau + for critic, target in zip(self.critics, self.critic_targets, strict=True): + for p, tp in zip(critic.parameters(), target.parameters(), strict=True): + tp.data.copy_(tau * p.data + (1 - tau) * tp.data) + + # ── Optimizer management ───────────────────────────────────────── + + def make_optimizers(self) -> dict[str, Optimizer]: + """Create optimizers. Initially for RL-token (Stage 1).""" + self.optimizers = { + "rl_token": torch.optim.Adam( + list(self.policy.rl_token_encoder.parameters()) + + list(self.policy.rl_token_decoder.parameters()), + lr=self.config.rl_token_lr, + ), + "actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr), + "critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr), + } + return self.optimizers + + def get_optimizers(self) -> dict[str, Optimizer]: + return self.optimizers + + # ── Weight sync ────────────────────────────────────────────────── + + def get_weights(self) -> dict[str, Any]: + """Push actor + RL-token encoder to actors (small footprint).""" + weights = { + "actor": self.policy.actor.state_dict(), + "rl_token_encoder": self.policy.rl_token_encoder.state_dict(), + } + return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()} + + def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None: + if "actor" in weights: + self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()}) + if "rl_token_encoder" in weights: + self.policy.rl_token_encoder.load_state_dict( + {k: v.to(device) for k, v in weights["rl_token_encoder"].items()} + )