mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-14 06:49:55 +00:00
Compare commits
8 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 099d2cb34e | |||
| b9c4dd3d12 | |||
| aa0d3e6608 | |||
| 350d01b74d | |||
| 41166b39fb | |||
| 79c6821407 | |||
| 507083249f | |||
| bd22407d93 |
@@ -57,11 +57,11 @@ The `lerobot-rollout --strategy.type=dagger` mode requires **teleoperators with
|
||||
|
||||
**Compatible teleoperators:**
|
||||
|
||||
- `openarm_mini` - OpenArm Mini
|
||||
- `bi_openarm_mini` - Bimanual OpenArm Mini
|
||||
- `so_leader` - SO100 / SO101 leader arm
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The provided commands default to `bi_openarm_follower` + `openarm_mini`.
|
||||
> The provided commands default to `bi_openarm_follower` + `bi_openarm_mini`.
|
||||
> `so_follower` + `so_leader` configs are also registered and can be used via CLI flags.
|
||||
|
||||
---
|
||||
@@ -104,9 +104,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
@@ -131,9 +131,9 @@ lerobot-rollout --strategy.type=dagger \
|
||||
--robot.right_arm_config.port=can0 \
|
||||
--robot.right_arm_config.side=right \
|
||||
--robot.cameras='{left_wrist: {type: opencv, index_or_path: "/dev/video0", width: 1280, height: 720, fps: 30}, right_wrist: {type: opencv, index_or_path: "/dev/video4", width: 1280, height: 720, fps: 30}, base: {type: opencv, index_or_path: "/dev/video2", width: 640, height: 480, fps: 30}}' \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.port_left=/dev/ttyACM0 \
|
||||
--teleop.port_right=/dev/ttyACM1 \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--teleop.left_arm_config.port=/dev/ttyACM0 \
|
||||
--teleop.right_arm_config.port=/dev/ttyACM1 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--dataset.repo_id=your-username/rollout_hil_rtc_dataset \
|
||||
--dataset.single_task="Fold the T-shirt properly" \
|
||||
|
||||
@@ -117,7 +117,7 @@ lerobot-rollout \
|
||||
--strategy.num_episodes=20 \
|
||||
--policy.path=outputs/pretrain/checkpoints/last/pretrained_model \
|
||||
--robot.type=bi_openarm_follower \
|
||||
--teleop.type=openarm_mini \
|
||||
--teleop.type=bi_openarm_mini \
|
||||
--dataset.repo_id=${HF_USER}/rollout_hil_data \
|
||||
--dataset.single_task="Fold the T-shirt"
|
||||
```
|
||||
|
||||
+3
-5
@@ -214,10 +214,9 @@ groot = [
|
||||
sarm = ["lerobot[transformers-dep]", "pydantic>=2.0.0,<3.0.0", "faker>=33.0.0,<35.0.0", "lerobot[matplotlib-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
robometer = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]", "lerobot[peft-dep]"]
|
||||
topreward = ["lerobot[transformers-dep]"]
|
||||
recap = ["lerobot[transformers-dep]"]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
eo1 = ["lerobot[transformers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "lerobot[dataset]", "gym-hil>=0.1.14,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
vla_jepa = ["lerobot[transformers-dep]", "lerobot[diffusers-dep]", "lerobot[qwen-vl-utils-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -232,9 +231,9 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
|
||||
# Simulation
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
aloha = ["lerobot[dataset]", "gym-aloha>=0.1.4,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["lerobot[dataset]", "gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = ["lerobot[dataset]", "lerobot[transformers-dep]", "hf-libero>=0.1.4,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
metaworld = ["lerobot[dataset]", "metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
# NOTE: vlabench is NOT exposed as a `lerobot` extra. Its only distribution
|
||||
# is the OpenMOSS/VLABench GitHub repo (package name `VLABench`, no PyPI
|
||||
@@ -297,7 +296,6 @@ all = [
|
||||
"lerobot[sarm]",
|
||||
"lerobot[robometer]",
|
||||
"lerobot[topreward]",
|
||||
"lerobot[recap]",
|
||||
"lerobot[peft]",
|
||||
# "lerobot[unitree_g1]", TODO: Unitree requires specific installation instructions for unitree_sdk2
|
||||
]
|
||||
|
||||
@@ -30,6 +30,7 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: int = 0,
|
||||
drop_n_last_frames: int = 0,
|
||||
shuffle: bool = False,
|
||||
generator: torch.Generator | None = None,
|
||||
):
|
||||
"""Sampler that optionally incorporates episode boundary information.
|
||||
|
||||
@@ -41,6 +42,10 @@ class EpisodeAwareSampler:
|
||||
drop_n_first_frames: Number of frames to drop from the start of each episode.
|
||||
drop_n_last_frames: Number of frames to drop from the end of each episode.
|
||||
shuffle: Whether to shuffle the indices.
|
||||
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
|
||||
`accelerate` register it as the synchronized RNG in distributed training, so
|
||||
every rank draws the same permutation and batch shards stay disjoint. When
|
||||
None, shuffling falls back to the global torch RNG.
|
||||
"""
|
||||
if drop_n_first_frames < 0:
|
||||
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
|
||||
@@ -73,10 +78,11 @@ class EpisodeAwareSampler:
|
||||
|
||||
self.indices = indices
|
||||
self.shuffle = shuffle
|
||||
self.generator = generator
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
if self.shuffle:
|
||||
for i in torch.randperm(len(self.indices)):
|
||||
for i in torch.randperm(len(self.indices), generator=self.generator):
|
||||
yield self.indices[i]
|
||||
else:
|
||||
for i in self.indices:
|
||||
|
||||
@@ -13,9 +13,6 @@
|
||||
# limitations under the License.
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig as RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig as DistributionalVFConfig,
|
||||
)
|
||||
from .factory import (
|
||||
get_reward_model_class as get_reward_model_class,
|
||||
make_reward_model as make_reward_model,
|
||||
@@ -29,7 +26,6 @@ from .topreward.configuration_topreward import TOPRewardConfig as TOPRewardConfi
|
||||
|
||||
__all__ = [
|
||||
# Configuration classes
|
||||
"DistributionalVFConfig",
|
||||
"RewardClassifierConfig",
|
||||
"RobometerConfig",
|
||||
"SARMConfig",
|
||||
|
||||
-108
@@ -1,108 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
Weights initialized from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs import FeatureType, NormalizationMode
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.optim import AdamWConfig, CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@RewardModelConfig.register_subclass("distributional_value_function")
|
||||
@dataclass
|
||||
class DistributionalVFConfig(RewardModelConfig):
|
||||
"""Configuration for RECAP's distributional value function.
|
||||
|
||||
The value function predicts V^{pi_ref}(o_t, l) as a distribution over B discrete
|
||||
bins spanning [value_support_min, value_support_max]. It is trained with cross-entropy
|
||||
on HL-Gauss soft targets or Dirac delta projection, derived from Monte Carlo returns
|
||||
(Eq. 1 in the paper).
|
||||
|
||||
Architecture: the paper value function is a 670M Gemma 3 VLM; the actor is 4B Gemma 3.
|
||||
We use truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``) to reach
|
||||
about 670M params and initialize from the PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
# Backbone
|
||||
paligemma_variant: str = "gemma_2b"
|
||||
num_hidden_layers: int = 6
|
||||
num_vision_layers: int = 13
|
||||
|
||||
# Distributional head
|
||||
num_value_bins: int = 201
|
||||
value_support_min: float = -1.0
|
||||
value_support_max: float = 0.0
|
||||
hl_gauss_sigma_ratio: float = 5.0
|
||||
|
||||
# Target distribution method: "hl_gauss" (default, soft) or "dirac_delta" (C51, hard)
|
||||
target_method: str = "hl_gauss"
|
||||
|
||||
# Whether to use one-hot targets for terminal states (exact return, no smoothing).
|
||||
# When False, terminal states use the same target method as non-terminal states.
|
||||
use_one_hot_terminal: bool = True
|
||||
|
||||
# Image
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
|
||||
# Tokenizer
|
||||
tokenizer_max_length: int = 64
|
||||
|
||||
# Init from actor (required for first training: provides SigLIP vision tower + Gemma embeddings).
|
||||
# Pass a PI05 checkpoint path or Hub repo_id here.
|
||||
# After training, load the value function with RewardModel.from_pretrained() instead.
|
||||
init_from_actor_path: str = ""
|
||||
|
||||
# Normalization
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
return AdamWConfig(
|
||||
lr=3e-4,
|
||||
weight_decay=1e-4,
|
||||
grad_clip_norm=1.0,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.input_features:
|
||||
return
|
||||
has_image = any(ft.type == FeatureType.VISUAL for ft in self.input_features.values())
|
||||
if not has_image:
|
||||
raise ValueError("DistributionalVFConfig requires at least one VISUAL input feature.")
|
||||
-567
@@ -1,567 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Modeling for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Implements the distributional value function V^{pi_ref}(o_t, l) from Section IV-A.
|
||||
Architecture: the paper uses a 670M-parameter Gemma 3 VLM (the actor is 4B Gemma 3).
|
||||
We match that scale on PaliGemma (PI05's Gemma 2B backbone) by truncating to 6 Gemma
|
||||
LM layers and 13 SigLIP vision layers (~670M params), with a [CLS] token and linear
|
||||
head predicting a categorical distribution over B=201 discrete value bins in [-1, 0].
|
||||
|
||||
Inputs: single image observation + task text prompt ("Task: {task}.")
|
||||
Outputs: softmax distribution over value bins; expected value E[V] for inference.
|
||||
Training: cross-entropy on HL-Gauss soft targets (or Dirac delta projection),
|
||||
with optional one-hot targets for terminal states; MC returns normalized per task.
|
||||
|
||||
Weight initialization: vision tower, multi-modal projector, token embeddings, and
|
||||
the first N transformer layers are copied from a pre-trained PI05 actor checkpoint.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.rewards.pretrained import PreTrainedRewardModel
|
||||
from lerobot.utils.import_utils import _transformers_available, require_package
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from transformers.models.auto import CONFIG_MAPPING
|
||||
from transformers.models.gemma import modeling_gemma
|
||||
|
||||
from lerobot.policies.pi_gemma import (
|
||||
PaliGemmaForConditionalGenerationWithPiGemma,
|
||||
PiGemmaRMSNorm,
|
||||
_gated_residual,
|
||||
_get_pi_gemma_decoder_layer_base,
|
||||
)
|
||||
else:
|
||||
CONFIG_MAPPING = None
|
||||
modeling_gemma = None
|
||||
PaliGemmaForConditionalGenerationWithPiGemma = None
|
||||
PiGemmaRMSNorm = None
|
||||
_gated_residual = None
|
||||
_get_pi_gemma_decoder_layer_base = None
|
||||
|
||||
PALIGEMMA_VOCAB_SIZE = 257152
|
||||
|
||||
|
||||
class DistributionalVFRewardModel(PreTrainedRewardModel):
|
||||
"""Distributional value function model for RECAP.
|
||||
|
||||
Predicts V^{pi_ref}(o_t, l) as a categorical distribution over B bins (default 201).
|
||||
Trained with cross-entropy on HL-Gauss or Dirac delta targets centered on
|
||||
per-task normalized Monte Carlo returns.
|
||||
|
||||
Architecture: truncated PaliGemma (``num_hidden_layers=6``, ``num_vision_layers=13``),
|
||||
causal attention, [CLS] token, and Linear(D, num_bins) value head.
|
||||
The expected value is E[V] = sum(softmax(logits) * bin_centers).
|
||||
"""
|
||||
|
||||
name = "distributional_value_function"
|
||||
config_class = DistributionalVFConfig
|
||||
|
||||
def __init__(self, config: DistributionalVFConfig, **kwargs) -> None:
|
||||
require_package("transformers", extra="recap")
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
from transformers.models.gemma.modeling_gemma import GemmaRotaryEmbedding
|
||||
|
||||
from lerobot.policies.pi05.modeling_pi05 import get_gemma_config
|
||||
|
||||
# Get base dimensions from the paligemma variant (OpenPI config format)
|
||||
base_config = get_gemma_config(config.paligemma_variant)
|
||||
hidden_dim = base_config.width
|
||||
mlp_dim = base_config.mlp_dim
|
||||
num_layers = config.num_hidden_layers
|
||||
|
||||
# HuggingFace GemmaConfig for transformer layers
|
||||
gemma_config = CONFIG_MAPPING["gemma"](
|
||||
head_dim=base_config.head_dim,
|
||||
hidden_size=hidden_dim,
|
||||
intermediate_size=mlp_dim,
|
||||
num_attention_heads=base_config.num_heads,
|
||||
num_hidden_layers=num_layers,
|
||||
num_key_value_heads=base_config.num_kv_heads,
|
||||
vocab_size=PALIGEMMA_VOCAB_SIZE,
|
||||
hidden_activation="gelu_pytorch_tanh",
|
||||
)
|
||||
self.gemma_config = gemma_config
|
||||
self.hidden_dim = hidden_dim
|
||||
self.num_value_bins = config.num_value_bins
|
||||
|
||||
# Single learned [CLS] token for value prediction
|
||||
self.cls_embedding = nn.Parameter(torch.randn(1, 1, hidden_dim) * 0.02)
|
||||
|
||||
# Value projection head: Linear(hidden_dim, num_bins)
|
||||
self.value_head = nn.Linear(in_features=hidden_dim, out_features=config.num_value_bins)
|
||||
|
||||
# Transformer layers (overwritten by _initialize_from_actor on first run)
|
||||
self.rotary_emb = GemmaRotaryEmbedding(gemma_config)
|
||||
pi_gemma_decoder_layer_base = _get_pi_gemma_decoder_layer_base()
|
||||
self.layers = nn.ModuleList(
|
||||
[pi_gemma_decoder_layer_base(gemma_config, layer_idx=i) for i in range(num_layers)]
|
||||
)
|
||||
self.norm = PiGemmaRMSNorm(hidden_dim, eps=gemma_config.rms_norm_eps)
|
||||
|
||||
# Vision tower + projector + token embedding (overwritten by _initialize_from_actor on first run)
|
||||
# PaliGemmaConfig wraps both vision and text configs into a single model
|
||||
paligemma_config = CONFIG_MAPPING["paligemma"]()
|
||||
paligemma_config.text_config = gemma_config
|
||||
paligemma_config.vision_config.image_size = config.image_resolution[0]
|
||||
paligemma_config.vision_config.intermediate_size = 4304
|
||||
paligemma_config.vision_config.projection_dim = 2048
|
||||
paligemma_config.vision_config.projector_hidden_act = "gelu_fast"
|
||||
|
||||
paligemma_full = PaliGemmaForConditionalGenerationWithPiGemma(config=paligemma_config)
|
||||
self.vision_tower = paligemma_full.model.vision_tower
|
||||
self.multi_modal_projector = paligemma_full.model.multi_modal_projector
|
||||
self.token_embedding = paligemma_full.model.language_model.embed_tokens
|
||||
del paligemma_full
|
||||
|
||||
# Truncate vision tower to num_vision_layers
|
||||
if hasattr(self.vision_tower, "vision_model") and hasattr(self.vision_tower.vision_model, "encoder"):
|
||||
vision_encoder = self.vision_tower.vision_model.encoder
|
||||
vision_encoder.layers = vision_encoder.layers[: config.num_vision_layers]
|
||||
|
||||
# Bin support: evenly spaced centers from value_support_min to value_support_max
|
||||
bin_centers = torch.linspace(config.value_support_min, config.value_support_max, self.num_value_bins)
|
||||
self.register_buffer("bin_centers", bin_centers, persistent=False)
|
||||
bin_width = (config.value_support_max - config.value_support_min) / (self.num_value_bins - 1)
|
||||
self.hl_gauss_sigma = float(config.hl_gauss_sigma_ratio * bin_width)
|
||||
|
||||
# Overwrite with pre-trained PI05 actor weights (first training run only)
|
||||
if config.init_from_actor_path:
|
||||
self._initialize_from_actor()
|
||||
|
||||
def _initialize_from_actor(self) -> None:
|
||||
"""Overwrite weights from a pre-trained PI05 actor checkpoint.
|
||||
|
||||
Called on first training run only (when init_from_actor_path is set).
|
||||
"""
|
||||
from lerobot.policies.pi05.modeling_pi05 import PI05Policy
|
||||
|
||||
actor_policy = PI05Policy.from_pretrained(self.config.init_from_actor_path)
|
||||
actor_model = actor_policy.model
|
||||
|
||||
paligemma_model = actor_model.paligemma_with_expert.paligemma
|
||||
source_language_model = paligemma_model.model.language_model
|
||||
|
||||
# Transformer components
|
||||
self.rotary_emb.load_state_dict(source_language_model.rotary_emb.state_dict())
|
||||
num_layers = self.gemma_config.num_hidden_layers
|
||||
for i in range(num_layers):
|
||||
self.layers[i].load_state_dict(source_language_model.layers[i].state_dict())
|
||||
self.norm.load_state_dict(source_language_model.norm.state_dict())
|
||||
|
||||
# Vision tower (truncate source first, then copy)
|
||||
source_vision_tower = paligemma_model.model.vision_tower
|
||||
if hasattr(source_vision_tower, "vision_model") and hasattr(
|
||||
source_vision_tower.vision_model, "encoder"
|
||||
):
|
||||
source_encoder = source_vision_tower.vision_model.encoder
|
||||
source_encoder.layers = source_encoder.layers[: self.config.num_vision_layers]
|
||||
self.vision_tower.load_state_dict(source_vision_tower.state_dict())
|
||||
|
||||
# Multi-modal projector
|
||||
self.multi_modal_projector.load_state_dict(paligemma_model.model.multi_modal_projector.state_dict())
|
||||
|
||||
# Token embedding table
|
||||
self.token_embedding.load_state_dict(paligemma_model.model.language_model.embed_tokens.state_dict())
|
||||
|
||||
del actor_policy
|
||||
|
||||
def embed_image(self, image: Tensor) -> Tensor:
|
||||
"""Embed images using the value function's SigLIP vision tower.
|
||||
|
||||
Args:
|
||||
image: [batch_size, channels, height, width] preprocessed images in [-1, 1].
|
||||
|
||||
Returns:
|
||||
[batch_size, num_patches, hidden_dim] projected image features.
|
||||
"""
|
||||
out_dtype = image.dtype
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
image_outputs = self.vision_tower(image, return_dict=True)
|
||||
image_features = self.multi_modal_projector(image_outputs.last_hidden_state)
|
||||
image_features = image_features / (self.hidden_dim**0.5)
|
||||
|
||||
if image_features.dtype != out_dtype:
|
||||
image_features = image_features.to(out_dtype)
|
||||
return image_features
|
||||
|
||||
def embed_text(self, token_ids: Tensor) -> Tensor:
|
||||
"""Embed text token IDs using the value function's token embedding table.
|
||||
|
||||
Args:
|
||||
token_ids: [batch_size, seq_len] integer token IDs
|
||||
|
||||
Returns:
|
||||
[batch_size, seq_len, hidden_dim] text embeddings
|
||||
"""
|
||||
return self.token_embedding(token_ids)
|
||||
|
||||
def _get_cls_embedding(self, batch_size: int) -> Tensor:
|
||||
"""Get [CLS] token embedding expanded to batch size.
|
||||
|
||||
Args:
|
||||
batch_size: number of samples in the batch.
|
||||
|
||||
Returns:
|
||||
[batch_size, 1, hidden_dim] learned [CLS] embedding.
|
||||
"""
|
||||
return self.cls_embedding.expand(batch_size, -1, -1)
|
||||
|
||||
def forward_value(
|
||||
self, vision_features: Tensor, text_embeddings: Tensor, text_padding_mask: Tensor
|
||||
) -> dict[str, Tensor]:
|
||||
"""Core forward pass through the distributional value function.
|
||||
|
||||
Args:
|
||||
vision_features: [batch_size, num_patches, hidden_dim]
|
||||
text_embeddings: [batch_size, seq_len, hidden_dim]
|
||||
text_padding_mask: [batch_size, seq_len] boolean mask for text tokens
|
||||
|
||||
Returns:
|
||||
logits: [batch_size, num_value_bins]
|
||||
probs: [batch_size, num_value_bins]
|
||||
value: [batch_size, 1]
|
||||
"""
|
||||
from lerobot.utils.constants import OPENPI_ATTENTION_MASK_VALUE
|
||||
|
||||
batch_size = text_embeddings.shape[0]
|
||||
device = text_embeddings.device
|
||||
|
||||
# Build sequence: [vision, text, CLS]
|
||||
cls_embedding = self._get_cls_embedding(batch_size)
|
||||
hidden_states = torch.cat([vision_features, text_embeddings, cls_embedding], dim=1)
|
||||
|
||||
# Build causal attention mask
|
||||
vision_len = vision_features.shape[1]
|
||||
vision_padding_mask = torch.ones(batch_size, vision_len, dtype=torch.bool, device=device)
|
||||
cls_padding_mask = torch.ones(batch_size, 1, dtype=torch.bool, device=device)
|
||||
full_padding_mask = torch.cat([vision_padding_mask, text_padding_mask, cls_padding_mask], dim=1)
|
||||
|
||||
full_seq_len = full_padding_mask.shape[1]
|
||||
|
||||
# Causal mask
|
||||
causal_mask = torch.tril(torch.ones(full_seq_len, full_seq_len, device=device, dtype=torch.bool))
|
||||
# Combine causal mask with padding mask
|
||||
padding_mask_4d = full_padding_mask[:, None, None, :].expand(
|
||||
batch_size, 1, full_seq_len, full_seq_len
|
||||
)
|
||||
attention_mask = causal_mask[None, None, :, :] & padding_mask_4d
|
||||
attention_mask = torch.where(attention_mask, 0.0, OPENPI_ATTENTION_MASK_VALUE)
|
||||
|
||||
position_ids = torch.cumsum(full_padding_mask.long(), dim=1) - 1
|
||||
cos, sin = self.rotary_emb(hidden_states, position_ids)
|
||||
|
||||
for layer in self.layers:
|
||||
norm_output = layer.input_layernorm(hidden_states, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
hidden_states_normed, gate = norm_output
|
||||
else:
|
||||
hidden_states_normed, gate = norm_output, None
|
||||
|
||||
input_shape = hidden_states_normed.shape[:-1]
|
||||
hidden_shape = (*input_shape, -1, layer.self_attn.head_dim)
|
||||
|
||||
query_states = layer.self_attn.q_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
key_states = layer.self_attn.k_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
value_states = layer.self_attn.v_proj(hidden_states_normed).view(hidden_shape).transpose(1, 2)
|
||||
|
||||
query_states, key_states = modeling_gemma.apply_rotary_pos_emb(
|
||||
query_states, key_states, cos, sin, unsqueeze_dim=1
|
||||
)
|
||||
|
||||
attention_output, _ = modeling_gemma.eager_attention_forward(
|
||||
layer.self_attn,
|
||||
query_states,
|
||||
key_states,
|
||||
value_states,
|
||||
attention_mask,
|
||||
layer.self_attn.scaling,
|
||||
)
|
||||
|
||||
attention_output = attention_output.reshape(batch_size, -1, self.gemma_config.hidden_size)
|
||||
if attention_output.dtype != layer.self_attn.o_proj.weight.dtype:
|
||||
attention_output = attention_output.to(layer.self_attn.o_proj.weight.dtype)
|
||||
projected_attention = layer.self_attn.o_proj(attention_output)
|
||||
|
||||
if gate is not None:
|
||||
projected_attention = _gated_residual(hidden_states, projected_attention, gate)
|
||||
else:
|
||||
projected_attention = hidden_states + projected_attention
|
||||
|
||||
after_attention_residual = projected_attention.clone()
|
||||
|
||||
norm_output = layer.post_attention_layernorm(projected_attention, cond=None)
|
||||
if isinstance(norm_output, tuple):
|
||||
mlp_input, gate = norm_output
|
||||
else:
|
||||
mlp_input, gate = norm_output, None
|
||||
|
||||
mlp_output = layer.mlp(mlp_input)
|
||||
|
||||
if gate is not None:
|
||||
hidden_states = _gated_residual(after_attention_residual, mlp_output, gate)
|
||||
else:
|
||||
hidden_states = after_attention_residual + mlp_output
|
||||
|
||||
hidden_states = self.norm(hidden_states)
|
||||
if isinstance(hidden_states, tuple):
|
||||
hidden_states = hidden_states[0]
|
||||
|
||||
# Extract [CLS] token (last position in the sequence)
|
||||
cls_hidden_state = hidden_states[:, -1, :] # [batch_size, hidden_dim]
|
||||
|
||||
# Value head: Linear(hidden_dim, num_bins) -> logits
|
||||
value_logits = self.value_head(cls_hidden_state) # [batch_size, num_value_bins]
|
||||
value_probs = F.softmax(value_logits, dim=-1)
|
||||
predicted_value = (value_probs * self.bin_centers.to(dtype=value_probs.dtype)).sum(
|
||||
dim=-1, keepdim=True
|
||||
)
|
||||
|
||||
return {"logits": value_logits, "probs": value_probs, "value": predicted_value}
|
||||
|
||||
def hl_gauss_target(self, target_value: Tensor) -> Tensor:
|
||||
"""HL-Gauss soft target distribution.
|
||||
|
||||
Places a Gaussian N(target, sigma^2) over the bin support and computes
|
||||
per-bin probabilities as CDF differences at bin edges, normalized to sum to 1.
|
||||
|
||||
Reference: Farebrother et al. 2024, "Stop Regressing: Training Value
|
||||
Functions via Classification for Scalable Deep RL", Section 3.1.
|
||||
arXiv:2403.03950
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
# Bin edges: half a bin-width outside the first/last center
|
||||
bin_width = (self.config.value_support_max - self.config.value_support_min) / (
|
||||
self.num_value_bins - 1
|
||||
)
|
||||
support_edges = torch.linspace(
|
||||
self.config.value_support_min - bin_width / 2,
|
||||
self.config.value_support_max + bin_width / 2,
|
||||
self.num_value_bins + 1,
|
||||
device=target_value.device,
|
||||
dtype=target_value.dtype,
|
||||
)
|
||||
|
||||
# CDF of N(target, sigma^2) evaluated at each edge
|
||||
cdf_at_edges = 0.5 * (
|
||||
1.0
|
||||
+ torch.erf(
|
||||
(support_edges.unsqueeze(0) - target_value.unsqueeze(-1))
|
||||
/ (self.hl_gauss_sigma * math.sqrt(2))
|
||||
)
|
||||
) # [batch_size, num_bins + 1]
|
||||
|
||||
# Normalize: z = cdf(max_edge) - cdf(min_edge)
|
||||
normalization_constant = (cdf_at_edges[:, -1] - cdf_at_edges[:, 0]).unsqueeze(-1).clamp(min=1e-10)
|
||||
|
||||
# Bin probabilities = differences of consecutive CDF values, normalized
|
||||
bin_probabilities = (cdf_at_edges[:, 1:] - cdf_at_edges[:, :-1]) / normalization_constant
|
||||
|
||||
return bin_probabilities
|
||||
|
||||
def dirac_delta_target(self, target_value: Tensor) -> Tensor:
|
||||
"""Dirac delta (C51) projection: split probability between two nearest bins.
|
||||
|
||||
Standard distributional RL projection from Bellemare et al. 2017.
|
||||
"A Distributional Perspective on Reinforcement Learning"
|
||||
arXiv:1707.06887
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.clamp(self.config.value_support_min, self.config.value_support_max)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
|
||||
bin_width = self.bin_centers[1] - self.bin_centers[0]
|
||||
normalized_position = (target_value - self.config.value_support_min) / bin_width
|
||||
lower_bin_idx = normalized_position.floor().long().clamp(0, self.num_value_bins - 1)
|
||||
upper_bin_idx = normalized_position.ceil().long().clamp(0, self.num_value_bins - 1)
|
||||
|
||||
weight_upper = normalized_position - lower_bin_idx.float()
|
||||
weight_lower = upper_bin_idx.float() - normalized_position
|
||||
|
||||
same_bin = lower_bin_idx == upper_bin_idx
|
||||
weight_upper = torch.where(same_bin, torch.zeros_like(weight_upper), weight_upper)
|
||||
weight_lower = torch.where(same_bin, torch.ones_like(weight_lower), weight_lower)
|
||||
|
||||
batch_size = target_value.shape[0]
|
||||
target_distribution = torch.zeros(batch_size, self.num_value_bins, device=target_value.device)
|
||||
batch_indices = torch.arange(batch_size, device=target_value.device)
|
||||
target_distribution[batch_indices, lower_bin_idx] += weight_lower
|
||||
target_distribution[batch_indices, upper_bin_idx] += weight_upper
|
||||
|
||||
return target_distribution
|
||||
|
||||
def one_hot_target(self, target_value: Tensor) -> Tensor:
|
||||
"""One-hot target for terminal states (exact return, no smoothing).
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] or [batch_size, 1] target values.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] one-hot distribution at the nearest bin.
|
||||
"""
|
||||
if target_value.ndim == 2:
|
||||
target_value = target_value.squeeze(-1)
|
||||
target_value = target_value.to(dtype=self.bin_centers.dtype)
|
||||
nearest_bin_idx = torch.argmin(
|
||||
torch.abs(self.bin_centers.unsqueeze(0) - target_value.unsqueeze(-1)), dim=-1
|
||||
)
|
||||
return F.one_hot(nearest_bin_idx, num_classes=self.num_value_bins).to(dtype=self.bin_centers.dtype)
|
||||
|
||||
def compute_target_distribution(
|
||||
self,
|
||||
target_value: Tensor,
|
||||
is_terminal: Tensor,
|
||||
method: str = "hl_gauss",
|
||||
use_one_hot_terminal: bool = True,
|
||||
) -> Tensor:
|
||||
"""Compute target distribution using configured method.
|
||||
|
||||
Args:
|
||||
target_value: [batch_size] scalar return targets
|
||||
is_terminal: [batch_size] boolean terminal flags
|
||||
method: "hl_gauss" or "dirac_delta"
|
||||
use_one_hot_terminal: if True, terminal states get one-hot targets
|
||||
(exact return, no smoothing). If False, all states use the same method.
|
||||
|
||||
Returns:
|
||||
[batch_size, num_value_bins] target probability distribution
|
||||
"""
|
||||
if method == "hl_gauss":
|
||||
base_distribution = self.hl_gauss_target(target_value)
|
||||
elif method == "dirac_delta":
|
||||
base_distribution = self.dirac_delta_target(target_value)
|
||||
else:
|
||||
raise ValueError(f"Unknown target method: {method}. Use 'hl_gauss' or 'dirac_delta'.")
|
||||
|
||||
if not use_one_hot_terminal:
|
||||
return base_distribution
|
||||
|
||||
terminal_distribution = self.one_hot_target(target_value)
|
||||
|
||||
return torch.where(is_terminal[:, None].bool(), terminal_distribution, base_distribution)
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict[str, Any]]:
|
||||
"""Training forward pass — computes cross-entropy loss against MC return targets.
|
||||
|
||||
The batch is expected to be preprocessed by the processor pipeline.
|
||||
Keys expected in batch:
|
||||
- observation.images.*: [B, C, H, W] preprocessed images
|
||||
- observation.language_tokens: [B, seq_len] tokenized task prompt
|
||||
- observation.language_attention_mask: [B, seq_len] padding mask
|
||||
- mc_return: [B] normalized Monte Carlo return targets in (-1, 0)
|
||||
- is_terminal: [B] boolean terminal flags
|
||||
|
||||
Returns:
|
||||
(loss, output_dict) where loss is scalar cross-entropy
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
# Get first image key from batch
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
mc_return = batch["mc_return"]
|
||||
is_terminal = batch["is_terminal"]
|
||||
|
||||
# Embed observations
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
# Forward through value function transformer
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
value_logits = vf_output["logits"]
|
||||
predicted_value = vf_output["value"]
|
||||
|
||||
# Compute target distribution
|
||||
target_distribution = self.compute_target_distribution(
|
||||
mc_return,
|
||||
is_terminal,
|
||||
method=self.config.target_method,
|
||||
use_one_hot_terminal=self.config.use_one_hot_terminal,
|
||||
)
|
||||
|
||||
# Cross-entropy loss (Eq. 1 in pi*0.6 paper)
|
||||
log_probs = F.log_softmax(value_logits, dim=-1)
|
||||
loss = -(target_distribution * log_probs).sum(dim=-1).mean()
|
||||
|
||||
output_dict = {
|
||||
"loss": loss.item(),
|
||||
"predicted_value_mean": predicted_value.mean().item(),
|
||||
"mc_return_mean": mc_return.mean().item(),
|
||||
}
|
||||
|
||||
return loss, output_dict
|
||||
|
||||
def compute_reward(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Compute V(s) for a batch of observations. Used for advantage scoring.
|
||||
|
||||
Args:
|
||||
batch: preprocessed batch with images and tokenized text
|
||||
|
||||
Returns:
|
||||
[batch_size] tensor of predicted values V(s)
|
||||
"""
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
image_keys = [k for k in batch if k.startswith(f"{OBS_IMAGES}.") or k == OBS_IMAGES]
|
||||
if not image_keys:
|
||||
raise KeyError(f"No image keys found in batch. Expected keys starting with '{OBS_IMAGES}.'")
|
||||
images = batch[image_keys[0]]
|
||||
|
||||
token_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
text_padding_mask = batch[OBS_LANGUAGE_ATTENTION_MASK].bool()
|
||||
|
||||
vision_features = self.embed_image(images)
|
||||
text_embeddings = self.embed_text(token_ids)
|
||||
|
||||
vf_output = self.forward_value(vision_features, text_embeddings, text_padding_mask)
|
||||
return vf_output["value"].squeeze(-1) # [batch_size]
|
||||
-235
@@ -1,235 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Processor for RECAP's distributional value function.
|
||||
|
||||
Paper: "π*0.6: a VLA That Learns From Experience" (Physical Intelligence, 2025)
|
||||
https://pi.website/blog/pistar06
|
||||
|
||||
Prepares inputs for V^{pi_ref}(o_t, l): single image observation and task text only.
|
||||
1. Image preprocessing (resize-with-pad + normalize to [-1, 1]) for SigLIP
|
||||
2. Task prompt formatting ("Task: {task}.") and tokenization via PaliGemma tokenizer
|
||||
|
||||
Training targets (mc_return, is_terminal) are NOT routed through the processor.
|
||||
They are dataset columns read directly from the batch in the model's forward().
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs import FeatureType, PipelineFeatureType, PolicyFeature
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
batch_to_transition,
|
||||
policy_action_to_transition,
|
||||
transition_to_batch,
|
||||
)
|
||||
from lerobot.processor.converters import to_tensor
|
||||
from lerobot.types import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
|
||||
PALIGEMMA_TOKENIZER_NAME = "google/paligemma-3b-pt-224"
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_prepare_task_prompt")
|
||||
@dataclass
|
||||
class DistributionalVFPrepareTaskPromptStep(ProcessorStep):
|
||||
"""Format the task string for the distributional value function.
|
||||
|
||||
The value function receives only visual observations and task text.
|
||||
Builds prompt: "Task: {task}."
|
||||
"""
|
||||
|
||||
task_key: str = "task"
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
transition = transition.copy()
|
||||
|
||||
complementary_data = transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
tasks = complementary_data.get(self.task_key)
|
||||
if tasks is None:
|
||||
raise ValueError("No task found in complementary data")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
tasks = [tasks]
|
||||
|
||||
full_prompts = []
|
||||
for task in tasks:
|
||||
cleaned_text = task.strip().replace("_", " ").replace("\n", " ")
|
||||
full_prompts.append(f"Task: {cleaned_text}.")
|
||||
|
||||
new_complementary_data = dict(complementary_data)
|
||||
new_complementary_data[self.task_key] = full_prompts
|
||||
transition[TransitionKey.COMPLEMENTARY_DATA] = new_complementary_data
|
||||
return transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {"task_key": self.task_key}
|
||||
|
||||
|
||||
@ProcessorStepRegistry.register(name="distributional_vf_image_preprocessor")
|
||||
@dataclass
|
||||
class DistributionalVFImagePreprocessorStep(ProcessorStep):
|
||||
"""Resize and normalize images for the value function's SigLIP vision tower.
|
||||
|
||||
Expects float images in [0, 1].
|
||||
- Resize-with-pad to ``image_resolution`` (preserves aspect ratio)
|
||||
- Scale to [-1, 1] for SigLIP
|
||||
"""
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224)
|
||||
image_keys: tuple[str, ...] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
from lerobot.policies.pi05.modeling_pi05 import resize_with_pad_torch
|
||||
|
||||
observation = transition.get(TransitionKey.OBSERVATION)
|
||||
if not isinstance(observation, dict):
|
||||
raise ValueError("DistributionalVFImagePreprocessorStep requires an observation dict")
|
||||
|
||||
image_keys = self.image_keys or tuple(
|
||||
key for key in observation if key == OBS_IMAGES or key.startswith(f"{OBS_IMAGES}.")
|
||||
)
|
||||
if not image_keys:
|
||||
raise KeyError(
|
||||
f"Distributional value function expected image keys under {OBS_IMAGES!r} in observation"
|
||||
)
|
||||
|
||||
new_observation = dict(observation)
|
||||
for image_key in image_keys:
|
||||
image = new_observation[image_key]
|
||||
if not isinstance(image, Tensor):
|
||||
image = to_tensor(image)
|
||||
if image.dtype != torch.float32:
|
||||
image = image.to(torch.float32)
|
||||
|
||||
is_channels_first = image.ndim == 4 and image.shape[1] == 3
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 2, 3, 1)
|
||||
|
||||
if image.shape[1:3] != self.image_resolution:
|
||||
image = resize_with_pad_torch(image, *self.image_resolution)
|
||||
|
||||
image = image * 2.0 - 1.0
|
||||
|
||||
if is_channels_first:
|
||||
image = image.permute(0, 3, 1, 2)
|
||||
|
||||
new_observation[image_key] = image
|
||||
|
||||
new_transition = transition.copy()
|
||||
new_transition[TransitionKey.OBSERVATION] = new_observation
|
||||
return new_transition
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
return {
|
||||
"image_resolution": self.image_resolution,
|
||||
"image_keys": list(self.image_keys) if self.image_keys is not None else None,
|
||||
}
|
||||
|
||||
|
||||
def _visual_image_keys(config: DistributionalVFConfig) -> tuple[str, ...]:
|
||||
return tuple(
|
||||
feature_name
|
||||
for feature_name, feature in config.input_features.items()
|
||||
if feature.type == FeatureType.VISUAL
|
||||
)
|
||||
|
||||
|
||||
def make_distributional_vf_pre_post_processors(
|
||||
config: DistributionalVFConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create pre/post processors for the distributional value function.
|
||||
|
||||
Preprocessor steps:
|
||||
1. Rename observations (no-op by default)
|
||||
2. Add a batch dimension
|
||||
3. Normalize features (images use identity, so they stay in [0, 1])
|
||||
4. Format task prompt: "Task: {task}."
|
||||
5. Tokenize with the PaliGemma tokenizer
|
||||
6. Resize-with-pad and scale images to [-1, 1] for SigLIP
|
||||
7. Move tensors to the configured device
|
||||
|
||||
Training targets (mc_return, is_terminal) are not processed here.
|
||||
The model reads them directly from the batch in forward().
|
||||
|
||||
The postprocessor is a no-op because the value function does not need
|
||||
action postprocessing.
|
||||
"""
|
||||
image_keys = _visual_image_keys(config)
|
||||
|
||||
preprocessor = PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DistributionalVFPrepareTaskPromptStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=PALIGEMMA_TOKENIZER_NAME,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding_side="right",
|
||||
padding="max_length",
|
||||
),
|
||||
DistributionalVFImagePreprocessorStep(
|
||||
image_resolution=config.image_resolution,
|
||||
image_keys=image_keys or None,
|
||||
),
|
||||
DeviceProcessorStep(device=config.device or "cpu"),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=batch_to_transition,
|
||||
to_output=transition_to_batch,
|
||||
)
|
||||
postprocessor = PolicyProcessorPipeline(
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
)
|
||||
return preprocessor, postprocessor
|
||||
@@ -24,7 +24,6 @@ from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
|
||||
from .classifier.configuration_classifier import RewardClassifierConfig
|
||||
from .distributional_value_function.configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .pretrained import PreTrainedRewardModel
|
||||
from .robometer.configuration_robometer import RobometerConfig
|
||||
from .sarm.configuration_sarm import SARMConfig
|
||||
@@ -64,12 +63,6 @@ def get_reward_model_class(name: str) -> type[PreTrainedRewardModel]:
|
||||
from lerobot.rewards.topreward.modeling_topreward import TOPRewardModel
|
||||
|
||||
return TOPRewardModel
|
||||
elif name == "distributional_value_function":
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel
|
||||
else:
|
||||
try:
|
||||
return _get_reward_model_cls_from_name(name=name)
|
||||
@@ -103,8 +96,6 @@ def make_reward_model_config(reward_type: str, **kwargs) -> RewardModelConfig:
|
||||
return RobometerConfig(**kwargs)
|
||||
elif reward_type == "topreward":
|
||||
return TOPRewardConfig(**kwargs)
|
||||
elif reward_type == "distributional_value_function":
|
||||
return DistributionalVFConfig(**kwargs)
|
||||
else:
|
||||
try:
|
||||
config_cls = RewardModelConfig.get_choice_class(reward_type)
|
||||
@@ -200,16 +191,6 @@ def make_reward_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(reward_cfg, DistributionalVFConfig):
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
|
||||
return make_distributional_vf_pre_post_processors(
|
||||
config=reward_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
try:
|
||||
processors = _make_processors_from_reward_model_config(
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_follower import OpenArmFollower, OpenArmFollowerConfig
|
||||
from ..robot import Robot
|
||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_follower import BiOpenArmFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmFollower(Robot):
|
||||
class BiOpenArmFollower(BimanualMixin, Robot):
|
||||
"""
|
||||
Bimanual OpenArm Follower Arms
|
||||
"""
|
||||
@@ -39,15 +40,17 @@ class BiOpenArmFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are distributed evenly: each arm's OpenArmFollower
|
||||
# will only open the cameras assigned to it. Per-arm cameras are used
|
||||
# as fallback when top-level cameras are empty.
|
||||
if config.cameras:
|
||||
left_cameras = config.cameras
|
||||
right_cameras = {}
|
||||
else:
|
||||
left_cameras = config.left_arm_config.cameras
|
||||
right_cameras = config.right_arm_config.cameras
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = OpenArmFollowerConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
@@ -56,7 +59,7 @@ class BiOpenArmFollower(Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.left_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=left_cameras,
|
||||
cameras=left_arm_cameras,
|
||||
side=config.left_arm_config.side,
|
||||
can_interface=config.left_arm_config.can_interface,
|
||||
use_can_fd=config.left_arm_config.use_can_fd,
|
||||
@@ -75,7 +78,7 @@ class BiOpenArmFollower(Robot):
|
||||
disable_torque_on_disconnect=config.right_arm_config.disable_torque_on_disconnect,
|
||||
use_velocity_and_torque=config.right_arm_config.use_velocity_and_torque,
|
||||
max_relative_target=config.right_arm_config.max_relative_target,
|
||||
cameras=right_cameras,
|
||||
cameras=config.right_arm_config.cameras,
|
||||
side=config.right_arm_config.side,
|
||||
can_interface=config.right_arm_config.can_interface,
|
||||
use_can_fd=config.right_arm_config.use_can_fd,
|
||||
@@ -95,22 +98,19 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
left_arm_motors_ft = self.left_arm._motors_ft
|
||||
right_arm_motors_ft = self.right_arm._motors_ft
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
return {
|
||||
**{f"right_{k}": v for k, v in right_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in left_arm_motors_ft.items()},
|
||||
**{f"left_{k}": v for k, v in self.left_arm._motors_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._motors_ft.items()},
|
||||
}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
# Cameras already have unique user-chosen names (e.g. "left_wrist", "base",
|
||||
# "right_wrist"), so we merge them directly — unlike motors which need the
|
||||
# left_/right_ prefix to disambiguate identical per-arm joint names.
|
||||
return {**self.left_arm._cameras_ft, **self.right_arm._cameras_ft}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -120,27 +120,6 @@ class BiOpenArmFollower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -148,21 +127,15 @@ class BiOpenArmFollower(Robot):
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Camera keys that should NOT get the arm prefix (they already have unique names)
|
||||
left_cam_keys = set(self.left_arm.cameras.keys())
|
||||
right_cam_keys = set(self.right_arm.cameras.keys())
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
|
||||
# Right first, then left — matches the teleoperator (OpenArmMini) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
right_obs = self.right_arm.get_observation()
|
||||
for key, value in right_obs.items():
|
||||
obs_dict[key if key in right_cam_keys else f"right_{key}"] = value
|
||||
|
||||
left_obs = self.left_arm.get_observation()
|
||||
for key, value in left_obs.items():
|
||||
obs_dict[key if key in left_cam_keys else f"left_{key}"] = value
|
||||
# Add "right_" prefix
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -189,9 +162,4 @@ class BiOpenArmFollower(Robot):
|
||||
prefixed_sent_action_left = {f"left_{key}": value for key, value in sent_action_left.items()}
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_right, **prefixed_sent_action_left}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@@ -32,5 +32,7 @@ class BiOpenArmFollowerConfig(RobotConfig):
|
||||
left_arm_config: OpenArmFollowerConfigBase
|
||||
right_arm_config: OpenArmFollowerConfigBase
|
||||
|
||||
# Top-level cameras shared across both arms.
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..rebot_b601_follower import RebotB601Follower, RebotB601FollowerRobotConfig
|
||||
from ..robot import Robot
|
||||
@@ -27,7 +28,7 @@ from .config_bi_rebot_b601_follower import BiRebotB601FollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiRebotB601Follower(Robot):
|
||||
class BiRebotB601Follower(BimanualMixin, Robot):
|
||||
"""Bimanual Seeed Studio reBot B601-DM follower.
|
||||
|
||||
Composes two single-arm :class:`RebotB601Follower` instances. Observation and
|
||||
@@ -41,6 +42,18 @@ class BiRebotB601Follower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = RebotB601FollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -49,7 +62,7 @@ class BiRebotB601Follower(Robot):
|
||||
dm_serial_baud=config.left_arm_config.dm_serial_baud,
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
cameras=left_arm_cameras,
|
||||
motor_can_ids=config.left_arm_config.motor_can_ids,
|
||||
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
|
||||
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
|
||||
@@ -86,10 +99,12 @@ class BiRebotB601Follower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm._cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm._cameras_ft.items()},
|
||||
}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -99,32 +114,13 @@ class BiRebotB601Follower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict.update({f"left_{k}": v for k, v in self.left_arm.get_observation().items()})
|
||||
obs_dict.update({f"right_{k}": v for k, v in self.right_arm.get_observation().items()})
|
||||
obs_dict: RobotObservation = {}
|
||||
for k, v in self.left_arm.get_observation().items():
|
||||
obs_dict[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{k}"] = v
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
@@ -143,8 +139,3 @@ class BiRebotB601Follower(Robot):
|
||||
**{f"left_{k}": v for k, v in sent_action_left.items()},
|
||||
**{f"right_{k}": v for k, v in sent_action_right.items()},
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..rebot_b601_follower import RebotB601FollowerConfig
|
||||
@@ -27,3 +29,8 @@ class BiRebotB601FollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: RebotB601FollowerConfig
|
||||
right_arm_config: RebotB601FollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..so_follower import SOFollower, SOFollowerRobotConfig
|
||||
@@ -27,7 +28,7 @@ from .config_bi_so_follower import BiSOFollowerConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiSOFollower(Robot):
|
||||
class BiSOFollower(BimanualMixin, Robot):
|
||||
"""
|
||||
[Bimanual SO Follower Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
@@ -39,6 +40,18 @@ class BiSOFollower(Robot):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# Top-level cameras are opened by `left_arm` for convenience, but their
|
||||
# keys stay unprefixed in observations (tracked via `_top_level_cam_keys`).
|
||||
self._top_level_cam_keys = set(config.cameras)
|
||||
_collisions = self._top_level_cam_keys & set(
|
||||
config.left_arm_config.cameras
|
||||
) | self._top_level_cam_keys & set(config.right_arm_config.cameras)
|
||||
if _collisions:
|
||||
raise ValueError(
|
||||
f"Top-level camera names collide with per-arm camera names: {sorted(_collisions)}"
|
||||
)
|
||||
left_arm_cameras = {**config.left_arm_config.cameras, **config.cameras}
|
||||
|
||||
left_arm_config = SOFollowerRobotConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
@@ -46,7 +59,7 @@ class BiSOFollower(Robot):
|
||||
disable_torque_on_disconnect=config.left_arm_config.disable_torque_on_disconnect,
|
||||
max_relative_target=config.left_arm_config.max_relative_target,
|
||||
use_degrees=config.left_arm_config.use_degrees,
|
||||
cameras=config.left_arm_config.cameras,
|
||||
cameras=left_arm_cameras,
|
||||
)
|
||||
|
||||
right_arm_config = SOFollowerRobotConfig(
|
||||
@@ -77,13 +90,12 @@ class BiSOFollower(Robot):
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
left_arm_cameras_ft = self.left_arm._cameras_ft
|
||||
right_arm_cameras_ft = self.right_arm._cameras_ft
|
||||
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in left_arm_cameras_ft.items()},
|
||||
**{f"right_{k}": v for k, v in right_arm_cameras_ft.items()},
|
||||
}
|
||||
out: dict[str, tuple] = {}
|
||||
for k, v in self.left_arm._cameras_ft.items():
|
||||
out[k if k in self._top_level_cam_keys else f"left_{k}"] = v
|
||||
for k, v in self.right_arm._cameras_ft.items():
|
||||
out[f"right_{k}"] = v
|
||||
return out
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
@@ -93,42 +105,21 @@ class BiSOFollower(Robot):
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return self._motors_ft
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
obs_dict = {}
|
||||
obs_dict: RobotObservation = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
left_obs = self.left_arm.get_observation()
|
||||
obs_dict.update({f"left_{key}": value for key, value in left_obs.items()})
|
||||
# Add "left_" prefix to per-arm keys; keep top-level camera keys unprefixed.
|
||||
for key, value in self.left_arm.get_observation().items():
|
||||
obs_dict[key if key in self._top_level_cam_keys else f"left_{key}"] = value
|
||||
|
||||
# Add "right_" prefix
|
||||
right_obs = self.right_arm.get_observation()
|
||||
obs_dict.update({f"right_{key}": value for key, value in right_obs.items()})
|
||||
for key, value in self.right_arm.get_observation().items():
|
||||
obs_dict[f"right_{key}"] = value
|
||||
|
||||
return obs_dict
|
||||
|
||||
@@ -151,8 +142,3 @@ class BiSOFollower(Robot):
|
||||
prefixed_sent_action_right = {f"right_{key}": value for key, value in sent_action_right.items()}
|
||||
|
||||
return {**prefixed_sent_action_left, **prefixed_sent_action_right}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -14,7 +14,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.cameras import CameraConfig
|
||||
|
||||
from ..config import RobotConfig
|
||||
from ..so_follower import SOFollowerConfig
|
||||
@@ -27,3 +29,8 @@ class BiSOFollowerConfig(RobotConfig):
|
||||
|
||||
left_arm_config: SOFollowerConfig
|
||||
right_arm_config: SOFollowerConfig
|
||||
|
||||
# Top-level cameras not attached to a specific side. Keys are kept as-is in
|
||||
# observations (no `left_`/`right_` prefix). Per-arm cameras (declared on
|
||||
# `{left,right}_arm_config.cameras`) are prefixed.
|
||||
cameras: dict[str, CameraConfig] = field(default_factory=dict)
|
||||
|
||||
@@ -54,6 +54,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
|
||||
@@ -57,6 +57,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
|
||||
@@ -137,6 +137,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
|
||||
@@ -174,6 +174,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
homunculus,
|
||||
|
||||
@@ -41,6 +41,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
koch_leader,
|
||||
|
||||
@@ -89,6 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
bi_openarm_leader,
|
||||
bi_openarm_mini,
|
||||
bi_rebot_102_leader,
|
||||
bi_so_leader,
|
||||
gamepad,
|
||||
|
||||
@@ -232,15 +232,18 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
# Dataset loading synchronization: main process downloads first to avoid race conditions
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
# Dataset loading synchronization: each node's local main process downloads first to avoid
|
||||
# race conditions (the global main process only exists on node 0, so gating on it would let
|
||||
# all ranks of the other nodes download and build the Arrow cache concurrently).
|
||||
if accelerator.is_local_main_process:
|
||||
if is_main_process:
|
||||
logging.info("Creating dataset")
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
# Now all other processes can safely load the dataset
|
||||
if not is_main_process:
|
||||
# Now all other processes can safely load the dataset from the local cache
|
||||
if not accelerator.is_local_main_process:
|
||||
dataset = make_dataset(cfg)
|
||||
|
||||
# Create environment used for evaluating checkpoints during training on simulation data.
|
||||
@@ -386,12 +389,19 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
# create dataloader for offline training
|
||||
if hasattr(active_cfg, "drop_n_last_frames"):
|
||||
shuffle = False
|
||||
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
|
||||
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
|
||||
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
|
||||
sampler_generator = torch.Generator()
|
||||
if cfg.seed is not None:
|
||||
sampler_generator.manual_seed(cfg.seed)
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=active_cfg.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
generator=sampler_generator,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
|
||||
@@ -18,7 +18,8 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_leader import OpenArmLeader, OpenArmLeaderConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -27,7 +28,7 @@ from .config_bi_openarm_leader import BiOpenArmLeaderConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmLeader(Teleoperator):
|
||||
class BiOpenArmLeader(BimanualMixin, Teleoperator):
|
||||
"""
|
||||
Bimanual OpenArm Leader Arms
|
||||
"""
|
||||
@@ -86,27 +87,6 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
raise NotImplementedError(
|
||||
"Motor ID configuration is typically done via manufacturer tools for CAN motors."
|
||||
@@ -129,8 +109,3 @@ class BiOpenArmLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -23,7 +23,7 @@ from ..openarm_leader import OpenArmLeaderConfigBase
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_leader")
|
||||
@dataclass
|
||||
class BiOpenArmLeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Follower robots."""
|
||||
"""Configuration class for Bi OpenArm Leader teleoperators."""
|
||||
|
||||
left_arm_config: OpenArmLeaderConfigBase
|
||||
right_arm_config: OpenArmLeaderConfigBase
|
||||
|
||||
+6
-9
@@ -1,4 +1,6 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -12,12 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .configuration_distributional_value_function import DistributionalVFConfig
|
||||
from .modeling_distributional_value_function import DistributionalVFRewardModel
|
||||
from .processor_distributional_value_function import make_distributional_vf_pre_post_processors
|
||||
from .bi_openarm_mini import BiOpenArmMini
|
||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||
|
||||
__all__ = [
|
||||
"DistributionalVFConfig",
|
||||
"DistributionalVFRewardModel",
|
||||
"make_distributional_vf_pre_post_processors",
|
||||
]
|
||||
__all__ = ["BiOpenArmMini", "BiOpenArmMiniConfig"]
|
||||
@@ -0,0 +1,101 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..openarm_mini import OpenArmMini, OpenArmMiniConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_openarm_mini import BiOpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiOpenArmMini(BimanualMixin, Teleoperator):
|
||||
"""Bimanual OpenArm Mini teleoperator.
|
||||
|
||||
Composes two single-arm :class:`OpenArmMini` instances. Action and feedback
|
||||
keys of each arm are namespaced with a ``left_`` / ``right_`` prefix, so a
|
||||
bimanual leader can teleoperate a bimanual OpenArm follower.
|
||||
"""
|
||||
|
||||
config_class = BiOpenArmMiniConfig
|
||||
name = "bi_openarm_mini"
|
||||
|
||||
def __init__(self, config: BiOpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
# `side` is forced to match left/right regardless of what the user passed
|
||||
# on the per-arm base config — the bimanual wrapper owns the side semantics.
|
||||
left_arm_config = OpenArmMiniConfig(
|
||||
id=f"{config.id}_left" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.left_arm_config.port,
|
||||
side="left",
|
||||
use_degrees=config.left_arm_config.use_degrees,
|
||||
)
|
||||
|
||||
right_arm_config = OpenArmMiniConfig(
|
||||
id=f"{config.id}_right" if config.id else None,
|
||||
calibration_dir=config.calibration_dir,
|
||||
port=config.right_arm_config.port,
|
||||
side="right",
|
||||
use_degrees=config.right_arm_config.use_degrees,
|
||||
)
|
||||
|
||||
self.left_arm = OpenArmMini(left_arm_config)
|
||||
self.right_arm = OpenArmMini(right_arm_config)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm.action_features.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm.action_features.items()},
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {
|
||||
**{f"left_{k}": v for k, v in self.left_arm.feedback_features.items()},
|
||||
**{f"right_{k}": v for k, v in self.right_arm.feedback_features.items()},
|
||||
}
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action: RobotAction = {}
|
||||
for k, v in self.left_arm.get_action().items():
|
||||
action[f"left_{k}"] = v
|
||||
for k, v in self.right_arm.get_action().items():
|
||||
action[f"right_{k}"] = v
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
left_fb = {k.removeprefix("left_"): v for k, v in feedback.items() if k.startswith("left_")}
|
||||
right_fb = {k.removeprefix("right_"): v for k, v in feedback.items() if k.startswith("right_")}
|
||||
if left_fb:
|
||||
self.left_arm.send_feedback(left_fb)
|
||||
if right_fb:
|
||||
self.right_arm.send_feedback(right_fb)
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
from ..openarm_mini import OpenArmMiniConfigBase
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_openarm_mini")
|
||||
@dataclass
|
||||
class BiOpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration class for Bi OpenArm Mini teleoperators."""
|
||||
|
||||
left_arm_config: OpenArmMiniConfigBase
|
||||
right_arm_config: OpenArmMiniConfigBase
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||
from .bi_rebot_102_leader import BiRebot102Leader
|
||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||
|
||||
__all__ = ["BiRebotArm102Leader", "BiRebotArm102LeaderConfig"]
|
||||
__all__ = ["BiRebot102Leader", "BiRebot102LeaderConfig"]
|
||||
|
||||
@@ -18,16 +18,17 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..rebot_102_leader import RebotArm102Leader, RebotArm102LeaderTeleopConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_bi_rebot_102_leader import BiRebotArm102LeaderConfig
|
||||
from .config_bi_rebot_102_leader import BiRebot102LeaderConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiRebotArm102Leader(Teleoperator):
|
||||
class BiRebot102Leader(BimanualMixin, Teleoperator):
|
||||
"""Bimanual Seeed Studio StarArm102 / reBot Arm 102 leader.
|
||||
|
||||
Composes two single-arm :class:`RebotArm102Leader` instances. Action keys of
|
||||
@@ -35,10 +36,10 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
leader can teleoperate a bimanual reBot B601 follower.
|
||||
"""
|
||||
|
||||
config_class = BiRebotArm102LeaderConfig
|
||||
config_class = BiRebot102LeaderConfig
|
||||
name = "bi_rebot_102_leader"
|
||||
|
||||
def __init__(self, config: BiRebotArm102LeaderConfig):
|
||||
def __init__(self, config: BiRebot102LeaderConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
@@ -76,27 +77,6 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
@@ -106,8 +86,3 @@ class BiRebotArm102Leader(Teleoperator):
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not implemented for the reBot Arm 102 leader.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -22,7 +22,7 @@ from ..rebot_102_leader import RebotArm102LeaderConfig
|
||||
|
||||
@TeleoperatorConfig.register_subclass("bi_rebot_102_leader")
|
||||
@dataclass
|
||||
class BiRebotArm102LeaderConfig(TeleoperatorConfig):
|
||||
class BiRebot102LeaderConfig(TeleoperatorConfig):
|
||||
"""Configuration class for the bimanual reBot Arm 102 leader teleoperator."""
|
||||
|
||||
left_arm_config: RebotArm102LeaderConfig
|
||||
|
||||
@@ -17,7 +17,9 @@
|
||||
import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.types import RobotAction
|
||||
from lerobot.utils.bimanual import BimanualMixin
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader, SOLeaderTeleopConfig
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -26,7 +28,7 @@ from .config_bi_so_leader import BiSOLeaderConfig
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class BiSOLeader(Teleoperator):
|
||||
class BiSOLeader(BimanualMixin, Teleoperator):
|
||||
"""
|
||||
[Bimanual SO Leader Arms](https://github.com/TheRobotStudio/SO-ARM100) designed by TheRobotStudio
|
||||
"""
|
||||
@@ -67,33 +69,12 @@ class BiSOLeader(Teleoperator):
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
def get_action(self) -> RobotAction:
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
@@ -109,8 +90,3 @@ class BiSOLeader(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,7 +14,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .config_openarm_mini import OpenArmMiniConfig, OpenArmMiniConfigBase
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig", "OpenArmMiniConfigBase"]
|
||||
|
||||
@@ -19,12 +19,21 @@ from dataclasses import dataclass
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
class OpenArmMiniConfigBase:
|
||||
"""Base configuration for the OpenArm Mini teleoperator (Feetech STS3215, 7DOF + gripper)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
# Serial port for the Feetech bus (e.g., "/dev/ttyUSB0").
|
||||
port: str
|
||||
|
||||
# Side of the arm: "left" or "right". Controls per-joint direction flips applied
|
||||
# during readout. If `None`, no flipping is applied.
|
||||
side: str | None = None
|
||||
|
||||
use_degrees: bool = True
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig, OpenArmMiniConfigBase):
|
||||
pass
|
||||
|
||||
@@ -31,22 +31,22 @@ from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
# Per-side motor direction flips applied during readout.
|
||||
SIDE_MOTORS_TO_FLIP: dict[str, list[str]] = {
|
||||
"left": ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"],
|
||||
"right": ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_7"],
|
||||
}
|
||||
|
||||
# Leader joint 6 maps to follower joint 7 and vice versa
|
||||
# Leader joint 6 ↔ follower joint 7 (symmetric — its own inverse).
|
||||
JOINT_REMAP = {"joint_6": "joint_7", "joint_7": "joint_6"}
|
||||
JOINT_REMAP_REVERSE = {"joint_7": "joint_6", "joint_6": "joint_7"}
|
||||
|
||||
GRIPPER_TELEOP_TO_DEGREES = -0.65
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
"""OpenArm Mini single-arm teleoperator (Feetech STS3215, 7DOF + gripper).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
For the bimanual setup, see :class:`BiOpenArmMini` which composes two of these.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
@@ -56,9 +56,12 @@ class OpenArmMini(Teleoperator):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
if config.side is not None and config.side not in SIDE_MOTORS_TO_FLIP:
|
||||
raise ValueError(f"Invalid side '{config.side}'; expected 'left', 'right', or None.")
|
||||
self._motors_to_flip: list[str] = SIDE_MOTORS_TO_FLIP.get(config.side, []) if config.side else []
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
motors = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
@@ -69,46 +72,15 @@ class OpenArmMini(Teleoperator):
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
self.bus = FeetechMotorsBus(
|
||||
port=self.config.port,
|
||||
motors=motors,
|
||||
calibration=self.calibration,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
return {f"{motor}.pos": float for motor in self.bus.motors}
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
@@ -116,14 +88,12 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
logger.info(f"Connecting arm on {self.config.port}...")
|
||||
self.bus.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -133,14 +103,14 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
return self.bus.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
Run calibration procedure for a single OpenArm Mini arm.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
2. Ask user to position arm in hanging position with gripper closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
@@ -152,70 +122,51 @@ class OpenArmMini(Teleoperator):
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
self.bus.write_calibration(self.calibration)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
self.bus.disable_torque()
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
logger.info("Setting Phase to 12 for all motors...")
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Phase", motor, 12)
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"\nCalibration: Zero Position\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
homing_offsets = self.bus.set_half_turn_homings()
|
||||
logger.info("Arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
print("\nSetting motor ranges\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
motor_resolution = self.bus.model_resolution_table[list(self.bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
for motor_name, motor in self.bus.motors.items():
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
"\nGripper Calibration\n"
|
||||
"Step 1: CLOSE the gripper fully\n"
|
||||
"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
closed_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
open_pos = self.bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
@@ -228,16 +179,16 @@ class OpenArmMini(Teleoperator):
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f" {motor_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
logger.info(f" {motor_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
self.calibration[motor_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
@@ -245,108 +196,68 @@ class OpenArmMini(Teleoperator):
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
self.bus.write_calibration(self.calibration)
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
self.bus.disable_torque()
|
||||
self.bus.configure_motors()
|
||||
for motor in self.bus.motors:
|
||||
self.bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
for motor in reversed(self.bus.motors):
|
||||
input(f"Connect the controller board to the '{motor}' motor only and press enter.")
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
"""Get current action (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
positions = self.bus.sync_read("Present_Position")
|
||||
|
||||
# Right first, then left — matches the robot (BiOpenArmFollower) ordering
|
||||
# and the dataset feature names recorded during data collection.
|
||||
# Joint 6↔7 remap: leader joint_6 → follower joint_7 and vice versa.
|
||||
# Per-side direction flip is applied based on the configured `side`.
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
for motor, val in positions.items():
|
||||
target = JOINT_REMAP.get(motor, motor)
|
||||
if motor == "gripper":
|
||||
# Convert gripper from teleop 0-100 to openarms degrees: 0→0°, 100→-65°
|
||||
action[f"right_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
action[f"{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
action[f"right_{target}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
target = JOINT_REMAP.get(motor, motor)
|
||||
if motor == "gripper":
|
||||
action[f"left_{target}.pos"] = val * GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
action[f"left_{target}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
action[f"{target}.pos"] = -val if motor in self._motors_to_flip else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def enable_torque(self) -> None:
|
||||
"""Enable torque on both arms for position control."""
|
||||
self.bus_right.enable_torque()
|
||||
self.bus_left.enable_torque()
|
||||
self.bus.enable_torque()
|
||||
|
||||
def disable_torque(self) -> None:
|
||||
"""Disable torque on both arms for free movement."""
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_left.disable_torque()
|
||||
self.bus.disable_torque()
|
||||
|
||||
def write_goal_positions(self, positions: dict[str, float]) -> None:
|
||||
"""Write goal positions to motors (inverse of get_action flip/gripper/remap logic)."""
|
||||
right_goals: dict[str, float] = {}
|
||||
left_goals: dict[str, float] = {}
|
||||
|
||||
goals: dict[str, float] = {}
|
||||
for key, val in positions.items():
|
||||
if not key.endswith(".pos"):
|
||||
continue
|
||||
motor_name = key.removesuffix(".pos")
|
||||
if motor_name.startswith("right_"):
|
||||
base = motor_name.removeprefix("right_")
|
||||
# Reverse remap: follower joint_7 → leader joint_6 and vice versa
|
||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||
if base == "gripper":
|
||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||
right_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||
right_goals[target] = -val if target in RIGHT_MOTORS_TO_FLIP else val
|
||||
elif motor_name.startswith("left_"):
|
||||
base = motor_name.removeprefix("left_")
|
||||
target = JOINT_REMAP_REVERSE.get(base, base)
|
||||
if base == "gripper":
|
||||
left_goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
left_goals[target] = -val if target in LEFT_MOTORS_TO_FLIP else val
|
||||
base = key.removesuffix(".pos")
|
||||
# JOINT_REMAP is symmetric (its own inverse).
|
||||
target = JOINT_REMAP.get(base, base)
|
||||
if base == "gripper":
|
||||
# Convert robot degrees to teleop 0-100: 0°→0, -65°→100
|
||||
goals[target] = val / GRIPPER_TELEOP_TO_DEGREES
|
||||
else:
|
||||
# Un-flip using the ORIGINAL motor name (target = leader motor)
|
||||
goals[target] = -val if target in self._motors_to_flip else val
|
||||
|
||||
if right_goals:
|
||||
self.bus_right.sync_write("Goal_Position", right_goals)
|
||||
if left_goals:
|
||||
self.bus_left.sync_write("Goal_Position", left_goals)
|
||||
if goals:
|
||||
self.bus.sync_write("Goal_Position", goals)
|
||||
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
@@ -354,6 +265,5 @@ class OpenArmMini(Teleoperator):
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -99,14 +99,18 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
elif config.type == "bi_openarm_mini":
|
||||
from .bi_openarm_mini import BiOpenArmMini
|
||||
|
||||
return BiOpenArmMini(config)
|
||||
elif config.type == "rebot_102_leader":
|
||||
from .rebot_102_leader import RebotArm102Leader
|
||||
|
||||
return RebotArm102Leader(config)
|
||||
elif config.type == "bi_rebot_102_leader":
|
||||
from .bi_rebot_102_leader import BiRebotArm102Leader
|
||||
from .bi_rebot_102_leader import BiRebot102Leader
|
||||
|
||||
return BiRebotArm102Leader(config)
|
||||
return BiRebot102Leader(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -0,0 +1,63 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
|
||||
class BimanualMixin:
|
||||
"""Lifecycle delegation for bimanual robots and teleoperators.
|
||||
|
||||
Concrete subclasses must populate ``self.left_arm`` and ``self.right_arm`` in
|
||||
their own ``__init__``. They retain ownership of feature dicts and the
|
||||
data-routing methods (``get_action`` / ``send_action`` / ``get_observation`` /
|
||||
``send_feedback``), which vary per-embodiment.
|
||||
|
||||
Inherit before the ``Robot`` / ``Teleoperator`` base so the mixin's methods
|
||||
take precedence in the MRO::
|
||||
|
||||
class BiFooFollower(BimanualMixin, Robot): ...
|
||||
"""
|
||||
|
||||
left_arm: Any
|
||||
right_arm: Any
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.left_arm.is_connected and self.right_arm.is_connected
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.left_arm.is_calibrated and self.right_arm.is_calibrated
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
self.left_arm.connect(calibrate)
|
||||
self.right_arm.connect(calibrate)
|
||||
|
||||
def calibrate(self) -> None:
|
||||
self.left_arm.calibrate()
|
||||
self.right_arm.calibrate()
|
||||
|
||||
def configure(self) -> None:
|
||||
self.left_arm.configure()
|
||||
self.right_arm.configure()
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.left_arm.disconnect()
|
||||
self.right_arm.disconnect()
|
||||
@@ -114,6 +114,30 @@ def test_shuffle():
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_shuffle_with_generator_is_deterministic():
|
||||
# Two samplers shuffling with same-seed generators must yield identical permutations.
|
||||
# This is what keeps batch shards disjoint across ranks in distributed training, where
|
||||
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
|
||||
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
assert list(sampler_a) == list(sampler_b)
|
||||
|
||||
# Desyncing the global RNG must not affect the permutation.
|
||||
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
|
||||
order_before = list(sampler_c)
|
||||
sampler_c.generator.manual_seed(42)
|
||||
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
|
||||
assert list(sampler_c) == order_before
|
||||
|
||||
|
||||
def test_generator_attribute_defaults_to_none():
|
||||
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
|
||||
# so the attribute must exist even when no generator is passed.
|
||||
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
|
||||
assert sampler.generator is None
|
||||
assert set(sampler) == {0, 1, 2, 3, 4, 5}
|
||||
|
||||
|
||||
def test_negative_drop_first_frames_raises():
|
||||
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
|
||||
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
|
||||
|
||||
@@ -1,518 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RECAP's distributional value function."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.rewards import RewardModelConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.rewards.distributional_value_function.configuration_distributional_value_function import (
|
||||
DistributionalVFConfig,
|
||||
)
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from tests.utils import skip_if_package_missing
|
||||
|
||||
BATCH_SIZE = 4
|
||||
NUM_BINS = 201
|
||||
IMAGE_KEY = f"{OBS_IMAGES}.top"
|
||||
|
||||
|
||||
def _make_config(**overrides) -> DistributionalVFConfig:
|
||||
defaults = {
|
||||
"init_from_actor_path": "",
|
||||
"device": "cpu",
|
||||
"image_resolution": (224, 224),
|
||||
}
|
||||
defaults.update(overrides)
|
||||
config = DistributionalVFConfig(**defaults)
|
||||
config.input_features = {
|
||||
IMAGE_KEY: PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {}
|
||||
config.normalization_mapping = {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
}
|
||||
return config
|
||||
|
||||
|
||||
def _make_model():
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
return DistributionalVFRewardModel(_make_config())
|
||||
|
||||
|
||||
def _make_batch(batch_size: int = BATCH_SIZE, device: str = "cpu") -> dict[str, torch.Tensor]:
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
return {
|
||||
IMAGE_KEY: torch.rand(batch_size, 3, 224, 224, device=device),
|
||||
OBS_LANGUAGE_TOKENS: torch.randint(0, 1000, (batch_size, 16), device=device),
|
||||
OBS_LANGUAGE_ATTENTION_MASK: torch.ones(batch_size, 16, dtype=torch.bool, device=device),
|
||||
"mc_return": torch.rand(batch_size, device=device) * -1.0,
|
||||
"is_terminal": torch.zeros(batch_size, dtype=torch.bool, device=device),
|
||||
}
|
||||
|
||||
|
||||
def test_config_registered_in_reward_model_registry():
|
||||
"""DistributionalVFConfig is discoverable via RewardModelConfig registry."""
|
||||
known = RewardModelConfig.get_known_choices()
|
||||
assert "distributional_value_function" in known
|
||||
|
||||
|
||||
def test_factory_returns_correct_class():
|
||||
"""get_reward_model_class returns DistributionalVFRewardModel."""
|
||||
from lerobot.rewards.factory import get_reward_model_class
|
||||
|
||||
cls = get_reward_model_class("distributional_value_function")
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
assert cls is DistributionalVFRewardModel
|
||||
|
||||
|
||||
def test_make_reward_model_config_factory():
|
||||
"""make_reward_model_config creates DistributionalVFConfig with overrides."""
|
||||
from lerobot.rewards.factory import make_reward_model_config
|
||||
|
||||
config = make_reward_model_config("distributional_value_function", num_value_bins=101)
|
||||
assert isinstance(config, DistributionalVFConfig)
|
||||
assert config.num_value_bins == 101
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_sums_to_one():
|
||||
"""HL-Gauss target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -0.0])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(4), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_non_negative():
|
||||
"""HL-Gauss target probabilities are all non-negative."""
|
||||
model = _make_model()
|
||||
targets = torch.linspace(-1.0, 0.0, 10)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert (dist >= 0).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_expected_value_matches():
|
||||
"""E[V] under HL-Gauss distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.hl_gauss_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-4, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_hl_gauss_handles_2d_input():
|
||||
"""HL-Gauss handles [batch_size, 1] shaped inputs correctly."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3]).unsqueeze(-1)
|
||||
dist = model.hl_gauss_target(targets)
|
||||
|
||||
assert dist.shape == (2, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_sums_to_one():
|
||||
"""Dirac delta target distribution sums to 1 for each sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9, -1.0, 0.0])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
assert dist.shape == (5, NUM_BINS)
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(5), atol=1e-6, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_at_most_two_nonzero():
|
||||
"""Dirac delta places probability on at most two adjacent bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.7523, -0.0013])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
for i in range(2):
|
||||
assert (dist[i] > 0).sum() <= 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_expected_value_matches():
|
||||
"""E[V] under Dirac delta distribution matches the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -0.9])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
expected = (dist * model.bin_centers).sum(dim=-1)
|
||||
|
||||
torch.testing.assert_close(expected, targets, atol=1e-5, rtol=0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_dirac_delta_boundary_values_clamped():
|
||||
"""Values outside support are clamped to boundary bins."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-1.5, 0.5])
|
||||
dist = model.dirac_delta_target(targets)
|
||||
|
||||
torch.testing.assert_close(dist.sum(dim=-1), torch.ones(2), atol=1e-6, rtol=0)
|
||||
assert dist[0, 0] == 1.0
|
||||
assert dist[1, -1] == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_single_nonzero():
|
||||
"""One-hot target has exactly one non-zero bin per sample."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.1, -1.0, 0.0])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
assert dist.shape == (4, NUM_BINS)
|
||||
for i in range(4):
|
||||
assert (dist[i] > 0).sum() == 1
|
||||
assert dist[i].sum() == 1.0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_one_hot_nearest_bin():
|
||||
"""One-hot target activates the bin closest to the target value."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5])
|
||||
dist = model.one_hot_target(targets)
|
||||
|
||||
hot_idx = dist[0].argmax()
|
||||
assert model.bin_centers[hot_idx].item() == pytest.approx(-0.5, abs=0.003)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_terminal_gets_one_hot():
|
||||
"""Terminal states receive one-hot targets; non-terminal get HL-Gauss."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3, -0.7, -0.9])
|
||||
is_terminal = torch.tensor([False, True, False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=True
|
||||
)
|
||||
|
||||
for i in range(4):
|
||||
assert dist[i].sum().item() == pytest.approx(1.0, abs=1e-5)
|
||||
assert (dist[1] > 0).sum() == 1
|
||||
assert (dist[3] > 0).sum() == 1
|
||||
assert (dist[0] > 0).sum() > 2
|
||||
assert (dist[2] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_no_terminal_override_when_disabled():
|
||||
"""When use_one_hot_terminal=False, terminal states use the base method."""
|
||||
model = _make_model()
|
||||
targets = torch.tensor([-0.5, -0.3])
|
||||
is_terminal = torch.tensor([False, True])
|
||||
|
||||
dist = model.compute_target_distribution(
|
||||
targets, is_terminal, method="hl_gauss", use_one_hot_terminal=False
|
||||
)
|
||||
|
||||
assert (dist[1] > 0).sum() > 2
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_has_expected_components():
|
||||
"""Model scaffold contains all architectural components."""
|
||||
model = _make_model()
|
||||
|
||||
assert hasattr(model, "vision_tower")
|
||||
assert hasattr(model, "multi_modal_projector")
|
||||
assert hasattr(model, "token_embedding")
|
||||
assert hasattr(model, "layers")
|
||||
assert hasattr(model, "value_head")
|
||||
assert hasattr(model, "cls_embedding")
|
||||
assert hasattr(model, "norm")
|
||||
assert hasattr(model, "rotary_emb")
|
||||
assert hasattr(model, "bin_centers")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_bin_centers_shape():
|
||||
"""Bin centers buffer has shape (num_value_bins,)."""
|
||||
model = _make_model()
|
||||
assert model.bin_centers.shape == (NUM_BINS,)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_layer_count():
|
||||
"""Transformer has num_hidden_layers (6) layers."""
|
||||
model = _make_model()
|
||||
assert len(model.layers) == 6
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_model_value_head_output_dim():
|
||||
"""Value head outputs num_value_bins logits."""
|
||||
model = _make_model()
|
||||
assert model.value_head.out_features == NUM_BINS
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_returns_loss_and_dict():
|
||||
"""Forward pass returns a finite scalar loss and output dict with expected keys."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, output_dict = model.forward(batch)
|
||||
|
||||
assert loss.shape == ()
|
||||
assert torch.isfinite(loss)
|
||||
assert "loss" in output_dict
|
||||
assert "predicted_value_mean" in output_dict
|
||||
assert "mc_return_mean" in output_dict
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_forward_loss_is_positive():
|
||||
"""Cross-entropy loss is strictly positive for random weights."""
|
||||
model = _make_model()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
|
||||
assert loss.item() > 0
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_returns_correct_shape():
|
||||
"""compute_reward returns [batch_size] tensor of finite float32 values."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=3)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert values.shape == (3,)
|
||||
assert values.dtype == torch.float32
|
||||
assert torch.isfinite(values).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_compute_reward_values_in_support_range():
|
||||
"""Predicted values lie within [value_support_min, value_support_max]."""
|
||||
model = _make_model()
|
||||
model.eval()
|
||||
batch = _make_batch(batch_size=8)
|
||||
|
||||
with torch.no_grad():
|
||||
values = model.compute_reward(batch)
|
||||
|
||||
assert (values >= -1.0 - 0.01).all()
|
||||
assert (values <= 0.0 + 0.01).all()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_processor_pipeline_produces_expected_keys():
|
||||
"""Full preprocessor pipeline produces tokenized text and processed images."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
make_distributional_vf_pre_post_processors,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_TOKENS
|
||||
|
||||
config = _make_config()
|
||||
preprocessor, _ = make_distributional_vf_pre_post_processors(config)
|
||||
|
||||
raw_batch = {
|
||||
IMAGE_KEY: torch.rand(3, 224, 224),
|
||||
"task": "pick up the cup",
|
||||
}
|
||||
|
||||
processed = preprocessor(raw_batch)
|
||||
|
||||
assert OBS_LANGUAGE_TOKENS in processed
|
||||
assert OBS_LANGUAGE_ATTENTION_MASK in processed
|
||||
assert IMAGE_KEY in processed
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_value_head():
|
||||
"""Backprop produces non-zero gradients on the value head."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.value_head.weight.grad is not None
|
||||
assert not torch.all(model.value_head.weight.grad == 0)
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_gradient_flows_through_cls_embedding():
|
||||
"""Backprop produces non-zero gradients on the learned [CLS] embedding."""
|
||||
model = _make_model()
|
||||
model.train()
|
||||
batch = _make_batch()
|
||||
|
||||
loss, _ = model.forward(batch)
|
||||
loss.backward()
|
||||
|
||||
assert model.cls_embedding.grad is not None
|
||||
assert not torch.all(model.cls_embedding.grad == 0)
|
||||
|
||||
|
||||
def test_config_requires_visual_feature():
|
||||
"""validate_features raises if no VISUAL feature is present."""
|
||||
config = DistributionalVFConfig(init_from_actor_path="")
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="VISUAL"):
|
||||
config.validate_features()
|
||||
|
||||
|
||||
def test_config_passes_with_visual_feature():
|
||||
"""validate_features succeeds when a VISUAL feature is present."""
|
||||
config = _make_config()
|
||||
config.validate_features()
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_save_load_pretrained_roundtrip(tmp_path):
|
||||
"""Saved model can be loaded back with identical weights."""
|
||||
from lerobot.rewards.distributional_value_function.modeling_distributional_value_function import (
|
||||
DistributionalVFRewardModel,
|
||||
)
|
||||
|
||||
model = _make_model()
|
||||
model._save_pretrained(tmp_path)
|
||||
|
||||
loaded = DistributionalVFRewardModel.from_pretrained(str(tmp_path))
|
||||
|
||||
orig_sd = model.state_dict()
|
||||
loaded_sd = loaded.state_dict()
|
||||
|
||||
assert set(orig_sd.keys()) == set(loaded_sd.keys())
|
||||
for key in orig_sd:
|
||||
torch.testing.assert_close(orig_sd[key], loaded_sd[key], msg=f"Mismatch in {key}")
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_normalizes_to_minus_one_one():
|
||||
"""Image preprocessor scales [0, 1] float input to [-1, 1] for SigLIP."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 224, 224, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.min() >= -1.0 - 1e-5
|
||||
assert image.max() <= 1.0 + 1e-5
|
||||
|
||||
|
||||
@skip_if_package_missing("transformers")
|
||||
def test_image_preprocessor_resizes_with_pad():
|
||||
"""Image preprocessor resizes non-square images to target resolution."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFImagePreprocessorStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFImagePreprocessorStep(image_resolution=(224, 224), image_keys=(IMAGE_KEY,))
|
||||
|
||||
transition = {
|
||||
TransitionKey.OBSERVATION: {
|
||||
IMAGE_KEY: torch.rand(1, 480, 640, 3),
|
||||
},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
image = result[TransitionKey.OBSERVATION][IMAGE_KEY]
|
||||
|
||||
assert image.shape[1:3] == (224, 224)
|
||||
|
||||
|
||||
def test_task_prompt_formats_correctly():
|
||||
"""Task prompt step converts underscored task to 'Task: {text}.' format."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": ["pick_up_the_cup"]},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: pick up the cup."
|
||||
|
||||
|
||||
def test_task_prompt_handles_string_input():
|
||||
"""Task prompt step accepts a plain string (not just a list)."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {"task": "open_drawer"},
|
||||
}
|
||||
|
||||
result = step(transition)
|
||||
prompt = result[TransitionKey.COMPLEMENTARY_DATA]["task"][0]
|
||||
|
||||
assert prompt == "Task: open drawer."
|
||||
|
||||
|
||||
def test_task_prompt_raises_on_missing_task():
|
||||
"""Task prompt step raises ValueError when task key is absent."""
|
||||
from lerobot.rewards.distributional_value_function.processor_distributional_value_function import (
|
||||
DistributionalVFPrepareTaskPromptStep,
|
||||
)
|
||||
|
||||
step = DistributionalVFPrepareTaskPromptStep()
|
||||
|
||||
transition = {
|
||||
TransitionKey.COMPLEMENTARY_DATA: {},
|
||||
}
|
||||
|
||||
with pytest.raises(ValueError, match="No task found"):
|
||||
step(transition)
|
||||
@@ -18,7 +18,7 @@ from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebotArm102Leader, BiRebotArm102LeaderConfig
|
||||
from lerobot.teleoperators.bi_rebot_102_leader import BiRebot102Leader, BiRebot102LeaderConfig
|
||||
from lerobot.teleoperators.rebot_102_leader import (
|
||||
RebotArm102Leader,
|
||||
RebotArm102LeaderConfig,
|
||||
@@ -91,11 +91,11 @@ def test_send_feedback_not_implemented(leader):
|
||||
|
||||
def test_bimanual_prefixes_features():
|
||||
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
|
||||
cfg = BiRebotArm102LeaderConfig(
|
||||
cfg = BiRebot102LeaderConfig(
|
||||
left_arm_config=RebotArm102LeaderConfig(port="/dev/null0"),
|
||||
right_arm_config=RebotArm102LeaderConfig(port="/dev/null1"),
|
||||
)
|
||||
teleop = BiRebotArm102Leader(cfg)
|
||||
teleop = BiRebot102Leader(cfg)
|
||||
assert any(k.startswith("left_") for k in teleop.action_features)
|
||||
assert any(k.startswith("right_") for k in teleop.action_features)
|
||||
assert "left_gripper.pos" in teleop.action_features
|
||||
|
||||
Reference in New Issue
Block a user