mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 05:59:52 +00:00
feat(dependecies): minimal default tag install
This commit is contained in:
@@ -12,25 +12,25 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
########################################################################################
|
||||
# Utilities
|
||||
########################################################################################
|
||||
|
||||
|
||||
import logging
|
||||
import traceback
|
||||
from contextlib import nullcontext
|
||||
from copy import copy
|
||||
from functools import cache
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.utils import prepare_observation_for_inference
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.robots import Robot
|
||||
@@ -218,6 +218,13 @@ def sanity_check_dataset_robot_compatibility(
|
||||
Raises:
|
||||
ValueError: If any of the checked metadata fields do not match.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("deepdiff", extra="hardware")
|
||||
from deepdiff import DeepDiff
|
||||
|
||||
from lerobot.datasets.utils import DEFAULT_FEATURES
|
||||
|
||||
fields = [
|
||||
("robot_type", dataset.meta.robot_type, robot.robot_type),
|
||||
("fps", dataset.fps, fps),
|
||||
|
||||
@@ -69,6 +69,24 @@ def is_package_available(
|
||||
return package_exists
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
return "torchcodec"
|
||||
else:
|
||||
logging.warning(
|
||||
"'torchcodec' is not available in your platform, falling back to 'pyav' as a default decoder"
|
||||
)
|
||||
return "pyav"
|
||||
|
||||
|
||||
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
|
||||
"""Raise an informative ImportError if a package required by an optional feature is missing."""
|
||||
if not is_package_available(pkg_name, import_name):
|
||||
raise ImportError(
|
||||
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
|
||||
)
|
||||
|
||||
|
||||
_transformers_available = is_package_available("transformers")
|
||||
_peft_available = is_package_available("peft")
|
||||
_scipy_available = is_package_available("scipy")
|
||||
|
||||
@@ -14,21 +14,64 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import json
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
|
||||
import imageio
|
||||
from typing import Any
|
||||
|
||||
JsonLike = str | int | float | bool | None | list["JsonLike"] | dict[str, "JsonLike"] | tuple["JsonLike", ...]
|
||||
|
||||
|
||||
def write_video(video_path, stacked_frames, fps):
|
||||
# Filter out DeprecationWarnings raised from pkg_resources
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings(
|
||||
"ignore", "pkg_resources is deprecated as an API", category=DeprecationWarning
|
||||
)
|
||||
imageio.mimsave(video_path, stacked_frames, fps=fps)
|
||||
def load_json(fpath: Path) -> Any:
|
||||
"""Load data from a JSON file.
|
||||
|
||||
Args:
|
||||
fpath (Path): Path to the JSON file.
|
||||
|
||||
Returns:
|
||||
Any: The data loaded from the JSON file.
|
||||
"""
|
||||
with open(fpath) as f:
|
||||
return json.load(f)
|
||||
|
||||
|
||||
def write_json(data: dict, fpath: Path) -> None:
|
||||
"""Write data to a JSON file.
|
||||
|
||||
Creates parent directories if they don't exist.
|
||||
|
||||
Args:
|
||||
data (dict): The dictionary to write.
|
||||
fpath (Path): The path to the output JSON file.
|
||||
"""
|
||||
fpath.parent.mkdir(exist_ok=True, parents=True)
|
||||
with open(fpath, "w") as f:
|
||||
json.dump(data, f, indent=4, ensure_ascii=False)
|
||||
|
||||
|
||||
def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
||||
"""Write a sequence of RGB frames to an MP4 video file using libx264.
|
||||
|
||||
Args:
|
||||
video_path: Output file path.
|
||||
stacked_frames: List of HWC uint8 numpy arrays (RGB).
|
||||
fps: Frames per second for the output video.
|
||||
"""
|
||||
import av
|
||||
|
||||
with av.open(str(video_path), mode="w") as container:
|
||||
height, width = stacked_frames[0].shape[:2]
|
||||
# Ensure dimensions are even for yuv420p compatibility
|
||||
height = height if height % 2 == 0 else height - 1
|
||||
width = width if width % 2 == 0 else width - 1
|
||||
stream = container.add_stream("libx264", rate=fps)
|
||||
stream.width = width
|
||||
stream.height = height
|
||||
stream.pix_fmt = "yuv420p"
|
||||
for frame_array in stacked_frames:
|
||||
frame = av.VideoFrame.from_ndarray(frame_array, format="rgb24")
|
||||
for packet in stream.encode(frame):
|
||||
container.mux(packet)
|
||||
for packet in stream.encode():
|
||||
container.mux(packet)
|
||||
|
||||
|
||||
def deserialize_json_into_object[T: JsonLike](fpath: Path, obj: T) -> T:
|
||||
|
||||
@@ -23,8 +23,8 @@ import numpy as np
|
||||
import torch
|
||||
from safetensors.torch import load_file, save_file
|
||||
|
||||
from lerobot.datasets.utils import flatten_dict, unflatten_dict
|
||||
from lerobot.utils.constants import RNG_STATE
|
||||
from lerobot.utils.utils import flatten_dict, unflatten_dict
|
||||
|
||||
|
||||
def serialize_python_rng_state() -> dict[str, torch.Tensor]:
|
||||
|
||||
@@ -19,7 +19,6 @@ from torch.optim import Optimizer
|
||||
from torch.optim.lr_scheduler import LRScheduler
|
||||
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.io_utils import load_json, write_json
|
||||
from lerobot.optim.optimizers import load_optimizer_state, save_optimizer_state
|
||||
from lerobot.optim.schedulers import load_scheduler_state, save_scheduler_state
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
@@ -31,6 +30,7 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.io_utils import load_json, write_json
|
||||
from lerobot.utils.random_utils import load_rng_state, save_rng_state
|
||||
|
||||
|
||||
|
||||
@@ -199,6 +199,59 @@ def get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time_s: float):
|
||||
return days, hours, minutes, seconds
|
||||
|
||||
|
||||
def flatten_dict(d: dict, parent_key: str = "", sep: str = "/") -> dict:
|
||||
"""Flatten a nested dictionary by joining keys with a separator.
|
||||
|
||||
Example:
|
||||
>>> dct = {"a": {"b": 1, "c": {"d": 2}}, "e": 3}
|
||||
>>> print(flatten_dict(dct))
|
||||
{'a/b': 1, 'a/c/d': 2, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): The dictionary to flatten.
|
||||
parent_key (str): The base key to prepend to the keys in this level.
|
||||
sep (str): The separator to use between keys.
|
||||
|
||||
Returns:
|
||||
dict: A flattened dictionary.
|
||||
"""
|
||||
items = []
|
||||
for k, v in d.items():
|
||||
new_key = f"{parent_key}{sep}{k}" if parent_key else k
|
||||
if isinstance(v, dict):
|
||||
items.extend(flatten_dict(v, new_key, sep=sep).items())
|
||||
else:
|
||||
items.append((new_key, v))
|
||||
return dict(items)
|
||||
|
||||
|
||||
def unflatten_dict(d: dict, sep: str = "/") -> dict:
|
||||
"""Unflatten a dictionary with delimited keys into a nested dictionary.
|
||||
|
||||
Example:
|
||||
>>> flat_dct = {"a/b": 1, "a/c/d": 2, "e": 3}
|
||||
>>> print(unflatten_dict(flat_dct))
|
||||
{'a': {'b': 1, 'c': {'d': 2}}, 'e': 3}
|
||||
|
||||
Args:
|
||||
d (dict): A dictionary with flattened keys.
|
||||
sep (str): The separator used in the keys.
|
||||
|
||||
Returns:
|
||||
dict: A nested dictionary.
|
||||
"""
|
||||
outdict = {}
|
||||
for key, value in d.items():
|
||||
parts = key.split(sep)
|
||||
d_inner = outdict
|
||||
for part in parts[:-1]:
|
||||
if part not in d_inner:
|
||||
d_inner[part] = {}
|
||||
d_inner = d_inner[part]
|
||||
d_inner[parts[-1]] = value
|
||||
return outdict
|
||||
|
||||
|
||||
class SuppressProgressBars:
|
||||
"""
|
||||
Context manager to suppress progress bars.
|
||||
|
||||
@@ -16,11 +16,16 @@ import numbers
|
||||
import os
|
||||
|
||||
import numpy as np
|
||||
import rerun as rr
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
|
||||
import rerun as rr # noqa: E402
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation # noqa: E402
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402
|
||||
|
||||
|
||||
def init_rerun(
|
||||
@@ -44,6 +49,11 @@ def init_rerun(
|
||||
rr.spawn(memory_limit=memory_limit)
|
||||
|
||||
|
||||
def shutdown_rerun() -> None:
|
||||
"""Shuts down the Rerun SDK gracefully."""
|
||||
rr.rerun_shutdown()
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
||||
isinstance(x, np.ndarray) and x.ndim == 0
|
||||
|
||||
Reference in New Issue
Block a user