mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 16:19:45 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34454748f4 | |||
| 8e5763c5ab | |||
| 388d4518ba | |||
| 232dbe4176 | |||
| 10c2e2fc87 | |||
| 5e74f06b20 | |||
| 07931b1101 |
@@ -486,8 +486,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||||
|
|
||||||
# Check timestamps
|
# Check timestamps
|
||||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
timestamps = torch.tensor(self.hf_dataset["timestamp"]).numpy()
|
||||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
episode_indices = torch.tensor(self.hf_dataset["episode_index"]).numpy()
|
||||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||||
|
|
||||||
@@ -667,7 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
for key in self.meta.video_keys:
|
for key in self.meta.video_keys:
|
||||||
if query_indices is not None and key in query_indices:
|
if query_indices is not None and key in query_indices:
|
||||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
query_timestamps[key] = torch.tensor(timestamps).tolist()
|
||||||
else:
|
else:
|
||||||
query_timestamps[key] = [current_ts]
|
query_timestamps[key] = [current_ts]
|
||||||
|
|
||||||
@@ -675,7 +675,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||||
return {
|
return {
|
||||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
key: torch.tensor(self.hf_dataset.select(q_idx)[key])
|
||||||
for key, q_idx in query_indices.items()
|
for key, q_idx in query_indices.items()
|
||||||
if key not in self.meta.video_keys
|
if key not in self.meta.video_keys
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -632,7 +632,7 @@ def cycle(iterable):
|
|||||||
iterator = iter(iterable)
|
iterator = iter(iterable)
|
||||||
|
|
||||||
|
|
||||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
def create_branch(repo_id, *, branch: str, repo_type: str | None = None, revision: str | None = None) -> None:
|
||||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||||
exists before creating it.
|
exists before creating it.
|
||||||
"""
|
"""
|
||||||
@@ -644,7 +644,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
|||||||
if ref in refs:
|
if ref in refs:
|
||||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||||
|
|
||||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
api.create_branch(repo_id, repo_type=repo_type, branch=branch, revision=revision)
|
||||||
|
|
||||||
|
|
||||||
def create_lerobot_dataset_card(
|
def create_lerobot_dataset_card(
|
||||||
|
|||||||
@@ -105,6 +105,7 @@ import filecmp
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
|
import re
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
import tempfile
|
import tempfile
|
||||||
@@ -119,6 +120,7 @@ from huggingface_hub import HfApi
|
|||||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
|
|
||||||
|
from lerobot.datasets.backward_compatibility import CompatibilityError
|
||||||
from lerobot.datasets.utils import (
|
from lerobot.datasets.utils import (
|
||||||
DEFAULT_CHUNK_SIZE,
|
DEFAULT_CHUNK_SIZE,
|
||||||
DEFAULT_PARQUET_PATH,
|
DEFAULT_PARQUET_PATH,
|
||||||
@@ -130,6 +132,7 @@ from lerobot.datasets.utils import (
|
|||||||
create_branch,
|
create_branch,
|
||||||
create_lerobot_dataset_card,
|
create_lerobot_dataset_card,
|
||||||
flatten_dict,
|
flatten_dict,
|
||||||
|
get_repo_versions,
|
||||||
get_safe_version,
|
get_safe_version,
|
||||||
load_json,
|
load_json,
|
||||||
unflatten_dict,
|
unflatten_dict,
|
||||||
@@ -205,7 +208,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
|||||||
def get_features_from_hf_dataset(
|
def get_features_from_hf_dataset(
|
||||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||||
) -> dict[str, list]:
|
) -> dict[str, list]:
|
||||||
robot_config = parse_robot_config(robot_config)
|
robot_config = parse_robot_config(robot_config) if robot_config else None
|
||||||
features = {}
|
features = {}
|
||||||
for key, ft in dataset.features.items():
|
for key, ft in dataset.features.items():
|
||||||
if isinstance(ft, datasets.Value):
|
if isinstance(ft, datasets.Value):
|
||||||
@@ -325,7 +328,19 @@ def move_videos(
|
|||||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||||
videos_moved = True # Videos have already been moved
|
videos_moved = True # Videos have already been moved
|
||||||
|
|
||||||
assert len(video_files) == total_episodes * len(video_keys)
|
expected_count = total_episodes * len(video_keys)
|
||||||
|
if len(video_files) != expected_count:
|
||||||
|
print(
|
||||||
|
f"Warning: expected {expected_count} video files "
|
||||||
|
f"({total_episodes} episodes x {len(video_keys)} keys), "
|
||||||
|
f"found {len(video_files)}. Keeping only videos matching existing episodes."
|
||||||
|
)
|
||||||
|
episode_pattern = re.compile(r"episode_(\d+)")
|
||||||
|
valid_episodes = set(range(total_episodes))
|
||||||
|
video_files = [
|
||||||
|
f for f in video_files
|
||||||
|
if (m := episode_pattern.search(f)) and int(m.group(1)) in valid_episodes
|
||||||
|
]
|
||||||
|
|
||||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||||
|
|
||||||
@@ -442,8 +457,16 @@ def convert_dataset(
|
|||||||
test_branch: str | None = None,
|
test_branch: str | None = None,
|
||||||
**card_kwargs,
|
**card_kwargs,
|
||||||
):
|
):
|
||||||
v1 = get_safe_version(repo_id, V16)
|
try:
|
||||||
v1x_dir = local_dir / V16 / repo_id
|
v1 = get_safe_version(repo_id, V16)
|
||||||
|
except CompatibilityError:
|
||||||
|
hub_versions = get_repo_versions(repo_id)
|
||||||
|
v1x_versions = [v for v in hub_versions if v.major == 1]
|
||||||
|
if not v1x_versions:
|
||||||
|
raise
|
||||||
|
v1 = f"v{max(v1x_versions)}"
|
||||||
|
logging.warning(f"v1.6 not found for {repo_id}, falling back to {v1}")
|
||||||
|
v1x_dir = local_dir / v1 / repo_id
|
||||||
v20_dir = local_dir / V20 / repo_id
|
v20_dir = local_dir / V20 / repo_id
|
||||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||||
@@ -455,7 +478,7 @@ def convert_dataset(
|
|||||||
branch = "main"
|
branch = "main"
|
||||||
if test_branch:
|
if test_branch:
|
||||||
branch = test_branch
|
branch = test_branch
|
||||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset", revision=v1)
|
||||||
|
|
||||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||||
@@ -564,6 +587,12 @@ def convert_dataset(
|
|||||||
"features": features,
|
"features": features,
|
||||||
}
|
}
|
||||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||||
|
|
||||||
|
info = load_json(v20_dir / INFO_PATH)
|
||||||
|
if "language_instruction" in info.get("features", {}):
|
||||||
|
del info["features"]["language_instruction"]
|
||||||
|
write_json(info, v20_dir / INFO_PATH)
|
||||||
|
|
||||||
convert_stats_to_json(v1x_dir, v20_dir)
|
convert_stats_to_json(v1x_dir, v20_dir)
|
||||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||||
|
|
||||||
@@ -677,6 +706,8 @@ def main():
|
|||||||
|
|
||||||
if args.robot is not None:
|
if args.robot is not None:
|
||||||
robot_config = make_robot_config(args.robot)
|
robot_config = make_robot_config(args.robot)
|
||||||
|
else:
|
||||||
|
robot_config = None
|
||||||
|
|
||||||
del args.robot
|
del args.robot
|
||||||
|
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ def convert_dataset(
|
|||||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||||
)
|
)
|
||||||
|
|
||||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
#hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
|||||||
@@ -45,6 +45,8 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
|||||||
|
|
||||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||||
|
if ft["dtype"] in ["image", "video"] and ep_ft_data.ndim == 3:
|
||||||
|
ep_ft_data = np.expand_dims(ep_ft_data, axis=0)
|
||||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||||
|
|
||||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||||
|
|||||||
Reference in New Issue
Block a user