mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -22,7 +22,9 @@ from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
pytest.importorskip("grpc")
|
||||
|
||||
from lerobot.rl.process import ProcessSignalHandler # noqa: E402
|
||||
|
||||
|
||||
# Fixture to reset shutdown_event_counter and original signal handlers before and after each test
|
||||
|
||||
@@ -18,12 +18,16 @@ import sys
|
||||
from collections.abc import Callable
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID
|
||||
pytest.importorskip("grpc")
|
||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||
|
||||
import torch # noqa: E402
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
|
||||
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
|
||||
from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
|
||||
|
||||
|
||||
def state_dims() -> list[str]:
|
||||
|
||||
@@ -17,6 +17,16 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_get_step_identifier():
|
||||
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
mock_save_training_state.assert_called_once()
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
|
||||
@@ -21,6 +21,8 @@ from types import SimpleNamespace
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
pytest.importorskip("rerun", reason="rerun-sdk is required (install lerobot[viz])")
|
||||
|
||||
from lerobot.types import TransitionKey
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
@@ -48,6 +50,9 @@ def mock_rerun(monkeypatch):
|
||||
calls.append((key, obj, kwargs))
|
||||
|
||||
dummy_rr = SimpleNamespace(
|
||||
__name__="rerun",
|
||||
__package__="rerun",
|
||||
__spec__=SimpleNamespace(name="rerun", submodule_search_locations=None),
|
||||
Scalars=DummyScalar,
|
||||
Image=DummyImage,
|
||||
log=dummy_log,
|
||||
|
||||
Reference in New Issue
Block a user