mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
Merge branch 'main' into feat/audio_dataset
This commit is contained in:
Vendored
+7
-6
@@ -26,7 +26,10 @@ import pytest
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.feature_utils import get_hf_features_from_features
|
||||
from lerobot.datasets.io_utils import hf_transform_to_torch
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_AUDIO_PATH,
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
@@ -36,8 +39,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_PATH,
|
||||
flatten_dict,
|
||||
get_hf_features_from_features,
|
||||
hf_transform_to_torch,
|
||||
)
|
||||
from lerobot.datasets.video_utils import encode_video_frames
|
||||
from tests.fixtures.constants import (
|
||||
@@ -239,7 +240,7 @@ def tasks_factory():
|
||||
def _create_tasks(total_tasks: int = 3) -> pd.DataFrame:
|
||||
ids = list(range(total_tasks))
|
||||
tasks = [f"Perform action {i}." for i in ids]
|
||||
df = pd.DataFrame({"task_index": ids}, index=tasks)
|
||||
df = pd.DataFrame({"task_index": ids}, index=pd.Index(tasks, name="task"))
|
||||
return df
|
||||
|
||||
return _create_tasks
|
||||
@@ -470,8 +471,8 @@ def lerobot_dataset_metadata_factory(
|
||||
episodes=episodes,
|
||||
)
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download_patch,
|
||||
patch("lerobot.datasets.dataset_metadata.get_safe_version") as mock_get_safe_version_patch,
|
||||
patch("lerobot.datasets.dataset_metadata.snapshot_download") as mock_snapshot_download_patch,
|
||||
):
|
||||
mock_get_safe_version_patch.side_effect = lambda repo_id, version: version
|
||||
mock_snapshot_download_patch.side_effect = mock_snapshot_download
|
||||
|
||||
Vendored
+7
-5
@@ -20,17 +20,19 @@ import pandas as pd
|
||||
import pytest
|
||||
from datasets import Dataset
|
||||
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
from lerobot.datasets.io_utils import (
|
||||
get_hf_dataset_size_in_mb,
|
||||
update_chunk_file_indices,
|
||||
write_episodes,
|
||||
write_info,
|
||||
write_stats,
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_DATA_PATH,
|
||||
update_chunk_file_indices,
|
||||
)
|
||||
|
||||
|
||||
def write_hf_dataset(
|
||||
|
||||
Reference in New Issue
Block a user