mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 03:30:10 +00:00
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
This commit is contained in:
@@ -49,7 +49,6 @@ from lerobot.common.datasets.utils import (
|
|||||||
embed_images,
|
embed_images,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
get_delta_indices,
|
get_delta_indices,
|
||||||
get_features_from_robot,
|
|
||||||
get_hf_dataset_size_in_mb,
|
get_hf_dataset_size_in_mb,
|
||||||
get_hf_features_from_features,
|
get_hf_features_from_features,
|
||||||
get_parquet_file_size_in_mb,
|
get_parquet_file_size_in_mb,
|
||||||
@@ -80,7 +79,6 @@ from lerobot.common.datasets.video_utils import (
|
|||||||
get_safe_default_codec,
|
get_safe_default_codec,
|
||||||
get_video_info,
|
get_video_info,
|
||||||
)
|
)
|
||||||
from lerobot.common.robots.utils import Robot
|
|
||||||
|
|
||||||
CODEBASE_VERSION = "v3.0"
|
CODEBASE_VERSION = "v3.0"
|
||||||
|
|
||||||
|
|||||||
@@ -840,7 +840,7 @@ def validate_frame(frame: dict, features: dict):
|
|||||||
|
|
||||||
# Remove task from actual_features for regular feature validation
|
# Remove task from actual_features for regular feature validation
|
||||||
actual_features_for_validation = actual_features - {"task"}
|
actual_features_for_validation = actual_features - {"task"}
|
||||||
|
|
||||||
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
error_message = validate_features_presence(actual_features_for_validation, expected_features)
|
||||||
|
|
||||||
common_features = actual_features_for_validation & expected_features
|
common_features = actual_features_for_validation & expected_features
|
||||||
|
|||||||
@@ -34,15 +34,15 @@ from lerobot.common.datasets.lerobot_dataset import (
|
|||||||
)
|
)
|
||||||
from lerobot.common.datasets.utils import (
|
from lerobot.common.datasets.utils import (
|
||||||
create_branch,
|
create_branch,
|
||||||
|
hw_to_dataset_features,
|
||||||
)
|
)
|
||||||
from lerobot.common.envs.factory import make_env_config
|
from lerobot.common.envs.factory import make_env_config
|
||||||
from lerobot.common.policies.factory import make_policy_config
|
from lerobot.common.policies.factory import make_policy_config
|
||||||
from lerobot.common.robots import make_robot_from_config
|
from lerobot.common.robots import make_robot_from_config
|
||||||
from lerobot.common.datasets.utils import hw_to_dataset_features
|
|
||||||
from tests.mocks.mock_robot import MockRobotConfig
|
|
||||||
from lerobot.configs.default import DatasetConfig
|
from lerobot.configs.default import DatasetConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
from tests.fixtures.constants import DUMMY_CHW, DUMMY_HWC, DUMMY_REPO_ID
|
||||||
|
from tests.mocks.mock_robot import MockRobotConfig
|
||||||
from tests.utils import require_x86_64_kernel
|
from tests.utils import require_x86_64_kernel
|
||||||
|
|
||||||
|
|
||||||
@@ -73,7 +73,9 @@ def test_same_attributes_defined(tmp_path, lerobot_dataset_factory):
|
|||||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation", True)
|
||||||
dataset_features = {**action_features, **obs_features}
|
dataset_features = {**action_features, **obs_features}
|
||||||
root_create = tmp_path / "create"
|
root_create = tmp_path / "create"
|
||||||
dataset_create = LeRobotDataset.create(repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create)
|
dataset_create = LeRobotDataset.create(
|
||||||
|
repo_id=DUMMY_REPO_ID, fps=30, features=dataset_features, root=root_create
|
||||||
|
)
|
||||||
|
|
||||||
root_init = tmp_path / "init"
|
root_init = tmp_path / "init"
|
||||||
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
dataset_init = lerobot_dataset_factory(root=root_init, total_episodes=1, total_frames=1)
|
||||||
|
|||||||
Vendored
+6
-2
@@ -474,7 +474,9 @@ def lerobot_dataset_metadata_factory(
|
|||||||
)
|
)
|
||||||
with (
|
with (
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
patch(
|
||||||
|
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||||
|
) as mock_snapshot_download_patch,
|
||||||
):
|
):
|
||||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||||
@@ -556,7 +558,9 @@ def lerobot_dataset_factory(
|
|||||||
with (
|
with (
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
patch("lerobot.common.datasets.lerobot_dataset.LeRobotDatasetMetadata") as mock_metadata_patch,
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
patch("lerobot.common.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||||
patch("lerobot.common.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
patch(
|
||||||
|
"lerobot.common.datasets.lerobot_dataset.snapshot_download"
|
||||||
|
) as mock_snapshot_download_patch,
|
||||||
):
|
):
|
||||||
mock_metadata_patch.return_value = mock_metadata
|
mock_metadata_patch.return_value = mock_metadata
|
||||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||||
|
|||||||
Reference in New Issue
Block a user