mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
feat(dependencies): minimal default tag install (#3362)
This commit is contained in:
@@ -17,6 +17,16 @@
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from lerobot.common.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.constants import (
|
||||
CHECKPOINTS_DIR,
|
||||
LAST_CHECKPOINT_LINK,
|
||||
@@ -27,16 +37,6 @@ from lerobot.utils.constants import (
|
||||
TRAINING_STATE_DIR,
|
||||
TRAINING_STEP,
|
||||
)
|
||||
from lerobot.utils.train_utils import (
|
||||
get_step_checkpoint_dir,
|
||||
get_step_identifier,
|
||||
load_training_state,
|
||||
load_training_step,
|
||||
save_checkpoint,
|
||||
save_training_state,
|
||||
save_training_step,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
|
||||
|
||||
def test_get_step_identifier():
|
||||
@@ -72,7 +72,7 @@ def test_update_last_checkpoint(tmp_path):
|
||||
assert last_checkpoint.resolve() == checkpoint
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
cfg = Mock()
|
||||
@@ -82,7 +82,7 @@ def test_save_checkpoint(mock_save_training_state, tmp_path, optimizer):
|
||||
mock_save_training_state.assert_called_once()
|
||||
|
||||
|
||||
@patch("lerobot.utils.train_utils.save_training_state")
|
||||
@patch("lerobot.common.train_utils.save_training_state")
|
||||
def test_save_checkpoint_peft(mock_save_training_state, tmp_path, optimizer):
|
||||
policy = Mock()
|
||||
policy.config = Mock()
|
||||
|
||||
Reference in New Issue
Block a user