mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 21:50:03 +00:00
big imports refactor
This commit is contained in:
@@ -0,0 +1,66 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Public API for lightweight, base-dependency-only utilities.
|
||||
|
||||
Heavy utility modules (train_utils, control_utils, visualization_utils)
|
||||
are intentionally NOT re-exported here to avoid pulling in optional
|
||||
dependencies. Import them directly, e.g.:
|
||||
``from lerobot.utils.train_utils import save_checkpoint``
|
||||
"""
|
||||
|
||||
from .constants import (
|
||||
ACTION,
|
||||
DEFAULT_FEATURES,
|
||||
DONE,
|
||||
IMAGENET_STATS,
|
||||
OBS_ENV_STATE,
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGES,
|
||||
OBS_STATE,
|
||||
OBS_STR,
|
||||
REWARD,
|
||||
)
|
||||
from .decorators import check_if_already_connected, check_if_not_connected
|
||||
from .device_utils import auto_select_torch_device, get_safe_torch_device, is_torch_device_available
|
||||
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from .import_utils import is_package_available, require_package
|
||||
|
||||
__all__ = [
|
||||
# Constants
|
||||
"ACTION",
|
||||
"DEFAULT_FEATURES",
|
||||
"DONE",
|
||||
"IMAGENET_STATS",
|
||||
"OBS_ENV_STATE",
|
||||
"OBS_IMAGE",
|
||||
"OBS_IMAGES",
|
||||
"OBS_STATE",
|
||||
"OBS_STR",
|
||||
"REWARD",
|
||||
# Device utilities
|
||||
"auto_select_torch_device",
|
||||
"get_safe_torch_device",
|
||||
"is_torch_device_available",
|
||||
# Import guards
|
||||
"is_package_available",
|
||||
"require_package",
|
||||
# Decorators
|
||||
"check_if_already_connected",
|
||||
"check_if_not_connected",
|
||||
# Errors
|
||||
"DeviceAlreadyConnectedError",
|
||||
"DeviceNotConnectedError",
|
||||
]
|
||||
@@ -27,11 +27,10 @@ from typing import TYPE_CHECKING, Any
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies import PreTrainedPolicy, prepare_observation_for_inference
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.datasets import LeRobotDataset
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.robots import Robot
|
||||
from lerobot.types import PolicyAction
|
||||
@@ -218,12 +217,12 @@ def sanity_check_dataset_robot_compatibility(
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from .import_utils import require_package
|
||||
|
||||
require_package("deepdiff", extra="hardware")
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from lerobot.utils.constants import DEFAULT_FEATURES
|
||||
from .constants import DEFAULT_FEATURES
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
|
||||
from functools import wraps
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
from .errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
|
||||
def check_if_not_connected(func):
|
||||
|
||||
@@ -25,8 +25,9 @@ from typing import Any
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
|
||||
from lerobot.configs import FeatureType, PolicyFeature
|
||||
|
||||
from .constants import ACTION, DEFAULT_FEATURES, OBS_ENV_STATE, OBS_STR
|
||||
|
||||
|
||||
def _validate_feature_names(features: dict[str, dict]) -> None:
|
||||
|
||||
@@ -58,7 +58,7 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
||||
stacked_frames: List of HWC uint8 numpy arrays (RGB).
|
||||
fps: Frames per second for the output video.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
from .import_utils import require_package
|
||||
|
||||
require_package("av", extra="av-dep")
|
||||
import av
|
||||
|
||||
@@ -16,7 +16,7 @@
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from lerobot.utils.utils import format_big_number
|
||||
from .utils import format_big_number
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
|
||||
@@ -23,8 +23,8 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.utils.constants import RNG_STATE
|
||||
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
||||
from .constants import RNG_STATE
|
||||
from .utils import flatten_dict, unflatten_dict
|
||||
|
||||
|
||||
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -19,19 +19,24 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.optim import (
|
||||
load_optimizer_state,
|
||||
load_scheduler_state,
|
||||
save_optimizer_state,
|
||||
save_scheduler_state,
|
||||
)
|
||||
from lerobot.policies import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import (
|
||||
|
||||
from .constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
PRETRAINED_MODEL_DIR,
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
from .io_utils import load_json, write_json
|
||||
from .random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
def get_step_identifier(step: int, total_steps: int) -> str:
|
||||
|
||||
@@ -18,7 +18,7 @@ from typing import TypedDict
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.utils.constants import ACTION
|
||||
from .constants import ACTION
|
||||
|
||||
|
||||
class Transition(TypedDict):
|
||||
|
||||
@@ -18,9 +18,9 @@ import os
|
||||
import numpy as np
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
from .import_utils import require_package
|
||||
|
||||
|
||||
def init_rerun(
|
||||
|
||||
Reference in New Issue
Block a user