mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
refactor: minor improvements
This commit is contained in:
+5
-2
@@ -75,9 +75,13 @@ dependencies = [
|
|||||||
# tightly-coupled modules. Candidate for a future refactor to decouple envs from the core.
|
# tightly-coupled modules. Candidate for a future refactor to decouple envs from the core.
|
||||||
"gymnasium>=1.1.1,<2.0.0",
|
"gymnasium>=1.1.1,<2.0.0",
|
||||||
|
|
||||||
|
# Serialization & checkpointing
|
||||||
|
"safetensors>=0.4.3,<1.0.0",
|
||||||
|
|
||||||
# Lightweight utilities
|
# Lightweight utilities
|
||||||
"packaging>=24.2,<26.0",
|
"packaging>=24.2,<26.0",
|
||||||
"termcolor>=2.4.0,<4.0.0",
|
"termcolor>=2.4.0,<4.0.0",
|
||||||
|
"tqdm>=4.66.0,<5.0.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Optional dependencies
|
# Optional dependencies
|
||||||
@@ -166,14 +170,13 @@ wallx = [
|
|||||||
"lerobot[qwen-vl-utils-dep]",
|
"lerobot[qwen-vl-utils-dep]",
|
||||||
]
|
]
|
||||||
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
|
||||||
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"]
|
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
|
||||||
multi_task_dit = ["lerobot[transformers-dep]"]
|
multi_task_dit = ["lerobot[transformers-dep]"]
|
||||||
groot = [
|
groot = [
|
||||||
"lerobot[transformers-dep]",
|
"lerobot[transformers-dep]",
|
||||||
"lerobot[peft]",
|
"lerobot[peft]",
|
||||||
"dm-tree>=0.1.8,<1.0.0",
|
"dm-tree>=0.1.8,<1.0.0",
|
||||||
"timm>=1.0.0,<1.1.0",
|
"timm>=1.0.0,<1.1.0",
|
||||||
"safetensors>=0.4.3,<1.0.0",
|
|
||||||
"Pillow>=10.0.0,<13.0.0",
|
"Pillow>=10.0.0,<13.0.0",
|
||||||
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
|
||||||
"ninja>=1.11.1,<2.0.0",
|
"ninja>=1.11.1,<2.0.0",
|
||||||
|
|||||||
@@ -80,9 +80,15 @@ def get_safe_default_codec():
|
|||||||
return "pyav"
|
return "pyav"
|
||||||
|
|
||||||
|
|
||||||
|
_require_package_cache: dict[str, bool] = {}
|
||||||
|
|
||||||
|
|
||||||
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
|
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
|
||||||
"""Raise an informative ImportError if a package required by an optional feature is missing."""
|
"""Raise an informative ImportError if a package required by an optional feature is missing."""
|
||||||
if not is_package_available(pkg_name, import_name):
|
cache_key = import_name or pkg_name
|
||||||
|
if cache_key not in _require_package_cache:
|
||||||
|
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
|
||||||
|
if not _require_package_cache[cache_key]:
|
||||||
raise ImportError(
|
raise ImportError(
|
||||||
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
|
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -58,6 +58,9 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
|||||||
stacked_frames: List of HWC uint8 numpy arrays (RGB).
|
stacked_frames: List of HWC uint8 numpy arrays (RGB).
|
||||||
fps: Frames per second for the output video.
|
fps: Frames per second for the output video.
|
||||||
"""
|
"""
|
||||||
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
|
require_package("av", extra="av-dep")
|
||||||
import av
|
import av
|
||||||
|
|
||||||
with av.open(str(video_path), mode="w") as container:
|
with av.open(str(video_path), mode="w") as container:
|
||||||
|
|||||||
@@ -287,14 +287,23 @@ class SuppressProgressBars:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
from datasets.utils.logging import disable_progress_bar
|
try:
|
||||||
|
from datasets.utils.logging import disable_progress_bar
|
||||||
|
|
||||||
disable_progress_bar()
|
disable_progress_bar()
|
||||||
|
except ImportError:
|
||||||
|
logging.getLogger(__name__).info(
|
||||||
|
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
|
||||||
|
"Install it with: pip install 'lerobot[dataset]'"
|
||||||
|
)
|
||||||
|
|
||||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||||
from datasets.utils.logging import enable_progress_bar
|
try:
|
||||||
|
from datasets.utils.logging import enable_progress_bar
|
||||||
|
|
||||||
enable_progress_bar()
|
enable_progress_bar()
|
||||||
|
except ImportError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
class TimerManager:
|
class TimerManager:
|
||||||
|
|||||||
@@ -17,15 +17,10 @@ import os
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
|
from lerobot.types import RobotAction, RobotObservation
|
||||||
from lerobot.utils.import_utils import require_package
|
from lerobot.utils.import_utils import require_package
|
||||||
|
|
||||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||||
|
|
||||||
import rerun as rr # noqa: E402
|
|
||||||
|
|
||||||
from lerobot.types import RobotAction, RobotObservation # noqa: E402
|
|
||||||
|
|
||||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402
|
|
||||||
|
|
||||||
|
|
||||||
def init_rerun(
|
def init_rerun(
|
||||||
@@ -39,6 +34,10 @@ def init_rerun(
|
|||||||
ip: Optional IP for connecting to a Rerun server.
|
ip: Optional IP for connecting to a Rerun server.
|
||||||
port: Optional port for connecting to a Rerun server.
|
port: Optional port for connecting to a Rerun server.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||||
rr.init(session_name)
|
rr.init(session_name)
|
||||||
@@ -51,6 +50,10 @@ def init_rerun(
|
|||||||
|
|
||||||
def shutdown_rerun() -> None:
|
def shutdown_rerun() -> None:
|
||||||
"""Shuts down the Rerun SDK gracefully."""
|
"""Shuts down the Rerun SDK gracefully."""
|
||||||
|
|
||||||
|
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
rr.rerun_shutdown()
|
rr.rerun_shutdown()
|
||||||
|
|
||||||
|
|
||||||
@@ -83,6 +86,10 @@ def log_rerun_data(
|
|||||||
action: An optional dictionary containing action data to log.
|
action: An optional dictionary containing action data to log.
|
||||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||||
|
import rerun as rr
|
||||||
|
|
||||||
if observation:
|
if observation:
|
||||||
for k, v in observation.items():
|
for k, v in observation.items():
|
||||||
if v is None:
|
if v is None:
|
||||||
|
|||||||
@@ -2200,9 +2200,11 @@ dependencies = [
|
|||||||
{ name = "numpy" },
|
{ name = "numpy" },
|
||||||
{ name = "opencv-python-headless" },
|
{ name = "opencv-python-headless" },
|
||||||
{ name = "packaging" },
|
{ name = "packaging" },
|
||||||
|
{ name = "safetensors" },
|
||||||
{ name = "termcolor" },
|
{ name = "termcolor" },
|
||||||
{ name = "torch" },
|
{ name = "torch" },
|
||||||
{ name = "torchvision" },
|
{ name = "torchvision" },
|
||||||
|
{ name = "tqdm" },
|
||||||
]
|
]
|
||||||
|
|
||||||
[package.optional-dependencies]
|
[package.optional-dependencies]
|
||||||
@@ -2251,7 +2253,6 @@ all = [
|
|||||||
{ name = "qwen-vl-utils" },
|
{ name = "qwen-vl-utils" },
|
||||||
{ name = "reachy2-sdk" },
|
{ name = "reachy2-sdk" },
|
||||||
{ name = "rerun-sdk" },
|
{ name = "rerun-sdk" },
|
||||||
{ name = "safetensors" },
|
|
||||||
{ name = "scikit-image" },
|
{ name = "scikit-image" },
|
||||||
{ name = "scipy" },
|
{ name = "scipy" },
|
||||||
{ name = "setuptools" },
|
{ name = "setuptools" },
|
||||||
@@ -2341,7 +2342,6 @@ groot = [
|
|||||||
{ name = "ninja" },
|
{ name = "ninja" },
|
||||||
{ name = "peft" },
|
{ name = "peft" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow" },
|
||||||
{ name = "safetensors" },
|
|
||||||
{ name = "timm" },
|
{ name = "timm" },
|
||||||
{ name = "transformers" },
|
{ name = "transformers" },
|
||||||
]
|
]
|
||||||
@@ -2466,7 +2466,6 @@ scipy-dep = [
|
|||||||
smolvla = [
|
smolvla = [
|
||||||
{ name = "accelerate" },
|
{ name = "accelerate" },
|
||||||
{ name = "num2words" },
|
{ name = "num2words" },
|
||||||
{ name = "safetensors" },
|
|
||||||
{ name = "transformers" },
|
{ name = "transformers" },
|
||||||
]
|
]
|
||||||
test = [
|
test = [
|
||||||
@@ -2674,8 +2673,7 @@ requires-dist = [
|
|||||||
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
|
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
|
||||||
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
|
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
|
||||||
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
|
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
|
||||||
{ name = "safetensors", marker = "extra == 'groot'", specifier = ">=0.4.3,<1.0.0" },
|
{ name = "safetensors", specifier = ">=0.4.3,<1.0.0" },
|
||||||
{ name = "safetensors", marker = "extra == 'smolvla'", specifier = ">=0.4.3,<1.0.0" },
|
|
||||||
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
|
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
|
||||||
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
|
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
|
||||||
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
|
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
|
||||||
@@ -2687,6 +2685,7 @@ requires-dist = [
|
|||||||
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" },
|
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" },
|
||||||
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
|
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
|
||||||
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
|
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
|
||||||
|
{ name = "tqdm", specifier = ">=4.66.0,<5.0.0" },
|
||||||
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
|
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
|
||||||
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
|
||||||
]
|
]
|
||||||
|
|||||||
Reference in New Issue
Block a user