diff --git a/pyproject.toml b/pyproject.toml index 5bc4d630a..3d036afa8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -75,9 +75,13 @@ dependencies = [ # tightly-coupled modules. Candidate for a future refactor to decouple envs from the core. "gymnasium>=1.1.1,<2.0.0", + # Serialization & checkpointing + "safetensors>=0.4.3,<1.0.0", + # Lightweight utilities "packaging>=24.2,<26.0", "termcolor>=2.4.0,<4.0.0", + "tqdm>=4.66.0,<5.0.0", ] # Optional dependencies @@ -166,14 +170,13 @@ wallx = [ "lerobot[qwen-vl-utils-dep]", ] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] -smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] +smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"] multi_task_dit = ["lerobot[transformers-dep]"] groot = [ "lerobot[transformers-dep]", "lerobot[peft]", "dm-tree>=0.1.8,<1.0.0", "timm>=1.0.0,<1.1.0", - "safetensors>=0.4.3,<1.0.0", "Pillow>=10.0.0,<13.0.0", "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "ninja>=1.11.1,<2.0.0", diff --git a/src/lerobot/utils/import_utils.py b/src/lerobot/utils/import_utils.py index 5d6e3ab78..e85f265f9 100644 --- a/src/lerobot/utils/import_utils.py +++ b/src/lerobot/utils/import_utils.py @@ -80,9 +80,15 @@ def get_safe_default_codec(): return "pyav" +_require_package_cache: dict[str, bool] = {} + + def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None: """Raise an informative ImportError if a package required by an optional feature is missing.""" - if not is_package_available(pkg_name, import_name): + cache_key = import_name or pkg_name + if cache_key not in _require_package_cache: + _require_package_cache[cache_key] = is_package_available(pkg_name, import_name) + if not _require_package_cache[cache_key]: raise ImportError( f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'" ) diff --git a/src/lerobot/utils/io_utils.py b/src/lerobot/utils/io_utils.py index 5a584541f..554d341aa 100644 --- a/src/lerobot/utils/io_utils.py +++ b/src/lerobot/utils/io_utils.py @@ -58,6 +58,9 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None: stacked_frames: List of HWC uint8 numpy arrays (RGB). fps: Frames per second for the output video. """ + from lerobot.utils.import_utils import require_package + + require_package("av", extra="av-dep") import av with av.open(str(video_path), mode="w") as container: diff --git a/src/lerobot/utils/utils.py b/src/lerobot/utils/utils.py index 48f2590f4..d8b06a56b 100644 --- a/src/lerobot/utils/utils.py +++ b/src/lerobot/utils/utils.py @@ -287,14 +287,23 @@ class SuppressProgressBars: """ def __enter__(self): - from datasets.utils.logging import disable_progress_bar + try: + from datasets.utils.logging import disable_progress_bar - disable_progress_bar() + disable_progress_bar() + except ImportError: + logging.getLogger(__name__).info( + "SuppressProgressBars is a no-op because 'datasets' is not installed. " + "Install it with: pip install 'lerobot[dataset]'" + ) def __exit__(self, exc_type, exc_val, exc_tb): - from datasets.utils.logging import enable_progress_bar + try: + from datasets.utils.logging import enable_progress_bar - enable_progress_bar() + enable_progress_bar() + except ImportError: + pass class TimerManager: diff --git a/src/lerobot/utils/visualization_utils.py b/src/lerobot/utils/visualization_utils.py index 2c24ff1d8..2fe61ff9f 100644 --- a/src/lerobot/utils/visualization_utils.py +++ b/src/lerobot/utils/visualization_utils.py @@ -17,15 +17,10 @@ import os import numpy as np +from lerobot.types import RobotAction, RobotObservation from lerobot.utils.import_utils import require_package -require_package("rerun-sdk", extra="viz", import_name="rerun") - -import rerun as rr # noqa: E402 - -from lerobot.types import RobotAction, RobotObservation # noqa: E402 - -from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402 +from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR def init_rerun( @@ -39,6 +34,10 @@ def init_rerun( ip: Optional IP for connecting to a Rerun server. port: Optional port for connecting to a Rerun server. """ + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size rr.init(session_name) @@ -51,6 +50,10 @@ def init_rerun( def shutdown_rerun() -> None: """Shuts down the Rerun SDK gracefully.""" + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + rr.rerun_shutdown() @@ -83,6 +86,10 @@ def log_rerun_data( action: An optional dictionary containing action data to log. compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. """ + + require_package("rerun-sdk", extra="viz", import_name="rerun") + import rerun as rr + if observation: for k, v in observation.items(): if v is None: diff --git a/uv.lock b/uv.lock index 95218df06..7c7bb3518 100644 --- a/uv.lock +++ b/uv.lock @@ -2200,9 +2200,11 @@ dependencies = [ { name = "numpy" }, { name = "opencv-python-headless" }, { name = "packaging" }, + { name = "safetensors" }, { name = "termcolor" }, { name = "torch" }, { name = "torchvision" }, + { name = "tqdm" }, ] [package.optional-dependencies] @@ -2251,7 +2253,6 @@ all = [ { name = "qwen-vl-utils" }, { name = "reachy2-sdk" }, { name = "rerun-sdk" }, - { name = "safetensors" }, { name = "scikit-image" }, { name = "scipy" }, { name = "setuptools" }, @@ -2341,7 +2342,6 @@ groot = [ { name = "ninja" }, { name = "peft" }, { name = "pillow" }, - { name = "safetensors" }, { name = "timm" }, { name = "transformers" }, ] @@ -2466,7 +2466,6 @@ scipy-dep = [ smolvla = [ { name = "accelerate" }, { name = "num2words" }, - { name = "safetensors" }, { name = "transformers" }, ] test = [ @@ -2674,8 +2673,7 @@ requires-dist = [ { name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" }, { name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" }, { name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" }, - { name = "safetensors", marker = "extra == 'groot'", specifier = ">=0.4.3,<1.0.0" }, - { name = "safetensors", marker = "extra == 'smolvla'", specifier = ">=0.4.3,<1.0.0" }, + { name = "safetensors", specifier = ">=0.4.3,<1.0.0" }, { name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" }, { name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" }, { name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" }, @@ -2687,6 +2685,7 @@ requires-dist = [ { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" }, { name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" }, { name = "torchvision", specifier = ">=0.22.0,<0.26.0" }, + { name = "tqdm", specifier = ">=4.66.0,<5.0.0" }, { name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, ]