refactor: minor improvements

This commit is contained in:
Steven Palma
2026-04-10 18:31:07 +02:00
parent 882a6b0965
commit 4c39981908
6 changed files with 46 additions and 19 deletions
+5 -2
View File
@@ -75,9 +75,13 @@ dependencies = [
# tightly-coupled modules. Candidate for a future refactor to decouple envs from the core. # tightly-coupled modules. Candidate for a future refactor to decouple envs from the core.
"gymnasium>=1.1.1,<2.0.0", "gymnasium>=1.1.1,<2.0.0",
# Serialization & checkpointing
"safetensors>=0.4.3,<1.0.0",
# Lightweight utilities # Lightweight utilities
"packaging>=24.2,<26.0", "packaging>=24.2,<26.0",
"termcolor>=2.4.0,<4.0.0", "termcolor>=2.4.0,<4.0.0",
"tqdm>=4.66.0,<5.0.0",
] ]
# Optional dependencies # Optional dependencies
@@ -166,14 +170,13 @@ wallx = [
"lerobot[qwen-vl-utils-dep]", "lerobot[qwen-vl-utils-dep]",
] ]
pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"] pi = ["lerobot[transformers-dep]", "lerobot[scipy-dep]"]
smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0", "safetensors>=0.4.3,<1.0.0"] smolvla = ["lerobot[transformers-dep]", "num2words>=0.5.14,<0.6.0", "accelerate>=1.7.0,<2.0.0"]
multi_task_dit = ["lerobot[transformers-dep]"] multi_task_dit = ["lerobot[transformers-dep]"]
groot = [ groot = [
"lerobot[transformers-dep]", "lerobot[transformers-dep]",
"lerobot[peft]", "lerobot[peft]",
"dm-tree>=0.1.8,<1.0.0", "dm-tree>=0.1.8,<1.0.0",
"timm>=1.0.0,<1.1.0", "timm>=1.0.0,<1.1.0",
"safetensors>=0.4.3,<1.0.0",
"Pillow>=10.0.0,<13.0.0", "Pillow>=10.0.0,<13.0.0",
"decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')", "decord>=0.6.0,<1.0.0; (platform_machine == 'AMD64' or platform_machine == 'x86_64')",
"ninja>=1.11.1,<2.0.0", "ninja>=1.11.1,<2.0.0",
+7 -1
View File
@@ -80,9 +80,15 @@ def get_safe_default_codec():
return "pyav" return "pyav"
_require_package_cache: dict[str, bool] = {}
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None: def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
"""Raise an informative ImportError if a package required by an optional feature is missing.""" """Raise an informative ImportError if a package required by an optional feature is missing."""
if not is_package_available(pkg_name, import_name): cache_key = import_name or pkg_name
if cache_key not in _require_package_cache:
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
if not _require_package_cache[cache_key]:
raise ImportError( raise ImportError(
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'" f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
) )
+3
View File
@@ -58,6 +58,9 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
stacked_frames: List of HWC uint8 numpy arrays (RGB). stacked_frames: List of HWC uint8 numpy arrays (RGB).
fps: Frames per second for the output video. fps: Frames per second for the output video.
""" """
from lerobot.utils.import_utils import require_package
require_package("av", extra="av-dep")
import av import av
with av.open(str(video_path), mode="w") as container: with av.open(str(video_path), mode="w") as container:
+13 -4
View File
@@ -287,14 +287,23 @@ class SuppressProgressBars:
""" """
def __enter__(self): def __enter__(self):
from datasets.utils.logging import disable_progress_bar try:
from datasets.utils.logging import disable_progress_bar
disable_progress_bar() disable_progress_bar()
except ImportError:
logging.getLogger(__name__).info(
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
"Install it with: pip install 'lerobot[dataset]'"
)
def __exit__(self, exc_type, exc_val, exc_tb): def __exit__(self, exc_type, exc_val, exc_tb):
from datasets.utils.logging import enable_progress_bar try:
from datasets.utils.logging import enable_progress_bar
enable_progress_bar() enable_progress_bar()
except ImportError:
pass
class TimerManager: class TimerManager:
+14 -7
View File
@@ -17,15 +17,10 @@ import os
import numpy as np import numpy as np
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import require_package from lerobot.utils.import_utils import require_package
require_package("rerun-sdk", extra="viz", import_name="rerun") from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
import rerun as rr # noqa: E402
from lerobot.types import RobotAction, RobotObservation # noqa: E402
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402
def init_rerun( def init_rerun(
@@ -39,6 +34,10 @@ def init_rerun(
ip: Optional IP for connecting to a Rerun server. ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server. port: Optional port for connecting to a Rerun server.
""" """
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000") batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name) rr.init(session_name)
@@ -51,6 +50,10 @@ def init_rerun(
def shutdown_rerun() -> None: def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully.""" """Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown() rr.rerun_shutdown()
@@ -83,6 +86,10 @@ def log_rerun_data(
action: An optional dictionary containing action data to log. action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality. compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
""" """
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
if observation: if observation:
for k, v in observation.items(): for k, v in observation.items():
if v is None: if v is None:
Generated
+4 -5
View File
@@ -2200,9 +2200,11 @@ dependencies = [
{ name = "numpy" }, { name = "numpy" },
{ name = "opencv-python-headless" }, { name = "opencv-python-headless" },
{ name = "packaging" }, { name = "packaging" },
{ name = "safetensors" },
{ name = "termcolor" }, { name = "termcolor" },
{ name = "torch" }, { name = "torch" },
{ name = "torchvision" }, { name = "torchvision" },
{ name = "tqdm" },
] ]
[package.optional-dependencies] [package.optional-dependencies]
@@ -2251,7 +2253,6 @@ all = [
{ name = "qwen-vl-utils" }, { name = "qwen-vl-utils" },
{ name = "reachy2-sdk" }, { name = "reachy2-sdk" },
{ name = "rerun-sdk" }, { name = "rerun-sdk" },
{ name = "safetensors" },
{ name = "scikit-image" }, { name = "scikit-image" },
{ name = "scipy" }, { name = "scipy" },
{ name = "setuptools" }, { name = "setuptools" },
@@ -2341,7 +2342,6 @@ groot = [
{ name = "ninja" }, { name = "ninja" },
{ name = "peft" }, { name = "peft" },
{ name = "pillow" }, { name = "pillow" },
{ name = "safetensors" },
{ name = "timm" }, { name = "timm" },
{ name = "transformers" }, { name = "transformers" },
] ]
@@ -2466,7 +2466,6 @@ scipy-dep = [
smolvla = [ smolvla = [
{ name = "accelerate" }, { name = "accelerate" },
{ name = "num2words" }, { name = "num2words" },
{ name = "safetensors" },
{ name = "transformers" }, { name = "transformers" },
] ]
test = [ test = [
@@ -2674,8 +2673,7 @@ requires-dist = [
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" }, { name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" }, { name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" }, { name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
{ name = "safetensors", marker = "extra == 'groot'", specifier = ">=0.4.3,<1.0.0" }, { name = "safetensors", specifier = ">=0.4.3,<1.0.0" },
{ name = "safetensors", marker = "extra == 'smolvla'", specifier = ">=0.4.3,<1.0.0" },
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" }, { name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
{ name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" }, { name = "scipy", marker = "extra == 'all'", specifier = ">=1.14.0,<2.0.0" },
{ name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" }, { name = "scipy", marker = "extra == 'scipy-dep'", specifier = ">=1.14.0,<2.0.0" },
@@ -2687,6 +2685,7 @@ requires-dist = [
{ name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" }, { name = "torchcodec", marker = "(platform_machine != 'aarch64' and platform_machine != 'arm64' and platform_machine != 'armv7l' and sys_platform == 'linux' and extra == 'dataset') or (platform_machine != 'x86_64' and sys_platform == 'darwin' and extra == 'dataset') or (sys_platform != 'darwin' and sys_platform != 'linux' and sys_platform != 'win32' and extra == 'dataset')", specifier = ">=0.3.0,<0.11.0" },
{ name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" }, { name = "torchdiffeq", marker = "extra == 'wallx'", specifier = ">=0.2.4,<0.3.0" },
{ name = "torchvision", specifier = ">=0.22.0,<0.26.0" }, { name = "torchvision", specifier = ">=0.22.0,<0.26.0" },
{ name = "tqdm", specifier = ">=4.66.0,<5.0.0" },
{ name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" }, { name = "transformers", marker = "extra == 'transformers-dep'", specifier = "==5.3.0" },
{ name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" }, { name = "wandb", marker = "extra == 'training'", specifier = ">=0.24.0,<0.25.0" },
] ]