mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
Package folder structure (#1417)
* Move files * Replace imports & paths * Update relative paths * Update doc symlinks * Update instructions paths * Fix imports * Update grpc files * Update more instructions * Downgrade grpc-tools * Update manifest * Update more paths * Update config paths * Update CI paths * Update bandit exclusions * Remove walkthrough section
This commit is contained in:
@@ -0,0 +1,190 @@
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import abc
|
||||
import logging
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
from typing import Type, TypeVar
|
||||
|
||||
import draccus
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.constants import CONFIG_NAME
|
||||
from huggingface_hub.errors import HfHubHTTPError
|
||||
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import OptimizerConfig
|
||||
from lerobot.optim.schedulers import LRSchedulerConfig
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import auto_select_torch_device, is_amp_available, is_torch_device_available
|
||||
|
||||
# Generic variable that is either PreTrainedConfig or a subclass thereof
|
||||
T = TypeVar("T", bound="PreTrainedConfig")
|
||||
|
||||
|
||||
@dataclass
|
||||
class PreTrainedConfig(draccus.ChoiceRegistry, HubMixin, abc.ABC):
|
||||
"""
|
||||
Base configuration class for policy models.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of environment steps worth of observations to pass to the policy (takes the
|
||||
current step and additional steps going back).
|
||||
input_shapes: A dictionary defining the shapes of the input data for the policy.
|
||||
output_shapes: A dictionary defining the shapes of the output data for the policy.
|
||||
input_normalization_modes: A dictionary with key representing the modality and the value specifies the
|
||||
normalization mode to apply.
|
||||
output_normalization_modes: Similar dictionary as `input_normalization_modes`, but to unnormalize to
|
||||
the original scale.
|
||||
"""
|
||||
|
||||
n_obs_steps: int = 1
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(default_factory=dict)
|
||||
|
||||
input_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
output_features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
|
||||
device: str | None = None # cuda | cpu | mp
|
||||
# `use_amp` determines whether to use Automatic Mixed Precision (AMP) for training and evaluation. With AMP,
|
||||
# automatic gradient scaling is used.
|
||||
use_amp: bool = False
|
||||
|
||||
push_to_hub: bool = True
|
||||
repo_id: str | None = None
|
||||
|
||||
# Upload on private repository on the Hugging Face hub.
|
||||
private: bool | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
tags: list[str] | None = None
|
||||
# Add tags to your policy on the hub.
|
||||
license: str | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
self.pretrained_path = None
|
||||
if not self.device or not is_torch_device_available(self.device):
|
||||
auto_device = auto_select_torch_device()
|
||||
logging.warning(f"Device '{self.device}' is not available. Switching to '{auto_device}'.")
|
||||
self.device = auto_device.type
|
||||
|
||||
# Automatically deactivate AMP if necessary
|
||||
if self.use_amp and not is_amp_available(self.device):
|
||||
logging.warning(
|
||||
f"Automatic Mixed Precision (amp) is not available on device '{self.device}'. Deactivating AMP."
|
||||
)
|
||||
self.use_amp = False
|
||||
|
||||
@property
|
||||
def type(self) -> str:
|
||||
return self.get_choice_name(self.__class__)
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def action_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def reward_delta_indices(self) -> list | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_optimizer_preset(self) -> OptimizerConfig:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def get_scheduler_preset(self) -> LRSchedulerConfig | None:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
def validate_features(self) -> None:
|
||||
raise NotImplementedError
|
||||
|
||||
@property
|
||||
def robot_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.STATE:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@property
|
||||
def env_state_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.input_features.items():
|
||||
if ft.type is FeatureType.ENV:
|
||||
return ft
|
||||
return None
|
||||
|
||||
@property
|
||||
def image_features(self) -> dict[str, PolicyFeature]:
|
||||
return {key: ft for key, ft in self.input_features.items() if ft.type is FeatureType.VISUAL}
|
||||
|
||||
@property
|
||||
def action_feature(self) -> PolicyFeature | None:
|
||||
for _, ft in self.output_features.items():
|
||||
if ft.type is FeatureType.ACTION:
|
||||
return ft
|
||||
return None
|
||||
|
||||
def _save_pretrained(self, save_directory: Path) -> None:
|
||||
with open(save_directory / CONFIG_NAME, "w") as f, draccus.config_type("json"):
|
||||
draccus.dump(self, f, indent=4)
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: Type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
force_download: bool = False,
|
||||
resume_download: bool = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
**policy_kwargs,
|
||||
) -> T:
|
||||
model_id = str(pretrained_name_or_path)
|
||||
config_file: str | None = None
|
||||
if Path(model_id).is_dir():
|
||||
if CONFIG_NAME in os.listdir(model_id):
|
||||
config_file = os.path.join(model_id, CONFIG_NAME)
|
||||
else:
|
||||
print(f"{CONFIG_NAME} not found in {Path(model_id).resolve()}")
|
||||
else:
|
||||
try:
|
||||
config_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename=CONFIG_NAME,
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(
|
||||
f"{CONFIG_NAME} not found on the HuggingFace Hub in {model_id}"
|
||||
) from e
|
||||
|
||||
# HACK: this is very ugly, ideally we'd like to be able to do that natively with draccus
|
||||
# something like --policy.path (in addition to --policy.type)
|
||||
cli_overrides = policy_kwargs.pop("cli_overrides", [])
|
||||
with draccus.config_type("json"):
|
||||
return draccus.parse(cls, config_file, args=cli_overrides)
|
||||
Reference in New Issue
Block a user