mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-26 22:20:06 +00:00
refactor: minor improvements
This commit is contained in:
@@ -80,9 +80,15 @@ def get_safe_default_codec():
|
||||
return "pyav"
|
||||
|
||||
|
||||
_require_package_cache: dict[str, bool] = {}
|
||||
|
||||
|
||||
def require_package(pkg_name: str, extra: str, import_name: str | None = None) -> None:
|
||||
"""Raise an informative ImportError if a package required by an optional feature is missing."""
|
||||
if not is_package_available(pkg_name, import_name):
|
||||
cache_key = import_name or pkg_name
|
||||
if cache_key not in _require_package_cache:
|
||||
_require_package_cache[cache_key] = is_package_available(pkg_name, import_name)
|
||||
if not _require_package_cache[cache_key]:
|
||||
raise ImportError(
|
||||
f"'{pkg_name}' is required but not installed. Install it with: pip install 'lerobot[{extra}]'"
|
||||
)
|
||||
|
||||
@@ -58,6 +58,9 @@ def write_video(video_path: str | Path, stacked_frames: list, fps: int) -> None:
|
||||
stacked_frames: List of HWC uint8 numpy arrays (RGB).
|
||||
fps: Frames per second for the output video.
|
||||
"""
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("av", extra="av-dep")
|
||||
import av
|
||||
|
||||
with av.open(str(video_path), mode="w") as container:
|
||||
|
||||
@@ -287,14 +287,23 @@ class SuppressProgressBars:
|
||||
"""
|
||||
|
||||
def __enter__(self):
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
try:
|
||||
from datasets.utils.logging import disable_progress_bar
|
||||
|
||||
disable_progress_bar()
|
||||
disable_progress_bar()
|
||||
except ImportError:
|
||||
logging.getLogger(__name__).info(
|
||||
"SuppressProgressBars is a no-op because 'datasets' is not installed. "
|
||||
"Install it with: pip install 'lerobot[dataset]'"
|
||||
)
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
from datasets.utils.logging import enable_progress_bar
|
||||
try:
|
||||
from datasets.utils.logging import enable_progress_bar
|
||||
|
||||
enable_progress_bar()
|
||||
enable_progress_bar()
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
|
||||
class TimerManager:
|
||||
|
||||
@@ -17,15 +17,10 @@ import os
|
||||
|
||||
import numpy as np
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation
|
||||
from lerobot.utils.import_utils import require_package
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
|
||||
import rerun as rr # noqa: E402
|
||||
|
||||
from lerobot.types import RobotAction, RobotObservation # noqa: E402
|
||||
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR # noqa: E402
|
||||
from .constants import ACTION, ACTION_PREFIX, OBS_PREFIX, OBS_STR
|
||||
|
||||
|
||||
def init_rerun(
|
||||
@@ -39,6 +34,10 @@ def init_rerun(
|
||||
ip: Optional IP for connecting to a Rerun server.
|
||||
port: Optional port for connecting to a Rerun server.
|
||||
"""
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
|
||||
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
|
||||
rr.init(session_name)
|
||||
@@ -51,6 +50,10 @@ def init_rerun(
|
||||
|
||||
def shutdown_rerun() -> None:
|
||||
"""Shuts down the Rerun SDK gracefully."""
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
rr.rerun_shutdown()
|
||||
|
||||
|
||||
@@ -83,6 +86,10 @@ def log_rerun_data(
|
||||
action: An optional dictionary containing action data to log.
|
||||
compress_images: Whether to compress images before logging to save bandwidth & memory in exchange for cpu and quality.
|
||||
"""
|
||||
|
||||
require_package("rerun-sdk", extra="viz", import_name="rerun")
|
||||
import rerun as rr
|
||||
|
||||
if observation:
|
||||
for k, v in observation.items():
|
||||
if v is None:
|
||||
|
||||
Reference in New Issue
Block a user