big imports refactor

This commit is contained in:
Steven Palma
2026-04-11 15:03:24 +02:00
parent 964acd0151
commit d626964119
183 changed files with 892 additions and 558 deletions
+66
View File
@@ -0,0 +1,66 @@
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""
Public API for lightweight, base-dependency-only utilities.
Heavy utility modules (train_utils, control_utils, visualization_utils)
are intentionally NOT re-exported here to avoid pulling in optional
dependencies. Import them directly, e.g.:
``from lerobot.utils.train_utils import save_checkpoint``
"""
from .constants import (
ACTION,
DEFAULT_FEATURES,
DONE,
IMAGENET_STATS,
OBS_ENV_STATE,
OBS_IMAGE,
OBS_IMAGES,
OBS_STATE,
OBS_STR,
REWARD,
)
from .decorators import check_if_already_connected, check_if_not_connected
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from .import_utils import is_package_available, require_package
__all__ = [
# Constants
"ACTION",
"DEFAULT_FEATURES",
"DONE",
"IMAGENET_STATS",
"OBS_ENV_STATE",
"OBS_IMAGE",
"OBS_IMAGES",
"OBS_STATE",
"OBS_STR",
"REWARD",
# Device utilities
"auto_select_torch_device",
"get_safe_torch_device",
"is_torch_device_available",
# Import guards
"is_package_available",
"require_package",
# Decorators
"check_if_already_connected",
"check_if_not_connected",
# Errors
"DeviceAlreadyConnectedError",
"DeviceNotConnectedError",
]
+4 -5
View File
@@ -27,11 +27,10 @@ from typing import TYPE_CHECKING, Any
import numpy as np
import torch
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
if TYPE_CHECKING:
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.policies.utils import prepare_observation_for_inference
from lerobot.datasets import LeRobotDataset
from lerobot.processor import PolicyProcessorPipeline
from lerobot.robots import Robot
from lerobot.types import PolicyAction
@@ -218,12 +217,12 @@ def sanity_check_dataset_robot_compatibility(
Raises:
ValueError: If any of the checked metadata fields do not match.
"""
from lerobot.utils.import_utils import require_package
from .import_utils import require_package
require_package("deepdiff", extra="hardware")
from deepdiff import DeepDiff
from lerobot.utils.constants import DEFAULT_FEATURES
from .constants import DEFAULT_FEATURES
fields = [
("robot_type", dataset.meta.robot_type, robot.robot_type),
+1 -1
View File
@@ -16,7 +16,7 @@
from functools import wraps
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
def check_if_not_connected(func):
+3 -2
View File
@@ -25,8 +25,9 @@ from typing import Any
import numpy as np
from lerobot.configs.types import FeatureType, PolicyFeature
from lerobot.utils.constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
from lerobot.configs import FeatureType, PolicyFeature
from .constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
def _validate_feature_names(features: dict[str, dict]) -> None:
+1 -1
View File
@@ -58,7 +58,7 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
stacked_frames: List of HWC uint8 numpy arrays (RGB).
fps: Frames per second for the output video.
"""
from lerobot.utils.import_utils import require_package
from .import_utils import require_package
require_package("av", extra="av-dep")
import av
+1 -1
View File
@@ -16,7 +16,7 @@
from collections.abc import Callable
from typing import Any
from lerobot.utils.utils import format_big_number
from .utils import format_big_number
class AverageMeter:
+2 -2
View File
@@ -23,8 +23,8 @@ import numpy as np
import torch
from safetensors.torch import load_file, save_file
from lerobot.utils.constants import RNG_STATE
from lerobot.utils.utils import flatten_dict, unflatten_dict
from .constants import RNG_STATE
from .utils import flatten_dict, unflatten_dict
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
+11 -6
View File
@@ -19,19 +19,24 @@ from torch.optim import Optimizer
from torch.optim.lr_scheduler import LRScheduler
from lerobot.configs.train import TrainPipelineConfig
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
from lerobot.policies.pretrained import PreTrainedPolicy
from lerobot.optim import (
load_optimizer_state,
load_scheduler_state,
save_optimizer_state,
save_scheduler_state,
)
from lerobot.policies import PreTrainedPolicy
from lerobot.processor import PolicyProcessorPipeline
from lerobot.utils.constants import (
from .constants import (
CHECKPOINTS_DIR,
LAST_CHECKPOINT_LINK,
PRETRAINED_MODEL_DIR,
TRAINING_STATE_DIR,
TRAINING_STEP,
)
from lerobot.utils.io_utils import load_json, write_json
from lerobot.utils.random_utils import load_rng_state, save_rng_state
from .io_utils import load_json, write_json
from .random_utils import load_rng_state, save_rng_state
def get_step_identifier(step: int, total_steps: int) -> str:
+1 -1
View File
@@ -18,7 +18,7 @@ from typing import TypedDict
import torch
from lerobot.utils.constants import ACTION
from .constants import ACTION
class Transition(TypedDict):
+1 -1
View File
@@ -18,9 +18,9 @@ import os
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import require_package
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
from .import_utils import require_package
def init_rerun(