refactor import fixes

This commit is contained in:
Steven Palma
2026-04-11 18:02:59 +02:00
parent d626964119
commit af0d72bd42
69 changed files with 306 additions and 339 deletions
+2 -2
View File
@@ -201,7 +201,7 @@ from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.teleoperators.so_leader.config_so100_leader import SO100LeaderConfig
from lerobot.teleoperators.so_leader.so100_leader import SO100Leader
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
from lerobot.scripts.lerobot_record import record_loop
@@ -540,7 +540,7 @@ from lerobot.policies.factory import make_pre_post_processors
from lerobot.robots.so_follower.config_so100_follower import SO100FollowerConfig
from lerobot.robots.so_follower.so100_follower import SO100Follower
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+1 -6
View File
@@ -31,17 +31,12 @@ from pprint import pprint
import torch
from huggingface_hub import HfApi
import lerobot
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.lerobot_dataset import LeRobotDataset
def main():
# We ported a number of existing datasets ourselves, use this to see the list:
print("List of available datasets:")
pprint(lerobot.available_datasets)
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
# Browse datasets created/ported by the community on the hub using the hub api:
hub_api = HfApi()
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
pprint(repo_ids)
+2 -2
View File
@@ -114,9 +114,9 @@ from hil_utils import (
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
from lerobot.common.control_utils import is_headless, predict_action
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.datasets.feature_utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.datasets.image_writer import safe_stop_image_writer
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
@@ -141,8 +141,8 @@ from lerobot.teleoperators import Teleoperator, TeleoperatorConfig, make_teleope
from lerobot.teleoperators.openarm_mini.config_openarm_mini import OpenArmMiniConfig # noqa: F401
from lerobot.teleoperators.so_leader.config_so_leader import SOLeaderTeleopConfig # noqa: F401
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
from lerobot.utils.control_utils import is_headless, predict_action
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts, hw_to_dataset_features
from lerobot.utils.robot_utils import precise_sleep
from lerobot.utils.utils import init_logging, log_say
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
+1 -1
View File
@@ -19,6 +19,7 @@ import time
from dataclasses import dataclass, field
from pathlib import Path
from lerobot.common.control_utils import is_headless
from lerobot.processor import (
IdentityProcessorStep,
RobotAction,
@@ -33,7 +34,6 @@ from lerobot.processor.converters import (
)
from lerobot.robots import Robot
from lerobot.teleoperators import Teleoperator
from lerobot.utils.control_utils import is_headless
from lerobot.utils.robot_utils import precise_sleep
logger = logging.getLogger(__name__)
+2 -2
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
@@ -22,7 +22,7 @@ from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
from lerobot.scripts.lerobot_record import record_loop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -2
View File
@@ -14,7 +14,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.processor import make_default_processors
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
@@ -23,7 +23,7 @@ from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.keyboard import KeyboardTeleop, KeyboardTeleopConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -2
View File
@@ -15,8 +15,8 @@
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -39,7 +39,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -2
View File
@@ -15,7 +15,7 @@
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -39,7 +39,7 @@ from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
from lerobot.teleoperators.phone.teleop_phone import Phone
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+1 -1
View File
@@ -107,7 +107,6 @@ from lerobot.cameras.zmq.configuration_zmq import ZMQCameraConfig # noqa: F401
from lerobot.configs import parser
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc import ActionInterpolator, ActionQueue, LatencyTracker, RTCConfig
from lerobot.processor import (
@@ -133,6 +132,7 @@ from lerobot.robots import ( # noqa: F401
)
from lerobot.robots.utils import make_robot_from_config
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
from lerobot.utils.feature_utils import build_dataset_frame, hw_to_dataset_features
from lerobot.utils.hub import HubMixin
from lerobot.utils.utils import init_logging
+2 -2
View File
@@ -15,8 +15,8 @@
# limitations under the License.
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -39,7 +39,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
)
from lerobot.scripts.lerobot_record import record_loop
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+2 -2
View File
@@ -16,7 +16,7 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import combine_feature_dicts
from lerobot.common.control_utils import init_keyboard_listener
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
from lerobot.model.kinematics import RobotKinematics
@@ -36,7 +36,7 @@ from lerobot.robots.so_follower.robot_kinematic_processor import (
from lerobot.scripts.lerobot_record import record_loop
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.feature_utils import combine_feature_dicts
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun
+1 -1
View File
@@ -20,11 +20,11 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.feature_utils import dataset_to_policy_features
def main():
+1 -1
View File
@@ -21,12 +21,12 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.streaming_dataset import StreamingLeRobotDataset
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.constants import ACTION
from lerobot.utils.feature_utils import dataset_to_policy_features
def main():
@@ -6,11 +6,11 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.act.configuration_act import ACTConfig
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.feature_utils import dataset_to_policy_features
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
@@ -6,11 +6,11 @@ import torch
from lerobot.configs.types import FeatureType
from lerobot.datasets.dataset_metadata import LeRobotDatasetMetadata
from lerobot.datasets.feature_utils import dataset_to_policy_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.diffusion.configuration_diffusion import DiffusionConfig
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.factory import make_pre_post_processors
from lerobot.utils.feature_utils import dataset_to_policy_features
def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[float]:
+1 -1
View File
@@ -1,11 +1,11 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.pi0.modeling_pi0 import PI0Policy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.utils.feature_utils import hw_to_dataset_features
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
+1 -1
View File
@@ -6,7 +6,6 @@ from queue import Empty, Full
import torch
import torch.optim as optim
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
from lerobot.policies.sac.configuration_sac import SACConfig
@@ -17,6 +16,7 @@ from lerobot.rl.gym_manipulator import make_robot_env
from lerobot.robots.so_follower import SO100FollowerConfig
from lerobot.teleoperators.so_leader import SO100LeaderConfig
from lerobot.teleoperators.utils import TeleopEvents
from lerobot.utils.feature_utils import hw_to_dataset_features
LOG_EVERY = 10
SEND_EVERY = 10
@@ -1,11 +1,11 @@
import torch
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.feature_utils import hw_to_dataset_features
from lerobot.policies.factory import make_pre_post_processors
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
from lerobot.policies.utils import build_inference_frame, make_robot_action
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
from lerobot.utils.feature_utils import hw_to_dataset_features
MAX_EPISODES = 5
MAX_STEPS_PER_EPISODE = 20
+4 -1
View File
@@ -115,7 +115,8 @@ build = [
# ── User-facing composite extras (map to CLI scripts) ─────
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
robot = ["lerobot[dataset]", "lerobot[hardware]", "lerobot[viz]"]
# lerobot-eval -- base evaluation framework. You also need the policy's extra (e.g., lerobot[pi]).
# lerobot-eval -- base evaluation framework. You also need the policy's extra (e.g., lerobot[pi])
# and the environment's extra (e.g., lerobot[pusht]) if evaluating in simulation.
evaluation = ["lerobot[av-dep]"]
# lerobot-dataset-viz, lerobot-imgtransform-viz
dataset_viz = ["lerobot[dataset]", "lerobot[viz]"]
@@ -307,6 +308,8 @@ ignore = [
"src/lerobot/policies/smolvla/smolvlm_with_expert.py" = ["E402"]
"src/lerobot/policies/xvla/modeling_xvla.py" = ["E402"]
"src/lerobot/policies/sarm/processor_sarm.py" = ["E402"]
"src/lerobot/scripts/convert_dataset_v21_to_v30.py" = ["E402"]
"src/lerobot/teleoperators/unitree_g1/exo_serial.py" = ["E402"]
"src/lerobot/policies/wall_x/**" = ["N801", "N812", "SIM102", "SIM108", "SIM210", "SIM211", "B006", "B007", "SIM118"] # Supprese these as they are coming from original Qwen2_5_vl code TODO(pepijn): refactor original
[tool.ruff.lint.isort]
+14 -178
View File
@@ -13,188 +13,24 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
This file contains lists of available environments, dataset and policies to reflect the current state of LeRobot library.
We do not want to import all the dependencies, but instead we keep it lightweight to ensure fast access to these variables.
LeRobot -- PyTorch library for real-world robotics.
Example:
```python
import lerobot
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_real_world_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
print(lerobot.available_robots)
print(lerobot.available_cameras)
print(lerobot.available_motors)
```
Provides datasets, pretrained policies, and tools for training, evaluation,
data collection, and robot control. Integrates with Hugging Face Hub for
model and dataset sharing.
When implementing a new dataset loadable with LeRobotDataset follow these steps:
- Update `available_datasets_per_env` in `lerobot/__init__.py`
The base install is intentionally lightweight. Feature-specific dependencies
are gated behind optional extras::
When implementing a new environment (e.g. `gym_aloha`), follow these steps:
- Update `available_tasks_per_env` and `available_datasets_per_env` in `lerobot/__init__.py`
When implementing a new policy class (e.g. `DiffusionPolicy`) follow these steps:
- Update `available_policies` and `available_policies_per_env`, in `lerobot/__init__.py`
- Set the required `name` class attribute.
- Update variables in `tests/test_available.py` by importing your new Policy class
pip install 'lerobot[dataset]' # dataset loading & creation
pip install 'lerobot[training]' # training loop + wandb
pip install 'lerobot[hardware]' # real robot control
pip install 'lerobot[robot]' # dataset + hardware + viz (recording)
pip install 'lerobot[all]' # everything
"""
import itertools
from lerobot.__version__ import __version__
from lerobot.__version__ import __version__ # noqa: F401
# TODO(rcadene): Improve policies and envs. As of now, an item in `available_policies`
# refers to a yaml file AND a modeling name. Same for `available_envs` which refers to
# a yaml file AND a environment name. The difference should be more obvious.
available_tasks_per_env = {
"aloha": [
"AlohaInsertion-v0",
"AlohaTransferCube-v0",
],
"pusht": ["PushT-v0"],
}
available_envs = list(available_tasks_per_env.keys())
available_datasets_per_env = {
"aloha": [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
],
# TODO(alexander-soare): Add "lerobot/pusht_keypoints". Right now we can't because this is too tightly
# coupled with tests.
"pusht": ["lerobot/pusht", "lerobot/pusht_image"],
}
available_real_world_datasets = [
"lerobot/aloha_mobile_cabinet",
"lerobot/aloha_mobile_chair",
"lerobot/aloha_mobile_elevator",
"lerobot/aloha_mobile_shrimp",
"lerobot/aloha_mobile_wash_pan",
"lerobot/aloha_mobile_wipe_wine",
"lerobot/aloha_static_battery",
"lerobot/aloha_static_candy",
"lerobot/aloha_static_coffee",
"lerobot/aloha_static_coffee_new",
"lerobot/aloha_static_cups_open",
"lerobot/aloha_static_fork_pick_up",
"lerobot/aloha_static_pingpong_test",
"lerobot/aloha_static_pro_pencil",
"lerobot/aloha_static_screw_driver",
"lerobot/aloha_static_tape",
"lerobot/aloha_static_thread_velcro",
"lerobot/aloha_static_towel",
"lerobot/aloha_static_vinh_cup",
"lerobot/aloha_static_vinh_cup_left",
"lerobot/aloha_static_ziploc_slide",
"lerobot/umi_cup_in_the_wild",
"lerobot/unitreeh1_fold_clothes",
"lerobot/unitreeh1_rearrange_objects",
"lerobot/unitreeh1_two_robot_greeting",
"lerobot/unitreeh1_warehouse",
"lerobot/nyu_rot_dataset",
"lerobot/utokyo_saytap",
"lerobot/imperialcollege_sawyer_wrist_cam",
"lerobot/utokyo_xarm_bimanual",
"lerobot/tokyo_u_lsmo",
"lerobot/utokyo_pr2_opening_fridge",
"lerobot/cmu_franka_exploration_dataset",
"lerobot/cmu_stretch",
"lerobot/asu_table_top",
"lerobot/utokyo_pr2_tabletop_manipulation",
"lerobot/utokyo_xarm_pick_and_place",
"lerobot/ucsd_kitchen_dataset",
"lerobot/austin_buds_dataset",
"lerobot/dlr_sara_grid_clamp",
"lerobot/conq_hose_manipulation",
"lerobot/columbia_cairlab_pusht_real",
"lerobot/dlr_sara_pour",
"lerobot/dlr_edan_shared_control",
"lerobot/ucsd_pick_and_place_dataset",
"lerobot/berkeley_cable_routing",
"lerobot/nyu_franka_play_dataset",
"lerobot/austin_sirius_dataset",
"lerobot/cmu_play_fusion",
"lerobot/berkeley_gnm_sac_son",
"lerobot/nyu_door_opening_surprising_effectiveness",
"lerobot/berkeley_fanuc_manipulation",
"lerobot/jaco_play",
"lerobot/viola",
"lerobot/kaist_nonprehensile",
"lerobot/berkeley_mvp",
"lerobot/uiuc_d3field",
"lerobot/berkeley_gnm_recon",
"lerobot/austin_sailor_dataset",
"lerobot/utaustin_mutex",
"lerobot/roboturk",
"lerobot/stanford_hydra_dataset",
"lerobot/berkeley_autolab_ur5",
"lerobot/stanford_robocook",
"lerobot/toto",
"lerobot/fmb",
"lerobot/droid_100",
"lerobot/berkeley_rpt",
"lerobot/stanford_kuka_multimodal_dataset",
"lerobot/iamlab_cmu_pickup_insert",
"lerobot/taco_play",
"lerobot/berkeley_gnm_cory_hall",
"lerobot/usc_cloth_sim",
]
available_datasets = sorted(
set(itertools.chain(*available_datasets_per_env.values(), available_real_world_datasets))
)
# lists all available policies from `lerobot/policies`
available_policies = ["act", "diffusion", "tdmpc", "vqbet"]
# lists all available robots from `lerobot/robots`
available_robots = [
"koch",
"koch_bimanual",
"aloha",
"so100",
"so101",
]
# lists all available cameras from `lerobot/cameras`
available_cameras = [
"opencv",
"intelrealsense",
]
# lists all available motors from `lerobot/motors`
available_motors = [
"dynamixel",
"feetech",
]
# keys and values refer to yaml files
available_policies_per_env = {
"aloha": ["act"],
"pusht": ["diffusion", "vqbet"],
"koch_real": ["act_koch_real"],
"aloha_real": ["act_aloha_real"],
}
env_task_pairs = [(env, task) for env, tasks in available_tasks_per_env.items() for task in tasks]
env_dataset_pairs = [
(env, dataset) for env, datasets in available_datasets_per_env.items() for dataset in datasets
]
env_dataset_policy_triplets = [
(env, dataset, policy)
for env, datasets in available_datasets_per_env.items()
for dataset in datasets
for policy in available_policies_per_env[env]
]
__all__ = ["__version__"]
@@ -14,3 +14,5 @@
from .configuration_reachy2_camera import Reachy2CameraConfig
from .reachy2_camera import Reachy2Camera
__all__ = ["Reachy2Camera", "Reachy2CameraConfig"]
@@ -14,3 +14,5 @@
from .camera_realsense import RealSenseCamera
from .configuration_realsense import RealSenseCameraConfig
__all__ = ["RealSenseCamera", "RealSenseCameraConfig"]
+8
View File
@@ -0,0 +1,8 @@
"""
Cross-cutting modules that bridge multiple lerobot packages.
Unlike ``lerobot.utils`` (which must remain dependency-free), modules here
are allowed to import from ``lerobot.policies``, ``lerobot.processor``,
``lerobot.configs``, etc. They are deliberately NOT re-exported from the
top-level ``lerobot`` package.
"""
@@ -217,12 +217,12 @@ def sanity_check_dataset_robot_compatibility(
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
from .import_utils import require_package
from lerobot.utils.import_utils import require_package
require_package("deepdiff", extra="hardware")
from deepdiff import DeepDiff
from .constants import DEFAULT_FEATURES
from lerobot.utils.constants import DEFAULT_FEATURES
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
@@ -27,16 +27,15 @@ from lerobot.optim import (
)
from lerobot.policies import PreTrainedPolicy
from lerobot.processor import PolicyProcessorPipeline
from .constants import (
from lerobot.utils.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from .io_utils import load_json, write_json
from .random_utils import load_rng_state, save_rng_state
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.random_utils import load_rng_state, save_rng_state
def get_step_identifier(step: int, total_steps: int) -> str:
+1
View File
@@ -18,6 +18,7 @@
from lerobot.utils.import_utils import require_package
require_package("datasets", extra="dataset")
require_package("av", extra="dataset")
from .compute_stats import aggregate_stats, get_feature_stats
from .dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
+4 -1
View File
@@ -263,7 +263,10 @@ class VideoDecoderCache:
if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder
else:
raise ImportError("torchcodec is required but not available.")
raise ImportError(
"'torchcodec' is required but not installed. "
"Install it with: pip install 'lerobot[dataset]' (or uv pip install 'lerobot[dataset]')"
)
video_path = str(video_path)
+3 -1
View File
@@ -15,4 +15,6 @@
# limitations under the License.
from .damiao import DamiaoMotorsBus
from .tables import *
from .tables import * # noqa: F403 — hardware constant tables
__all__ = ["DamiaoMotorsBus"]
+3 -1
View File
@@ -19,4 +19,6 @@ from lerobot.utils.import_utils import require_package
require_package("dynamixel-sdk", extra="dynamixel", import_name="dynamixel_sdk")
from .dynamixel import DriveMode, DynamixelMotorsBus, OperatingMode, TorqueMode
from .tables import *
from .tables import * # noqa: F403 — hardware constant tables
__all__ = ["DriveMode", "DynamixelMotorsBus", "OperatingMode", "TorqueMode"]
+3 -1
View File
@@ -19,4 +19,6 @@ from lerobot.utils.import_utils import require_package
require_package("feetech-servo-sdk", extra="feetech", import_name="scservo_sdk")
from .feetech import DriveMode, FeetechMotorsBus, OperatingMode, TorqueMode
from .tables import *
from .tables import * # noqa: F403 — hardware constant tables
__all__ = ["DriveMode", "FeetechMotorsBus", "OperatingMode", "TorqueMode"]
+3 -1
View File
@@ -15,4 +15,6 @@
# limitations under the License.
from .robstride import RobstrideMotorsBus
from .tables import *
from .tables import * # noqa: F403 — hardware constant tables
__all__ = ["RobstrideMotorsBus"]
+5
View File
@@ -0,0 +1,5 @@
from .configuration_act import ACTConfig
from .modeling_act import ACTPolicy
from .processor_act import make_act_pre_post_processors
__all__ = ["ACTConfig", "ACTPolicy", "make_act_pre_post_processors"]
@@ -0,0 +1,5 @@
from .configuration_diffusion import DiffusionConfig
from .modeling_diffusion import DiffusionPolicy
from .processor_diffusion import make_diffusion_pre_post_processors
__all__ = ["DiffusionConfig", "DiffusionPolicy", "make_diffusion_pre_post_processors"]
-4
View File
@@ -14,10 +14,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.import_utils import require_package
require_package("transformers", extra="groot")
from .configuration_groot import GrootConfig
from .modeling_groot import GrootPolicy
from .processor_groot import make_groot_pre_post_processors
+5
View File
@@ -0,0 +1,5 @@
from .configuration_sac import SACConfig
from .modeling_sac import SACPolicy
from .processor_sac import make_sac_pre_post_processors
__all__ = ["SACConfig", "SACPolicy", "make_sac_pre_post_processors"]
@@ -0,0 +1,5 @@
from .configuration_classifier import RewardClassifierConfig
from .modeling_classifier import Classifier
from .processor_classifier import make_classifier_processor
__all__ = ["RewardClassifierConfig", "Classifier", "make_classifier_processor"]
+7
View File
@@ -0,0 +1,7 @@
from .configuration_smolvla import SmolVLAConfig
from .processor_smolvla import (
SmolVLANewLineProcessor,
make_smolvla_pre_post_processors,
)
__all__ = ["SmolVLAConfig", "SmolVLANewLineProcessor", "make_smolvla_pre_post_processors"]
+5
View File
@@ -0,0 +1,5 @@
from .configuration_tdmpc import TDMPCConfig
from .modeling_tdmpc import TDMPCPolicy
from .processor_tdmpc import make_tdmpc_pre_post_processors
__all__ = ["TDMPCConfig", "TDMPCPolicy", "make_tdmpc_pre_post_processors"]
+5
View File
@@ -0,0 +1,5 @@
from .configuration_vqbet import VQBeTConfig
from .modeling_vqbet import VQBeTPolicy
from .processor_vqbet import make_vqbet_pre_post_processors
__all__ = ["VQBeTConfig", "VQBeTPolicy", "make_vqbet_pre_post_processors"]
+2 -1
View File
@@ -15,5 +15,6 @@
# limitations under the License.
from .configuration_wall_x import WallXConfig
from .processor_wall_x import make_wall_x_pre_post_processors
__all__ = ["WallXConfig"]
__all__ = ["WallXConfig", "make_wall_x_pre_post_processors"]
+8 -1
View File
@@ -1,6 +1,13 @@
# register the processor steps
from .configuration_xvla import XVLAConfig
from .processor_xvla import (
XVLAAddDomainIdProcessorStep,
XVLAImageNetNormalizeProcessorStep,
XVLAImageToFloatProcessorStep,
)
__all__ = [
"XVLAConfig",
"XVLAAddDomainIdProcessorStep",
"XVLAImageNetNormalizeProcessorStep",
"XVLAImageToFloatProcessorStep",
]
+6 -6
View File
@@ -60,6 +60,12 @@ from torch.multiprocessing import Queue
from torch.optim.optimizer import Optimizer
from lerobot.cameras import opencv # noqa: F401
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainRLServerPipelineConfig
from lerobot.datasets import LeRobotDataset, make_dataset
@@ -84,12 +90,6 @@ from lerobot.utils.constants import (
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
load_training_state as utils_load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
from lerobot.utils.utils import (
format_big_number,
@@ -16,3 +16,5 @@
from .bi_so_follower import BiSOFollower
from .config_bi_so_follower import BiSOFollowerConfig
__all__ = ["BiSOFollower", "BiSOFollowerConfig"]
+2
View File
@@ -17,3 +17,5 @@
from .config_hope_jr import HopeJrArmConfig, HopeJrHandConfig
from .hope_jr_arm import HopeJrArm
from .hope_jr_hand import HopeJrHand
__all__ = ["HopeJrArm", "HopeJrArmConfig", "HopeJrHand", "HopeJrHandConfig"]
@@ -16,3 +16,5 @@
from .config_koch_follower import KochFollowerConfig
from .koch_follower import KochFollower
__all__ = ["KochFollower", "KochFollowerConfig"]
+2 -4
View File
@@ -14,10 +14,8 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from lerobot.utils.import_utils import require_package
require_package("pyzmq", extra="lekiwi", import_name="zmq")
from .config_lekiwi import LeKiwiClientConfig, LeKiwiConfig
from .lekiwi import LeKiwi
from .lekiwi_client import LeKiwiClient
__all__ = ["LeKiwi", "LeKiwiClient", "LeKiwiClientConfig", "LeKiwiConfig"]
@@ -19,3 +19,5 @@
from .config_omx_follower import OmxFollowerConfig
from .omx_follower import OmxFollower
__all__ = ["OmxFollower", "OmxFollowerConfig"]
+10
View File
@@ -23,3 +23,13 @@ from .robot_reachy2 import (
REACHY2_VEL,
Reachy2Robot,
)
__all__ = [
"REACHY2_ANTENNAS_JOINTS",
"REACHY2_L_ARM_JOINTS",
"REACHY2_NECK_JOINTS",
"REACHY2_R_ARM_JOINTS",
"REACHY2_VEL",
"Reachy2Robot",
"Reachy2RobotConfig",
]
@@ -21,3 +21,13 @@ from .config_so_follower import (
SOFollowerRobotConfig,
)
from .so_follower import SO100Follower, SO101Follower, SOFollower
__all__ = [
"SO100Follower",
"SO100FollowerConfig",
"SO101Follower",
"SO101FollowerConfig",
"SOFollower",
"SOFollowerConfig",
"SOFollowerRobotConfig",
]
@@ -51,6 +51,10 @@ import shutil
from pathlib import Path
from typing import Any
from lerobot.utils.import_utils import require_package
require_package("jsonlines", extra="dataset")
import jsonlines
import pandas as pd
import pyarrow as pa
+7 -7
View File
@@ -81,6 +81,13 @@ from lerobot.cameras.opencv import OpenCVCameraConfig # noqa: F401
from lerobot.cameras.reachy2_camera import Reachy2CameraConfig # noqa: F401
from lerobot.cameras.realsense import RealSenseCameraConfig # noqa: F401
from lerobot.cameras.zmq import ZMQCameraConfig # noqa: F401
from lerobot.common.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.configs import PreTrainedConfig, parser
from lerobot.datasets import (
LeRobotDataset,
@@ -137,13 +144,6 @@ from lerobot.teleoperators import ( # noqa: F401
)
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop
from lerobot.utils.constants import ACTION, OBS_STR
from lerobot.utils.control_utils import (
init_keyboard_listener,
is_headless,
predict_action,
sanity_check_dataset_name,
sanity_check_dataset_robot_compatibility,
)
from lerobot.utils.device_utils import get_safe_torch_device
from lerobot.utils.feature_utils import build_dataset_frame, combine_feature_dicts
from lerobot.utils.import_utils import register_third_party_plugins
+7 -7
View File
@@ -28,6 +28,13 @@ from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import EpisodeAwareSampler, make_dataset
@@ -38,13 +45,6 @@ from lerobot.rl.wandb_utils import WandBLogger
from lerobot.utils.import_utils import register_third_party_plugins
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
from lerobot.utils.random_utils import set_seed
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
save_checkpoint,
update_last_checkpoint,
)
from lerobot.utils.utils import (
cycle,
format_big_number,
@@ -15,3 +15,5 @@
# limitations under the License.
from .bi_so_leader import BiSOLeader, BiSOLeaderConfig
__all__ = ["BiSOLeader", "BiSOLeaderConfig"]
@@ -16,3 +16,5 @@
from .configuration_gamepad import GamepadTeleopConfig
from .teleop_gamepad import GamepadTeleop
__all__ = ["GamepadTeleop", "GamepadTeleopConfig"]
@@ -18,3 +18,11 @@ from .config_homunculus import HomunculusArmConfig, HomunculusGloveConfig
from .homunculus_arm import HomunculusArm
from .homunculus_glove import HomunculusGlove
from .joints_translation import homunculus_glove_to_hope_jr_hand
__all__ = [
"HomunculusArm",
"HomunculusArmConfig",
"HomunculusGlove",
"HomunculusGloveConfig",
"homunculus_glove_to_hope_jr_hand",
]
@@ -23,6 +23,7 @@ from typing import Any
from lerobot.types import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from lerobot.utils.import_utils import is_package_available
from ..teleoperator import Teleoperator
from ..utils import TeleopEvents
@@ -32,18 +33,16 @@ from .configuration_keyboard import (
KeyboardTeleopConfig,
)
PYNPUT_AVAILABLE = True
try:
PYNPUT_AVAILABLE = is_package_available("pynput")
keyboard = None
if PYNPUT_AVAILABLE:
try:
if ("DISPLAY" not in os.environ) and ("linux" in sys.platform):
logging.info("No DISPLAY set. Skipping pynput import.")
raise ImportError("pynput blocked intentionally due to no display.")
from pynput import keyboard
except ImportError:
keyboard = None
PYNPUT_AVAILABLE = False
except Exception as e:
keyboard = None
else:
from pynput import keyboard
except Exception as e:
PYNPUT_AVAILABLE = False
logging.info(f"Could not import pynput: {e}")
@@ -16,3 +16,5 @@
from .config_koch_leader import KochLeaderConfig
from .koch_leader import KochLeader
__all__ = ["KochLeader", "KochLeaderConfig"]
@@ -16,3 +16,5 @@
from .config_omx_leader import OmxLeaderConfig
from .omx_leader import OmxLeader
__all__ = ["OmxLeader", "OmxLeaderConfig"]
@@ -16,3 +16,5 @@
from .config_phone import PhoneConfig
from .teleop_phone import Phone
__all__ = ["Phone", "PhoneConfig"]
@@ -23,3 +23,13 @@ from .reachy2_teleoperator import (
REACHY2_VEL,
Reachy2Teleoperator,
)
__all__ = [
"REACHY2_ANTENNAS_JOINTS",
"REACHY2_L_ARM_JOINTS",
"REACHY2_NECK_JOINTS",
"REACHY2_R_ARM_JOINTS",
"REACHY2_VEL",
"Reachy2Teleoperator",
"Reachy2TeleoperatorConfig",
]
@@ -21,3 +21,13 @@ from .config_so_leader import (
SOLeaderTeleopConfig,
)
from .so_leader import SO100Leader, SO101Leader, SOLeader
__all__ = [
"SO100Leader",
"SO100LeaderConfig",
"SO101Leader",
"SO101LeaderConfig",
"SOLeader",
"SOLeaderConfig",
"SOLeaderTeleopConfig",
]
@@ -19,6 +19,10 @@ import logging
from dataclasses import dataclass
from pathlib import Path
from lerobot.utils.import_utils import require_package
require_package("pyserial", extra="hardware", import_name="serial")
import serial
from .exo_calib import ExoskeletonCalibration, exo_raw_to_angles, run_exo_calibration
+3 -4
View File
@@ -15,10 +15,9 @@
"""
Public API for lightweight, base-dependency-only utilities.
Heavy utility modules (train_utils, control_utils, visualization_utils)
are intentionally NOT re-exported here to avoid pulling in optional
dependencies. Import them directly, e.g.:
``from lerobot.utils.train_utils import save_checkpoint``
Heavy cross-cutting modules (train_utils, control_utils) have been moved
to ``lerobot.common``. ``visualization_utils`` remains here but is
intentionally NOT re-exported to avoid pulling in optional dependencies.
"""
from .constants import (
+20 -6
View File
@@ -26,7 +26,6 @@ from PIL import Image
from safetensors.torch import load_file
from torchvision.transforms import v2
import lerobot
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.datasets import make_dataset
@@ -494,13 +493,28 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
# - [ ] remove old tests
ENV_DATASET_POLICY_TRIPLETS = [
("aloha", dataset, "act")
for dataset in [
"lerobot/aloha_sim_insertion_human",
"lerobot/aloha_sim_insertion_scripted",
"lerobot/aloha_sim_transfer_cube_human",
"lerobot/aloha_sim_transfer_cube_scripted",
"lerobot/aloha_sim_insertion_human_image",
"lerobot/aloha_sim_insertion_scripted_image",
"lerobot/aloha_sim_transfer_cube_human_image",
"lerobot/aloha_sim_transfer_cube_scripted_image",
]
] + [
("pusht", dataset, policy)
for dataset in ["lerobot/pusht", "lerobot/pusht_image"]
for policy in ["diffusion", "vqbet"]
]
@pytest.mark.parametrize(
"env_name, repo_id, policy_name",
# Single dataset
lerobot.env_dataset_policy_triplets,
# Multi-dataset
# TODO after fix multidataset
# + [("aloha", ["lerobot/aloha_sim_insertion_human", "lerobot/aloha_sim_transfer_cube_human"], "act")],
ENV_DATASET_POLICY_TRIPLETS,
)
def test_factory(env_name, repo_id, policy_name):
"""
+9 -3
View File
@@ -23,7 +23,6 @@ import torch
from gymnasium.envs.registration import register, registry as gym_registry
from gymnasium.utils.env_checker import check_env
import lerobot
from lerobot.configs.types import PolicyFeature
from lerobot.envs.configs import EnvConfig
from lerobot.envs.factory import make_env, make_env_config
@@ -36,9 +35,16 @@ from tests.utils import require_env
OBS_TYPES = ["state", "pixels", "pixels_agent_pos"]
ENV_TASK_PAIRS = [
("aloha", "AlohaInsertion-v0"),
("aloha", "AlohaTransferCube-v0"),
("pusht", "PushT-v0"),
]
AVAILABLE_ENVS = ["aloha", "pusht"]
@pytest.mark.parametrize("obs_type", OBS_TYPES)
@pytest.mark.parametrize("env_name, env_task", lerobot.env_task_pairs)
@pytest.mark.parametrize("env_name, env_task", ENV_TASK_PAIRS)
@require_env
def test_env(env_name, env_task, obs_type):
if env_name == "aloha" and obs_type == "state":
@@ -51,7 +57,7 @@ def test_env(env_name, env_task, obs_type):
env.close()
@pytest.mark.parametrize("env_name", lerobot.available_envs)
@pytest.mark.parametrize("env_name", AVAILABLE_ENVS)
@require_env
def test_factory(env_name):
cfg = make_env_config(env_name)
+5 -4
View File
@@ -23,7 +23,6 @@ import torch
from packaging import version
from safetensors.torch import load_file
from lerobot import available_policies
from lerobot.configs.default import DatasetConfig
from lerobot.configs.train import TrainPipelineConfig
from lerobot.configs.types import FeatureType, PolicyFeature
@@ -49,6 +48,8 @@ from lerobot.utils.utils import cycle
from tests.artifacts.policies.save_policy_to_safetensors import get_policy_stats
from tests.utils import DEVICE, require_cpu, require_env, require_x86_64_kernel
AVAILABLE_POLICIES = ["act", "diffusion", "tdmpc", "vqbet"]
@pytest.fixture
def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_path):
@@ -84,7 +85,7 @@ def dummy_dataset_metadata(lerobot_dataset_metadata_factory, info_factory, tmp_p
return ds_meta
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_get_policy_and_config_classes(policy_name: str):
"""Check that the correct policy and config classes are returned."""
policy_cls = get_policy_class(policy_name)
@@ -255,7 +256,7 @@ def test_act_backbone_lr():
assert len(optimizer.param_groups[1]["params"]) == 20
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
"""Check that the policy can be instantiated with defaults."""
policy_cls = get_policy_class(policy_name)
@@ -268,7 +269,7 @@ def test_policy_defaults(dummy_dataset_metadata, policy_name: str):
policy_cls(policy_cfg)
@pytest.mark.parametrize("policy_name", available_policies)
@pytest.mark.parametrize("policy_name", AVAILABLE_POLICIES)
def test_save_and_load_pretrained(dummy_dataset_metadata, tmp_path, policy_name: str):
policy_cls = get_policy_class(policy_name)
policy_cfg = make_policy_config(policy_name)
+4 -41
View File
@@ -13,48 +13,11 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import importlib
import gymnasium as gym
import pytest
import lerobot
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.policies.diffusion.modeling_diffusion import DiffusionPolicy
from lerobot.policies.tdmpc.modeling_tdmpc import TDMPCPolicy
from lerobot.policies.vqbet.modeling_vqbet import VQBeTPolicy
from tests.utils import require_env
@pytest.mark.parametrize("env_name, task_name", lerobot.env_task_pairs)
@require_env
def test_available_env_task(env_name: str, task_name: list):
"""
This test verifies that all environments listed in `lerobot/__init__.py` can
be successfully imported if they're installed — and that their
`available_tasks_per_env` are valid.
"""
package_name = f"gym_{env_name}"
importlib.import_module(package_name)
gym_handle = f"{package_name}/{task_name}"
assert gym_handle in gym.envs.registry, gym_handle
def test_available_policies():
"""
This test verifies that the class attribute `name` for all policies is
consistent with those listed in `lerobot/__init__.py`.
"""
policy_classes = [ACTPolicy, DiffusionPolicy, TDMPCPolicy, VQBeTPolicy]
policies = [pol_cls.name for pol_cls in policy_classes]
assert set(policies) == set(lerobot.available_policies), policies
def test_print():
print(lerobot.available_envs)
print(lerobot.available_tasks_per_env)
print(lerobot.available_datasets)
print(lerobot.available_datasets_per_env)
print(lerobot.available_real_world_datasets)
print(lerobot.available_policies)
print(lerobot.available_policies_per_env)
def test_version():
"""Verify the package exposes a version string."""
assert isinstance(lerobot.__version__, str)
assert len(lerobot.__version__) > 0
+19 -4
View File
@@ -20,22 +20,37 @@ from functools import wraps
import pytest
import torch
from lerobot import available_cameras, available_motors, available_robots
from lerobot.utils.device_utils import auto_select_torch_device
from lerobot.utils.import_utils import is_package_available
DEVICE = os.environ.get("LEROBOT_TEST_DEVICE", str(auto_select_torch_device()))
AVAILABLE_ROBOTS = [
"koch",
"koch_bimanual",
"aloha",
"so100",
"so101",
]
AVAILABLE_CAMERAS = [
"opencv",
"intelrealsense",
]
AVAILABLE_MOTORS = [
"dynamixel",
"feetech",
]
TEST_ROBOT_TYPES = []
for robot_type in available_robots:
for robot_type in AVAILABLE_ROBOTS:
TEST_ROBOT_TYPES += [(robot_type, True), (robot_type, False)]
TEST_CAMERA_TYPES = []
for camera_type in available_cameras:
for camera_type in AVAILABLE_CAMERAS:
TEST_CAMERA_TYPES += [(camera_type, True), (camera_type, False)]
TEST_MOTOR_TYPES = []
for motor_type in available_motors:
for motor_type in AVAILABLE_MOTORS:
TEST_MOTOR_TYPES += [(motor_type, True), (motor_type, False)]
# Camera indices used for connecting physical cameras
+12 -12
View File
@@ -17,6 +17,16 @@
from pathlib import Path
from unittest.mock import Mock, patch
from lerobot.common.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
from lerobot.utils.constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.train_utils import (
get_step_checkpoint_dir,
get_step_identifier,
load_training_state,
load_training_step,
save_checkpoint,
save_training_state,
save_training_step,
update_last_checkpoint,
)
def test_get_step_identifier():
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
assert last_checkpoint.resolve() == checkpoint
@patch("lerobot.utils.train_utils.save_training_state")
@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
cfg = Mock()
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
mock_save_training_state.assert_called_once()
@patch("lerobot.utils.train_utils.save_training_state")
@patch("lerobot.common.train_utils.save_training_state")
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
policy = Mock()
policy.config = Mock()