refactor: minor improvements

This commit is contained in:
Steven Palma
2026-04-10 18:31:07 +02:00
parent 882a6b0965
commit 4c39981908
6 changed files with 46 additions and 19 deletions
+7 -1
View File
@@ -80,9 +80,15 @@ def get_safe_default_codec():
return "pyav"
_require_package_cache: dict[str, bool] = {}
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
"""Raise an informative ImportError if a package required by an optional feature is missing."""
if not is_package_available(pkg_name, import_name):
cache_key = import_name or pkg_name
if cache_key not in _require_package_cache:
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
if not _require_package_cache[cache_key]:
raise ImportError(
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
)
+3
View File
@@ -58,6 +58,9 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
stacked_frames: List of HWC uint8 numpy arrays (RGB).
fps: Frames per second for the output video.
"""
from lerobot.utils.import_utils import require_package
require_package("av", extra="av-dep")
import av
with av.open(str(video_path), mode="w") as container:
+13 -4
View File
@@ -287,14 +287,23 @@ class SuppressProgressBars:
"""
def __enter__(self):
from datasets.utils.logging import disable_progress_bar
try:
from datasets.utils.logging import disable_progress_bar
disable_progress_bar()
disable_progress_bar()
except ImportError:
logging.getLogger(__name__).info(
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
"Install it with: pip install 'lerobot[dataset]'"
)
def __exit__(self, exc_type, exc_val, exc_tb):
from datasets.utils.logging import enable_progress_bar
try:
from datasets.utils.logging import enable_progress_bar
enable_progress_bar()
enable_progress_bar()
except ImportError:
pass
class TimerManager:
+14 -7
View File
@@ -17,15 +17,10 @@ import os
import numpy as np
from lerobot.types import RobotAction, RobotObservation
from lerobot.utils.import_utils import require_package
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr # noqa: E402
from lerobot.types import RobotAction, RobotObservation # noqa: E402
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
def init_rerun(
@@ -39,6 +34,10 @@ def init_rerun(
ip: Optional IP for connecting to a Rerun server.
port: Optional port for connecting to a Rerun server.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
@@ -51,6 +50,10 @@ def init_rerun(
def shutdown_rerun() -> None:
"""Shuts down the Rerun SDK gracefully."""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
rr.rerun_shutdown()
@@ -83,6 +86,10 @@ def log_rerun_data(
action: An optional dictionary containing action data to log.
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
"""
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
if observation:
for k, v in observation.items():
if v is None: