mirror of
https://github.com/huggingface/lerobot.git
synced 2026-07-03 16:17:15 +00:00
Merge branch 'main' into feature/add-multitask-dit
This commit is contained in:
@@ -935,17 +935,30 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
else:
|
||||
return get_hf_features_from_features(self.features)
|
||||
|
||||
def _get_query_indices(self, idx: int, ep_idx: int) -> tuple[dict[str, list[int | bool]]]:
|
||||
def _get_query_indices(
|
||||
self, abs_idx: int, ep_idx: int
|
||||
) -> tuple[dict[str, list[int]], dict[str, torch.Tensor]]:
|
||||
"""Compute query indices for delta timestamps.
|
||||
|
||||
Args:
|
||||
abs_idx: The absolute index in the full dataset (not the relative index in filtered episodes).
|
||||
ep_idx: The episode index.
|
||||
|
||||
Returns:
|
||||
A tuple of (query_indices, padding) where:
|
||||
- query_indices: Dict mapping keys to lists of absolute indices to query
|
||||
- padding: Dict mapping "{key}_is_pad" to boolean tensors indicating padded positions
|
||||
"""
|
||||
ep = self.meta.episodes[ep_idx]
|
||||
ep_start = ep["dataset_from_index"]
|
||||
ep_end = ep["dataset_to_index"]
|
||||
query_indices = {
|
||||
key: [max(ep_start, min(ep_end - 1, idx + delta)) for delta in delta_idx]
|
||||
key: [max(ep_start, min(ep_end - 1, abs_idx + delta)) for delta in delta_idx]
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
padding = { # Pad values outside of current episode range
|
||||
f"{key}_is_pad": torch.BoolTensor(
|
||||
[(idx + delta < ep_start) | (idx + delta >= ep_end) for delta in delta_idx]
|
||||
[(abs_idx + delta < ep_start) | (abs_idx + delta >= ep_end) for delta in delta_idx]
|
||||
)
|
||||
for key, delta_idx in self.delta_indices.items()
|
||||
}
|
||||
@@ -1037,10 +1050,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self._ensure_hf_dataset_loaded()
|
||||
item = self.hf_dataset[idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
# Use the absolute index from the dataset for delta timestamp calculations
|
||||
abs_idx = item["index"].item()
|
||||
|
||||
query_indices = None
|
||||
if self.delta_indices is not None:
|
||||
query_indices, padding = self._get_query_indices(idx, ep_idx)
|
||||
query_indices, padding = self._get_query_indices(abs_idx, ep_idx)
|
||||
query_result = self._query_hf_dataset(query_indices)
|
||||
item = {**item, **padding}
|
||||
for key, val in query_result.items():
|
||||
@@ -1498,7 +1513,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_index = self.episode_buffer["episode_index"]
|
||||
if isinstance(episode_index, np.ndarray):
|
||||
episode_index = episode_index.item() if episode_index.size == 1 else episode_index[0]
|
||||
for cam_key in self.meta.camera_keys:
|
||||
for cam_key in self.meta.image_keys:
|
||||
img_dir = self._get_image_file_dir(episode_index, cam_key)
|
||||
if img_dir.is_dir():
|
||||
shutil.rmtree(img_dir)
|
||||
|
||||
@@ -32,7 +32,7 @@ import serial
|
||||
from deepdiff import DeepDiff
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
NameOrID: TypeAlias = str | int
|
||||
@@ -411,6 +411,7 @@ class MotorsBus(abc.ABC):
|
||||
"""bool: `True` if the underlying serial port is open."""
|
||||
return self.port_handler.is_open
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, handshake: bool = True) -> None:
|
||||
"""Open the serial port and initialise communication.
|
||||
|
||||
@@ -422,10 +423,6 @@ class MotorsBus(abc.ABC):
|
||||
DeviceAlreadyConnectedError: The port is already open.
|
||||
ConnectionError: The underlying SDK failed to open the port or the handshake did not succeed.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is already connected. Do not call `{self.__class__.__name__}.connect()` twice."
|
||||
)
|
||||
|
||||
self._connect(handshake)
|
||||
self.set_timeout()
|
||||
@@ -447,6 +444,7 @@ class MotorsBus(abc.ABC):
|
||||
def _handshake(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self, disable_torque: bool = True) -> None:
|
||||
"""Close the serial port (optionally disabling torque first).
|
||||
|
||||
@@ -455,10 +453,6 @@ class MotorsBus(abc.ABC):
|
||||
closing the port. This can prevent damaging motors if they are left applying resisting torque
|
||||
after disconnect.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. Try running `{self.__class__.__name__}.connect()` first."
|
||||
)
|
||||
|
||||
if disable_torque:
|
||||
self.port_handler.clearPort()
|
||||
@@ -907,6 +901,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -927,10 +922,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
Value: Raw or normalised value depending on *normalize*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -981,6 +972,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return value, comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def write(
|
||||
self, data_name: str, motor: str, value: Value, *, normalize: bool = True, num_retry: int = 0
|
||||
) -> None:
|
||||
@@ -999,10 +991,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): Enable or disable normalisation. Defaults to `True`.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
id_ = self.motors[motor].id
|
||||
model = self.motors[motor].model
|
||||
@@ -1044,6 +1032,7 @@ class MotorsBus(abc.ABC):
|
||||
|
||||
return comm, error
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_read(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1063,10 +1052,6 @@ class MotorsBus(abc.ABC):
|
||||
Returns:
|
||||
dict[str, Value]: Mapping *motor name → value*.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
self._assert_protocol_is_compatible("sync_read")
|
||||
|
||||
@@ -1139,6 +1124,7 @@ class MotorsBus(abc.ABC):
|
||||
# for id_ in motor_ids:
|
||||
# value = self.sync_reader.getData(id_, address, length)
|
||||
|
||||
@check_if_not_connected
|
||||
def sync_write(
|
||||
self,
|
||||
data_name: str,
|
||||
@@ -1160,10 +1146,6 @@ class MotorsBus(abc.ABC):
|
||||
normalize (bool, optional): If `True` (default) convert values from the user range to raw units.
|
||||
num_retry (int, optional): Retry attempts. Defaults to `0`.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__}('{self.port}') is not connected. You need to run `{self.__class__.__name__}.connect()`."
|
||||
)
|
||||
|
||||
ids_values = self._get_ids_values_dict(values)
|
||||
models = [self._id_to_model(id_) for id_ in ids_values]
|
||||
|
||||
@@ -1297,3 +1297,14 @@ class PI0Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -1270,3 +1270,14 @@ class PI05Policy(PreTrainedPolicy):
|
||||
loss = losses.mean()
|
||||
loss_dict["loss"] = loss.item()
|
||||
return loss, loss_dict
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for PI0.5 fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(.*\.gemma_expert\..*\.self_attn\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import builtins
|
||||
import dataclasses
|
||||
import logging
|
||||
import os
|
||||
from importlib.resources import files
|
||||
@@ -265,3 +266,166 @@ class PreTrainedPolicy(nn.Module, HubMixin, abc.ABC):
|
||||
card = ModelCard.from_template(card_data, template_str=template_card)
|
||||
card.validate()
|
||||
return card
|
||||
|
||||
def wrap_with_peft(
|
||||
self,
|
||||
peft_config=None,
|
||||
peft_cli_overrides: dict | None = None,
|
||||
) -> "PreTrainedPolicy":
|
||||
"""
|
||||
Wrap this policy with PEFT adapters for parameter-efficient fine-tuning.
|
||||
|
||||
This method is the single entry point for PEFT integration. Subclasses should
|
||||
override `_get_default_peft_targets()` to provide default target modules, and
|
||||
`_validate_peft_config()` for policy-specific validation.
|
||||
|
||||
Args:
|
||||
peft_config: Optional PEFT adapter configuration (e.g., LoraConfig).
|
||||
If provided, used directly (with CLI overrides applied).
|
||||
peft_cli_overrides: Optional dict of CLI overrides (method_type, target_modules, r, etc.)
|
||||
These are merged with policy defaults to build the final config.
|
||||
"""
|
||||
from peft import get_peft_model
|
||||
|
||||
# If user provided a complete config, use it directly (with overrides)
|
||||
if peft_config is not None:
|
||||
final_config = peft_config
|
||||
if peft_cli_overrides:
|
||||
final_config = self._apply_peft_cli_overrides(final_config, peft_cli_overrides)
|
||||
else:
|
||||
# Build config from defaults + CLI overrides
|
||||
final_config = self._build_peft_config(peft_cli_overrides or {})
|
||||
|
||||
# Validate the configuration
|
||||
self._validate_peft_config(final_config)
|
||||
|
||||
# Freeze base parameters, only adapter params will be trained
|
||||
for p in self.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
# Store pretrained path for PEFT's base_model_name_or_path
|
||||
if self.config.pretrained_path:
|
||||
self.name_or_path = str(self.config.pretrained_path)
|
||||
|
||||
# Wrap with PEFT
|
||||
peft_model = get_peft_model(self, final_config)
|
||||
|
||||
# Mark config as using PEFT for proper loading later
|
||||
peft_model.config.use_peft = True
|
||||
|
||||
logging.info(f"Wrapped {self.name} with PEFT ({type(final_config).__name__})")
|
||||
return peft_model
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any] | None:
|
||||
"""
|
||||
Return default PEFT target modules for this policy.
|
||||
|
||||
Override this in subclasses to provide policy-specific defaults. These defaults
|
||||
are PEFT-method agnostic - they only specify which modules to target.
|
||||
|
||||
"""
|
||||
return None
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""
|
||||
Validate the PEFT configuration for this policy.
|
||||
|
||||
Override this in subclasses to add policy-specific validation or warnings.
|
||||
The default implementation checks that a pretrained_path exists.
|
||||
|
||||
Args:
|
||||
peft_config: The PEFT configuration to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If the configuration is invalid.
|
||||
"""
|
||||
if not self.config.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT is unlikely to yield good results. "
|
||||
"Supply a `policy.pretrained_path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
def _preprocess_peft_cli_overrides(self, cli_overrides: dict, peft_method_type) -> dict:
|
||||
"""
|
||||
Preprocess CLI overrides: rename keys and handle method-specific init_type.
|
||||
|
||||
Args:
|
||||
cli_overrides: Dict of CLI options (will be copied, not mutated).
|
||||
peft_method_type: The PeftType enum value for the PEFT method.
|
||||
|
||||
Returns:
|
||||
Preprocessed dict with renamed keys and init_type mapped to method-specific key.
|
||||
"""
|
||||
from peft import PeftType
|
||||
|
||||
cli_overrides = cli_overrides.copy()
|
||||
|
||||
# Handle the full_training_modules -> modules_to_save rename
|
||||
if "full_training_modules" in cli_overrides:
|
||||
cli_overrides["modules_to_save"] = cli_overrides.pop("full_training_modules")
|
||||
|
||||
# Remove method_type as it's handled separately
|
||||
cli_overrides.pop("method_type", None)
|
||||
|
||||
# Handle init_type specially based on PEFT method
|
||||
init_type = cli_overrides.pop("init_type", None)
|
||||
if init_type is not None:
|
||||
if peft_method_type == PeftType.LORA:
|
||||
cli_overrides["init_lora_weights"] = init_type
|
||||
elif peft_method_type == PeftType.MISS:
|
||||
cli_overrides["init_weights"] = init_type
|
||||
else:
|
||||
raise ValueError(f"Init type '{init_type}' unknown for PEFT method {peft_method_type}.")
|
||||
|
||||
return cli_overrides
|
||||
|
||||
def _build_peft_config(self, cli_overrides: dict):
|
||||
"""Build a PEFT config from policy defaults and CLI overrides."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Determine PEFT method type (default to LORA)
|
||||
method_type_str = cli_overrides.get("method_type") or "lora"
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with policy defaults, apply CLI overrides
|
||||
config_dict = dict(self._get_default_peft_targets() or {})
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
# Ensure we have target_modules
|
||||
if not config_dict.get("target_modules"):
|
||||
raise ValueError(
|
||||
f"Policy '{self.name}' does not define default target_modules. "
|
||||
"Please pass --peft.target_modules explicitly."
|
||||
)
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
def _apply_peft_cli_overrides(self, peft_config, cli_overrides: dict):
|
||||
"""Apply CLI overrides to an existing PEFT config."""
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType
|
||||
|
||||
# Get method type from existing config or CLI override
|
||||
method_type_str = cli_overrides.get("method_type")
|
||||
if method_type_str:
|
||||
peft_method_type = PeftType[method_type_str.upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
else:
|
||||
peft_method_type = PeftType(peft_config.peft_type)
|
||||
peft_config_cls = type(peft_config)
|
||||
|
||||
# Preprocess CLI overrides
|
||||
cli_overrides = self._preprocess_peft_cli_overrides(cli_overrides, peft_method_type)
|
||||
|
||||
# Start with existing config, apply CLI overrides
|
||||
config_dict = {k: v for k, v in dataclasses.asdict(peft_config).items() if not k.startswith("_")}
|
||||
for key, value in cli_overrides.items():
|
||||
if value is not None:
|
||||
config_dict[key] = value
|
||||
|
||||
return peft_config_cls(**config_dict)
|
||||
|
||||
@@ -480,6 +480,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
actions = pad_vector(batch[ACTION], self.config.max_action_dim)
|
||||
return actions
|
||||
|
||||
def _get_default_peft_targets(self) -> dict[str, any]:
|
||||
"""Return default PEFT target modules for SmolVLA fine-tuning."""
|
||||
common_projections = (
|
||||
"state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
)
|
||||
target_modules = rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))"
|
||||
return {
|
||||
"target_modules": target_modules,
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
def _validate_peft_config(self, peft_config) -> None:
|
||||
"""Validate PEFT configuration for SmolVLA."""
|
||||
super()._validate_peft_config(peft_config)
|
||||
if not self.config.load_vlm_weights:
|
||||
import logging
|
||||
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Set `load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
|
||||
def pad_tensor(tensor, max_len, pad_value=0):
|
||||
"""
|
||||
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
@@ -99,6 +100,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
"""Check if robot is connected to SDK."""
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect to robot via Frodobots SDK.
|
||||
|
||||
@@ -109,8 +111,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
DeviceAlreadyConnectedError: If robot is already connected
|
||||
DeviceNotConnectedError: If cannot connect to SDK server
|
||||
"""
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||
|
||||
# Verify SDK is running and accessible
|
||||
try:
|
||||
@@ -197,6 +197,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
@@ -223,8 +224,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
observation = {}
|
||||
|
||||
@@ -255,6 +254,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
|
||||
return observation
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
@@ -272,8 +272,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
@@ -291,6 +289,7 @@ class EarthRoverMiniPlus(Robot):
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from robot.
|
||||
|
||||
@@ -299,8 +298,6 @@ class EarthRoverMiniPlus(Robot):
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Stop the robot before disconnecting
|
||||
try:
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -82,13 +82,12 @@ class HopeJrArm(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect(handshake=False)
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -128,10 +127,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position", self.other_motors)
|
||||
@@ -149,10 +146,8 @@ class HopeJrArm(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
# Cap goal position when too far away from present position.
|
||||
@@ -165,10 +160,8 @@ class HopeJrArm(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_hope_jr import HopeJrHandConfig
|
||||
@@ -118,10 +118,8 @@ class HopeJrHand(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
self.calibrate()
|
||||
@@ -159,10 +157,8 @@ class HopeJrHand(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
obs_dict = {}
|
||||
|
||||
# Read hand position
|
||||
@@ -181,18 +177,14 @@ class HopeJrHand(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -25,7 +25,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,13 +84,12 @@ class KochFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -182,10 +181,8 @@ class KochFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -202,6 +199,7 @@ class KochFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -215,8 +213,6 @@ class KochFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -231,10 +227,8 @@ class KochFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.motors.feetech import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -109,10 +109,8 @@ class LeKiwi(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -339,10 +337,8 @@ class LeKiwi(Robot):
|
||||
"theta.vel": theta,
|
||||
} # m/s and deg/s
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read actuators position for arm and vel for base
|
||||
start = time.perf_counter()
|
||||
arm_pos = self.bus.sync_read("Present_Position", self.arm_motors)
|
||||
@@ -370,6 +366,7 @@ class LeKiwi(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration.
|
||||
|
||||
@@ -383,8 +380,6 @@ class LeKiwi(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
arm_goal_pos = {k: v for k, v in action.items() if k.endswith(".pos")}
|
||||
base_goal_vel = {k: v for k, v in action.items() if k.endswith(".vel")}
|
||||
@@ -412,10 +407,8 @@ class LeKiwi(Robot):
|
||||
self.bus.sync_write("Goal_Velocity", dict.fromkeys(self.base_motors, 0), num_retry=5)
|
||||
logger.info("Base motors stopped")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_base()
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
|
||||
@@ -24,7 +24,8 @@ import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_lekiwi import LeKiwiClientConfig
|
||||
@@ -112,14 +113,10 @@ class LeKiwiClient(Robot):
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
"""Establishes ZMQ sockets with the remote mobile robot"""
|
||||
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"LeKiwi Daemon is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
zmq = self._zmq
|
||||
self.zmq_context = zmq.Context()
|
||||
self.zmq_cmd_socket = self.zmq_context.socket(zmq.PUSH)
|
||||
@@ -252,14 +249,13 @@ class LeKiwiClient(Robot):
|
||||
|
||||
return new_frames, new_state
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
"""
|
||||
Capture observations from the remote robot: current follower arm positions,
|
||||
present wheel speeds (converted to body-frame velocities: x, y, theta),
|
||||
and a camera frame. Receives over ZMQ, translate to body-frame vel
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError("LeKiwiClient is not connected. You need to run `robot.connect()`.")
|
||||
|
||||
frames, obs_dict = self._get_data()
|
||||
|
||||
@@ -307,6 +303,7 @@ class LeKiwiClient(Robot):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command lekiwi to move to a target joint configuration. Translates to motor space + sends over ZMQ
|
||||
|
||||
@@ -318,10 +315,6 @@ class LeKiwiClient(Robot):
|
||||
Returns:
|
||||
np.ndarray: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"ManipulatorRobot is not connected. You need to run `robot.connect()`."
|
||||
)
|
||||
|
||||
self.zmq_cmd_socket.send_string(json.dumps(action)) # action is in motor space
|
||||
|
||||
@@ -332,13 +325,10 @@ class LeKiwiClient(Robot):
|
||||
action_sent[ACTION] = actions
|
||||
return action_sent
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
"""Cleans ZMQ comms"""
|
||||
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"LeKiwi is not connected. You need to run `robot.connect()` before disconnecting."
|
||||
)
|
||||
self.zmq_observation_socket.close()
|
||||
self.zmq_cmd_socket.close()
|
||||
self.zmq_context.term()
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.motors.dynamixel import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -84,6 +84,7 @@ class OmxFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
For OMX robots that come pre-calibrated:
|
||||
@@ -91,8 +92,6 @@ class OmxFollower(Robot):
|
||||
- This allows using pre-calibrated robots without manual calibration
|
||||
- If no calibration file exists, use factory default values (homing_offset=0, range_min=0, range_max=4095)
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -165,10 +164,8 @@ class OmxFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -185,6 +182,7 @@ class OmxFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -198,8 +196,6 @@ class OmxFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: The action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -214,10 +210,8 @@ class OmxFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -26,7 +26,7 @@ from lerobot.motors.feetech import (
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..robot import Robot
|
||||
from ..utils import ensure_safe_goal_position
|
||||
@@ -85,13 +85,12 @@ class SOFollower(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected and all(cam.is_connected for cam in self.cameras.values())
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""
|
||||
We assume that at connection time, arm is in a rest position,
|
||||
and torque can be safely disabled to run calibration.
|
||||
"""
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
@@ -176,10 +175,8 @@ class SOFollower(Robot):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Read arm position
|
||||
start = time.perf_counter()
|
||||
obs_dict = self.bus.sync_read("Present_Position")
|
||||
@@ -196,6 +193,7 @@ class SOFollower(Robot):
|
||||
|
||||
return obs_dict
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
"""Command arm to move to a target joint configuration.
|
||||
|
||||
@@ -209,8 +207,6 @@ class SOFollower(Robot):
|
||||
Returns:
|
||||
RobotAction: the action sent to the motors, potentially clipped.
|
||||
"""
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
goal_pos = {key.removesuffix(".pos"): val for key, val in action.items() if key.endswith(".pos")}
|
||||
|
||||
@@ -225,10 +221,8 @@ class SOFollower(Robot):
|
||||
self.bus.sync_write("Goal_Position", goal_pos)
|
||||
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect(self.config.disable_torque_on_disconnect)
|
||||
for cam in self.cameras.values():
|
||||
cam.disconnect()
|
||||
|
||||
@@ -148,92 +148,6 @@ def update_policy(
|
||||
return train_metrics, output_dict
|
||||
|
||||
|
||||
def get_default_peft_configuration(policy_type):
|
||||
"""Build a basic PEFT configuration for the given policy type assuming that we train a policy from a checkpoint."""
|
||||
|
||||
common_projections = "state_proj|action_in_proj|action_out_proj|action_time_mlp_in|action_time_mlp_out"
|
||||
|
||||
if policy_type == "smolvla":
|
||||
return {
|
||||
"target_modules": rf"(model\.vlm_with_expert\.lm_expert\..*\.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
elif policy_type in ("pi0", "pi05"):
|
||||
return {
|
||||
"target_modules": rf"(.*\.gemma_expert\..*\.self_attn.(q|v)_proj|model\.({common_projections}))",
|
||||
"modules_to_save": [],
|
||||
}
|
||||
|
||||
return {"modules_to_save": None}
|
||||
|
||||
|
||||
def wrap_policy_in_peft_model(cfg, policy):
|
||||
from peft import PEFT_TYPE_TO_CONFIG_MAPPING, PeftType, get_peft_model
|
||||
|
||||
# Disable all gradients because we'll only train the parameters selected by the PEFT method.
|
||||
# Layers that should receive gradients anyway need to be listed in `modules_to_save`.
|
||||
for p in policy.parameters():
|
||||
p.requires_grad_(False)
|
||||
|
||||
if not cfg.policy.pretrained_path:
|
||||
raise ValueError(
|
||||
"Training from scratch using PEFT. This is unlikely to yield good results. "
|
||||
"Supply a `policy.path` to fine-tune an existing model."
|
||||
)
|
||||
|
||||
if cfg.policy.type == "smolvla" and not cfg.policy.load_vlm_weights:
|
||||
logging.warning(
|
||||
"Training SmolVLA from scratch using PEFT. This is unlikely to yield good results. Set "
|
||||
"`load_vlm_weights=True` to fine-tune the existing policy."
|
||||
)
|
||||
|
||||
peft_config_policy = get_default_peft_configuration(cfg.policy.type)
|
||||
peft_config_cli = dataclasses.asdict(cfg.peft) if cfg.peft else {}
|
||||
peft_config_cli["modules_to_save"] = peft_config_cli["full_training_modules"] # compatibility with PEFT
|
||||
peft_method_type = PeftType[peft_config_cli["method_type"].upper()]
|
||||
peft_config_cls = PEFT_TYPE_TO_CONFIG_MAPPING[peft_method_type]
|
||||
|
||||
# Handle specific CLI overrides
|
||||
for key in ["target_modules", "modules_to_save", "r"]:
|
||||
if peft_config_cli[key] is not None:
|
||||
peft_config_policy[key] = peft_config_cli[key]
|
||||
|
||||
if "target_modules" not in peft_config_policy:
|
||||
raise ValueError(
|
||||
f"There is no default `target_modules` value for policy {cfg.policy.type}. Please pass it manually."
|
||||
)
|
||||
|
||||
# Init method depends on the used PEFT method, your specific PEFT method
|
||||
# might not be considered here, in that case an error is raised.
|
||||
if peft_config_cli["init_type"] is not None:
|
||||
if peft_method_type == "LORA":
|
||||
peft_config_policy["init_lora_weights"] = peft_config_cli["init_type"]
|
||||
elif peft_method_type == "MISS":
|
||||
peft_config_policy["init_weights"] = peft_config_cli["init_type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Init type {peft_config_cli['init_type']} unknown for PEFT method {peft_method_type}."
|
||||
)
|
||||
|
||||
# PEFT uses this attribute to set adapter_config.base_name_or_path which we use for loading the
|
||||
# correct base model in `make_policy` since in a PEFT loading setting we only get the path to the
|
||||
# adapter, not the base model.
|
||||
if policy.config.pretrained_path:
|
||||
policy.name_or_path = str(policy.config.pretrained_path)
|
||||
|
||||
# Finally wrap the policy in a PEFT model
|
||||
policy = get_peft_model(
|
||||
policy,
|
||||
peft_config_cls(**peft_config_policy),
|
||||
)
|
||||
|
||||
# Make sure that the config is tagged as using PEFT so that the loading code can take the
|
||||
# appropriate steps to use the adapter weights and the PEFT config instead of the full model weights.
|
||||
policy.config.use_peft = True
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
"""
|
||||
@@ -326,7 +240,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
if cfg.peft is not None:
|
||||
logging.info("Using PEFT! Wrapping model.")
|
||||
policy = wrap_policy_in_peft_model(cfg, policy)
|
||||
# Convert CLI peft config to dict for overrides
|
||||
peft_cli_overrides = dataclasses.asdict(cfg.peft)
|
||||
policy = policy.wrap_with_peft(peft_cli_overrides=peft_cli_overrides)
|
||||
|
||||
# Wait for all processes to finish policy creation before continuing
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
@@ -18,7 +18,7 @@ import logging
|
||||
from functools import cached_property
|
||||
|
||||
from lerobot.teleoperators.so_leader import SOLeaderTeleopConfig
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..so_leader import SOLeader
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -92,10 +92,8 @@ class BiSOLeader(Teleoperator):
|
||||
self.left_arm.setup_motors()
|
||||
self.right_arm.setup_motors()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
action_dict = {}
|
||||
|
||||
# Add "left_" prefix
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
import numpy as np
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -86,10 +86,8 @@ class GamepadTeleop(Teleoperator):
|
||||
self.gamepad = Gamepad()
|
||||
self.gamepad.start()
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
# Update the controller to get fresh inputs
|
||||
self.gamepad.update()
|
||||
|
||||
|
||||
@@ -22,7 +22,7 @@ from pprint import pformat
|
||||
import serial
|
||||
|
||||
from lerobot.motors.motors_bus import MotorCalibration, MotorNormMode
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -93,10 +93,8 @@ class HomunculusArm(Teleoperator):
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
@@ -299,20 +297,16 @@ class HomunculusArm(Teleoperator):
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_positions = self._read()
|
||||
return {f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
|
||||
@@ -24,7 +24,7 @@ import serial
|
||||
from lerobot.motors import MotorCalibration
|
||||
from lerobot.motors.motors_bus import MotorNormMode
|
||||
from lerobot.teleoperators.homunculus.joints_translation import homunculus_glove_to_hope_jr_hand
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.utils import enter_pressed, move_cursor_up
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
@@ -119,10 +119,8 @@ class HomunculusGlove(Teleoperator):
|
||||
with self.serial_lock:
|
||||
return self.serial.is_open and self.thread.is_alive()
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
if not self.serial.is_open:
|
||||
self.serial.open()
|
||||
self.thread.start()
|
||||
@@ -325,10 +323,8 @@ class HomunculusGlove(Teleoperator):
|
||||
except Exception as e:
|
||||
logger.debug(f"Error reading frame in background thread for {self}: {e}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_positions = self._read()
|
||||
return homunculus_glove_to_hope_jr_hand(
|
||||
{f"{joint}.pos": pos for joint, pos in joint_positions.items()}
|
||||
@@ -337,10 +333,8 @@ class HomunculusGlove(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.stop_event.set()
|
||||
self.thread.join(timeout=1)
|
||||
self.serial.close()
|
||||
|
||||
@@ -22,7 +22,7 @@ from queue import Queue
|
||||
from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from ..utils import TeleopEvents
|
||||
@@ -86,12 +86,8 @@ class KeyboardTeleop(Teleoperator):
|
||||
def is_calibrated(self) -> bool:
|
||||
pass
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(
|
||||
"Keyboard is already connected. Do not run `robot.connect()` twice."
|
||||
)
|
||||
|
||||
if PYNPUT_AVAILABLE:
|
||||
logging.info("pynput is available - enabling local keyboard listener.")
|
||||
self.listener = keyboard.Listener(
|
||||
@@ -125,14 +121,10 @@ class KeyboardTeleop(Teleoperator):
|
||||
def configure(self):
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
|
||||
# Generate action based on current key states
|
||||
@@ -144,11 +136,8 @@ class KeyboardTeleop(Teleoperator):
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `robot.connect()` before `disconnect()`."
|
||||
)
|
||||
if self.listener is not None:
|
||||
self.listener.stop()
|
||||
|
||||
@@ -182,12 +171,8 @@ class KeyboardEndEffectorTeleop(KeyboardTeleop):
|
||||
"names": {"delta_x": 0, "delta_y": 1, "delta_z": 2},
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
delta_x = 0.0
|
||||
delta_y = 0.0
|
||||
@@ -375,6 +360,7 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
# Only remove key if it's being released
|
||||
self.current_pressed.pop(key_char, None)
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""
|
||||
Get the current action based on pressed keys.
|
||||
@@ -384,11 +370,6 @@ class KeyboardRoverTeleop(KeyboardTeleop):
|
||||
"""
|
||||
before_read_t = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
"KeyboardRoverTeleop is not connected. You need to run `connect()` before `get_action()`."
|
||||
)
|
||||
|
||||
self._drain_pressed_keys()
|
||||
|
||||
linear_velocity = 0.0
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_koch_leader import KochLeaderConfig
|
||||
@@ -69,10 +69,8 @@ class KochLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -161,10 +159,8 @@ class KochLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -176,9 +172,7 @@ class KochLeader(Teleoperator):
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.dynamixel import (
|
||||
DynamixelMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_omx_leader import OmxLeaderConfig
|
||||
@@ -68,10 +68,8 @@ class OmxLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -142,10 +140,8 @@ class OmxLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -157,9 +153,7 @@ class OmxLeader(Teleoperator):
|
||||
# TODO(rcadene, aliberts): Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
@@ -28,7 +28,7 @@ from teleop import Teleop
|
||||
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.rotation import Rotation
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
@@ -81,10 +81,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._group is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Connecting to IPhone, make sure to open the HEBI Mobile I/O app.")
|
||||
lookup = hebi.Lookup()
|
||||
time.sleep(2.0)
|
||||
@@ -164,10 +162,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
pos = ar_pos - rot.apply(self.config.camera_offset)
|
||||
return True, pos, rot, pose
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
has_pose, raw_position, raw_rotation, fb_pose = self._read_current_pose()
|
||||
if not has_pose or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -207,10 +203,8 @@ class IOSPhone(BasePhone, Teleoperator):
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._group = None
|
||||
|
||||
|
||||
@@ -230,10 +224,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._teleop is not None
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
logger.info("Starting teleop stream for Android...")
|
||||
self._teleop = Teleop()
|
||||
self._teleop.subscribe(self._android_callback)
|
||||
@@ -321,10 +313,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
self._latest_pose = pose
|
||||
self._latest_message = message
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
ok, raw_pos, raw_rot, pose = self._read_current_pose()
|
||||
if not ok or not self.is_calibrated:
|
||||
return {}
|
||||
@@ -356,10 +346,8 @@ class AndroidPhone(BasePhone, Teleoperator):
|
||||
"phone.enabled": self._enabled,
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._teleop = None
|
||||
if self._teleop_thread and self._teleop_thread.is_alive():
|
||||
self._teleop_thread.join(timeout=1.0)
|
||||
|
||||
@@ -26,7 +26,8 @@ if TYPE_CHECKING or _reachy2_sdk_available:
|
||||
else:
|
||||
ReachySDK = None
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from lerobot.utils.errors import DeviceNotConnectedError
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_reachy2_teleoperator import Reachy2TeleoperatorConfig
|
||||
@@ -126,10 +127,8 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.reachy.is_connected() if self.reachy is not None else False
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.reachy = ReachySDK(self.config.ip_address)
|
||||
|
||||
if not self.is_connected:
|
||||
@@ -146,12 +145,10 @@ class Reachy2Teleoperator(Teleoperator):
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
start = time.perf_counter()
|
||||
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
joint_action: dict[str, float] = {}
|
||||
vel_action: dict[str, float] = {}
|
||||
|
||||
|
||||
@@ -23,7 +23,7 @@ from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_so_leader import SOLeaderTeleopConfig
|
||||
@@ -66,10 +66,8 @@ class SOLeader(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self.bus.connect()
|
||||
if not self.is_calibrated and calibrate:
|
||||
logger.info(
|
||||
@@ -139,10 +137,8 @@ class SOLeader(Teleoperator):
|
||||
self.bus.setup_motor(motor)
|
||||
print(f"'{motor}' motor id set to {self.bus.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> dict[str, float]:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
start = time.perf_counter()
|
||||
action = self.bus.sync_read("Present_Position")
|
||||
action = {f"{motor}.pos": val for motor, val in action.items()}
|
||||
@@ -154,10 +150,8 @@ class SOLeader(Teleoperator):
|
||||
# TODO: Implement force feedback
|
||||
raise NotImplementedError
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self.bus.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
|
||||
|
||||
@@ -0,0 +1,41 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
|
||||
def check_if_not_connected(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(
|
||||
f"{self.__class__.__name__} is not connected. Run `.connect()` first."
|
||||
)
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def check_if_already_connected(func):
|
||||
@wraps(func)
|
||||
def wrapper(self, *args, **kwargs):
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.__class__.__name__} is already connected.")
|
||||
return func(self, *args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
@@ -352,6 +352,65 @@ def test_image_array_to_pil_image_wrong_range_float_0_255():
|
||||
image_array_to_pil_image(image)
|
||||
|
||||
|
||||
def test_tmp_image_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for image features after saving episode."""
|
||||
# Image feature: images should be deleted after saving episode
|
||||
image_key = "image"
|
||||
features_image = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
ds_img = empty_lerobot_dataset_factory(root=tmp_path / "img", features=features_image)
|
||||
ds_img.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_img.save_episode()
|
||||
img_dir = ds_img._get_image_file_dir(0, image_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
|
||||
|
||||
def test_tmp_video_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed for video encoding when `batch_encoding_size == 1`."""
|
||||
# Video feature: when batch_encoding_size == 1 temporary images should be deleted
|
||||
vid_key = "video"
|
||||
features_video = {
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]}
|
||||
}
|
||||
|
||||
ds_vid = empty_lerobot_dataset_factory(root=tmp_path / "vid", features=features_video)
|
||||
ds_vid.batch_encoding_size = 1
|
||||
ds_vid.add_frame({vid_key: np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||
ds_vid.save_episode()
|
||||
vid_img_dir = ds_vid._get_image_file_dir(0, vid_key)
|
||||
assert not vid_img_dir.exists(), (
|
||||
"Temporary image directory should be removed when batch_encoding_size == 1"
|
||||
)
|
||||
|
||||
|
||||
def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Verify temporary image directories are removed appropriately when both image and video features are present."""
|
||||
image_key = "image"
|
||||
vid_key = "video"
|
||||
features_mixed = {
|
||||
image_key: {"dtype": "image", "shape": DUMMY_CHW, "names": ["channels", "height", "width"]},
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds_mixed = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||
)
|
||||
ds_mixed.add_frame(
|
||||
{
|
||||
"image": np.random.rand(*DUMMY_CHW),
|
||||
"video": np.random.rand(*DUMMY_HWC),
|
||||
"task": "Dummy task",
|
||||
}
|
||||
)
|
||||
ds_mixed.save_episode()
|
||||
img_dir = ds_mixed._get_image_file_dir(0, image_key)
|
||||
vid_img_dir = ds_mixed._get_image_file_dir(0, vid_key)
|
||||
assert not img_dir.exists(), "Temporary image directory should be removed for image features"
|
||||
assert vid_img_dir.exists(), (
|
||||
"Temporary image directory should not be removed for video features when batch_encoding_size == 2"
|
||||
)
|
||||
|
||||
|
||||
# TODO(aliberts):
|
||||
# - [ ] test various attributes & state from init and create
|
||||
# - [ ] test init with episodes and check num_frames
|
||||
@@ -1392,3 +1451,202 @@ def test_valid_video_codecs_constant():
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Regression test for bug where delta_timestamps incorrectly marked all frames as padded when using episodes filter.
|
||||
|
||||
The bug occurred because _get_query_indices was using the relative index (idx) in the filtered dataset
|
||||
instead of the absolute index when comparing against episode boundaries (ep_start, ep_end).
|
||||
"""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features, use_videos=False)
|
||||
|
||||
# Create 3 episodes with 10 frames each
|
||||
frames_per_episode = 10
|
||||
for ep_idx in range(3):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"action": torch.randn(2),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load only episode 1 (middle episode) with delta_timestamps
|
||||
delta_ts = {"observation.state": [0.0]} # Just the current frame
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
)
|
||||
|
||||
# Verify the filtered dataset has the correct length
|
||||
assert len(filtered_dataset) == frames_per_episode
|
||||
|
||||
# Check that no frames are marked as padded (since delta=0 should always be valid)
|
||||
for idx in range(len(filtered_dataset)):
|
||||
frame = filtered_dataset[idx]
|
||||
assert frame["observation.state_is_pad"].item() is False, f"Frame {idx} incorrectly marked as padded"
|
||||
# Verify we're getting data from episode 1
|
||||
assert frame["episode_index"].item() == 1
|
||||
|
||||
|
||||
def test_delta_timestamps_padding_at_episode_boundaries(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that delta_timestamps correctly marks padding at episode boundaries when using episodes filter."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["vx", "vy"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 3 episodes with 5 frames each
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(3):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"action": torch.randn(2),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load only episode 1 with delta_timestamps that go beyond episode boundaries
|
||||
# fps=10, so 0.1s = 1 frame offset
|
||||
delta_ts = {"observation.state": [-0.2, -0.1, 0.0, 0.1, 0.2]} # -2, -1, 0, +1, +2 frames
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
tolerance_s=0.04, # Slightly less than half a frame at 10fps
|
||||
)
|
||||
|
||||
assert len(filtered_dataset) == frames_per_episode
|
||||
|
||||
# Check padding at the start of the episode (first frame)
|
||||
first_frame = filtered_dataset[0]
|
||||
is_pad = first_frame["observation.state_is_pad"].tolist()
|
||||
# At frame 0 of episode 1: delta -2 and -1 should be padded, 0, +1, +2 should not
|
||||
assert is_pad == [True, True, False, False, False], f"First frame padding incorrect: {is_pad}"
|
||||
|
||||
# Check middle frame (no padding expected)
|
||||
mid_frame = filtered_dataset[2]
|
||||
is_pad = mid_frame["observation.state_is_pad"].tolist()
|
||||
assert is_pad == [False, False, False, False, False], f"Middle frame padding incorrect: {is_pad}"
|
||||
|
||||
# Check padding at the end of the episode (last frame)
|
||||
last_frame = filtered_dataset[4]
|
||||
is_pad = last_frame["observation.state_is_pad"].tolist()
|
||||
# At frame 4 of episode 1: delta -2, -1, 0 should not be padded, +1, +2 should be
|
||||
assert is_pad == [False, False, False, True, True], f"Last frame padding incorrect: {is_pad}"
|
||||
|
||||
|
||||
def test_delta_timestamps_multiple_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test delta_timestamps with multiple non-consecutive episodes selected."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (2,), "names": ["x", "y"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 5 episodes with 5 frames each
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(5):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([ep_idx, frame_idx], dtype=torch.float32),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load episodes 1 and 3 (non-consecutive)
|
||||
delta_ts = {"observation.state": [0.0]}
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1, 3],
|
||||
delta_timestamps=delta_ts,
|
||||
)
|
||||
|
||||
assert len(filtered_dataset) == 2 * frames_per_episode
|
||||
|
||||
# All frames should have valid (non-padded) data for delta=0
|
||||
for idx in range(len(filtered_dataset)):
|
||||
frame = filtered_dataset[idx]
|
||||
assert frame["observation.state_is_pad"].item() is False
|
||||
|
||||
# Verify we're getting the correct episodes
|
||||
episode_indices = [filtered_dataset[i]["episode_index"].item() for i in range(len(filtered_dataset))]
|
||||
expected_episodes = [1] * frames_per_episode + [3] * frames_per_episode
|
||||
assert episode_indices == expected_episodes
|
||||
|
||||
|
||||
def test_delta_timestamps_query_returns_correct_values(tmp_path, empty_lerobot_dataset_factory):
|
||||
"""Test that delta_timestamps returns the correct observation values, not just correct padding."""
|
||||
features = {
|
||||
"observation.state": {"dtype": "float32", "shape": (1,), "names": ["x"]},
|
||||
}
|
||||
|
||||
dataset = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "test", features=features, use_videos=False, fps=10
|
||||
)
|
||||
|
||||
# Create 2 episodes with known values
|
||||
# Episode 0: frames with values 0, 1, 2, 3, 4
|
||||
# Episode 1: frames with values 10, 11, 12, 13, 14
|
||||
frames_per_episode = 5
|
||||
for ep_idx in range(2):
|
||||
for frame_idx in range(frames_per_episode):
|
||||
value = ep_idx * 10 + frame_idx
|
||||
dataset.add_frame(
|
||||
{
|
||||
"observation.state": torch.tensor([value], dtype=torch.float32),
|
||||
"task": f"task_{ep_idx}",
|
||||
}
|
||||
)
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
# Load episode 1 with delta that looks at previous frame
|
||||
delta_ts = {"observation.state": [-0.1, 0.0]} # Previous frame and current frame
|
||||
filtered_dataset = LeRobotDataset(
|
||||
dataset.repo_id,
|
||||
root=dataset.root,
|
||||
episodes=[1],
|
||||
delta_timestamps=delta_ts,
|
||||
tolerance_s=0.04,
|
||||
)
|
||||
|
||||
# Check frame 2 of episode 1 (which has absolute index 7, value 12)
|
||||
frame = filtered_dataset[2]
|
||||
state_values = frame["observation.state"].tolist()
|
||||
# Should get [11, 12] - the previous and current values within episode 1
|
||||
assert state_values == [11.0, 12.0], f"Expected [11.0, 12.0], got {state_values}"
|
||||
|
||||
# Check first frame - previous frame should be clamped to episode start (padded)
|
||||
first_frame = filtered_dataset[0]
|
||||
state_values = first_frame["observation.state"].tolist()
|
||||
is_pad = first_frame["observation.state_is_pad"].tolist()
|
||||
# Previous frame is outside episode, so it's clamped to first frame and marked as padded
|
||||
assert state_values == [10.0, 10.0], f"Expected [10.0, 10.0], got {state_values}"
|
||||
assert is_pad == [True, False], f"Expected [True, False], got {is_pad}"
|
||||
|
||||
@@ -22,7 +22,7 @@ from lerobot.cameras import CameraConfig, make_cameras_from_configs
|
||||
from lerobot.motors.motors_bus import Motor, MotorNormMode
|
||||
from lerobot.processor import RobotAction, RobotObservation
|
||||
from lerobot.robots import Robot, RobotConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
from tests.mocks.mock_motors_bus import MockMotorsBus
|
||||
|
||||
|
||||
@@ -98,10 +98,8 @@ class MockRobot(Robot):
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -110,19 +108,15 @@ class MockRobot(Robot):
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
@check_if_not_connected
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_observation(self) -> RobotObservation:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -130,14 +124,10 @@ class MockRobot(Robot):
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
@check_if_not_connected
|
||||
def send_action(self, action: RobotAction) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
return action
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
|
||||
@@ -21,7 +21,7 @@ from typing import Any
|
||||
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.teleoperators import Teleoperator, TeleoperatorConfig
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("mock_teleop")
|
||||
@@ -68,10 +68,8 @@ class MockTeleop(Teleoperator):
|
||||
def is_connected(self) -> bool:
|
||||
return self._is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
if self.is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self} already connected")
|
||||
|
||||
self._is_connected = True
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
@@ -80,19 +78,15 @@ class MockTeleop(Teleoperator):
|
||||
def is_calibrated(self) -> bool:
|
||||
return self._is_calibrated
|
||||
|
||||
@check_if_not_connected
|
||||
def calibrate(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_calibrated = True
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
if self.config.random_values:
|
||||
return {f"{motor}.pos": random.uniform(-100, 100) for motor in self.motors}
|
||||
else:
|
||||
@@ -100,12 +94,9 @@ class MockTeleop(Teleoperator):
|
||||
f"{motor}.pos": val for motor, val in zip(self.motors, self.config.static_values, strict=True)
|
||||
}
|
||||
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
@check_if_not_connected
|
||||
def send_feedback(self, feedback: dict[str, Any]) -> None: ...
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
if not self.is_connected:
|
||||
raise DeviceNotConnectedError(f"{self} is not connected.")
|
||||
|
||||
self._is_connected = False
|
||||
|
||||
Reference in New Issue
Block a user