mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
7 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 34454748f4 | |||
| 8e5763c5ab | |||
| 388d4518ba | |||
| 232dbe4176 | |||
| 10c2e2fc87 | |||
| 5e74f06b20 | |||
| 07931b1101 |
@@ -486,8 +486,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_data_index = get_episode_data_index(self.meta.episodes, self.episodes)
|
||||
|
||||
# Check timestamps
|
||||
timestamps = torch.stack(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.stack(self.hf_dataset["episode_index"]).numpy()
|
||||
timestamps = torch.tensor(self.hf_dataset["timestamp"]).numpy()
|
||||
episode_indices = torch.tensor(self.hf_dataset["episode_index"]).numpy()
|
||||
ep_data_index_np = {k: t.numpy() for k, t in self.episode_data_index.items()}
|
||||
check_timestamps_sync(timestamps, episode_indices, ep_data_index_np, self.fps, self.tolerance_s)
|
||||
|
||||
@@ -667,7 +667,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset.select(query_indices[key])["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
query_timestamps[key] = torch.tensor(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
|
||||
@@ -675,7 +675,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset.select(q_idx)[key])
|
||||
key: torch.tensor(self.hf_dataset.select(q_idx)[key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
|
||||
@@ -632,7 +632,7 @@ def cycle(iterable):
|
||||
iterator = iter(iterable)
|
||||
|
||||
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None:
|
||||
def create_branch(repo_id, *, branch: str, repo_type: str | None = None, revision: str | None = None) -> None:
|
||||
"""Create a branch on a existing Hugging Face repo. Delete the branch if it already
|
||||
exists before creating it.
|
||||
"""
|
||||
@@ -644,7 +644,7 @@ def create_branch(repo_id, *, branch: str, repo_type: str | None = None) -> None
|
||||
if ref in refs:
|
||||
api.delete_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch)
|
||||
api.create_branch(repo_id, repo_type=repo_type, branch=branch, revision=revision)
|
||||
|
||||
|
||||
def create_lerobot_dataset_card(
|
||||
|
||||
@@ -105,6 +105,7 @@ import filecmp
|
||||
import json
|
||||
import logging
|
||||
import math
|
||||
import re
|
||||
import shutil
|
||||
import subprocess
|
||||
import tempfile
|
||||
@@ -119,6 +120,7 @@ from huggingface_hub import HfApi
|
||||
from huggingface_hub.errors import EntryNotFoundError, HfHubHTTPError
|
||||
from safetensors.torch import load_file
|
||||
|
||||
from lerobot.datasets.backward_compatibility import CompatibilityError
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_PARQUET_PATH,
|
||||
@@ -130,6 +132,7 @@ from lerobot.datasets.utils import (
|
||||
create_branch,
|
||||
create_lerobot_dataset_card,
|
||||
flatten_dict,
|
||||
get_repo_versions,
|
||||
get_safe_version,
|
||||
load_json,
|
||||
unflatten_dict,
|
||||
@@ -205,7 +208,7 @@ def convert_stats_to_json(v1_dir: Path, v2_dir: Path) -> None:
|
||||
def get_features_from_hf_dataset(
|
||||
dataset: Dataset, robot_config: RobotConfig | None = None
|
||||
) -> dict[str, list]:
|
||||
robot_config = parse_robot_config(robot_config)
|
||||
robot_config = parse_robot_config(robot_config) if robot_config else None
|
||||
features = {}
|
||||
for key, ft in dataset.features.items():
|
||||
if isinstance(ft, datasets.Value):
|
||||
@@ -325,7 +328,19 @@ def move_videos(
|
||||
video_files = [str(f.relative_to(work_dir)) for f in work_dir.glob("videos*/*/*/*.mp4")]
|
||||
videos_moved = True # Videos have already been moved
|
||||
|
||||
assert len(video_files) == total_episodes * len(video_keys)
|
||||
expected_count = total_episodes * len(video_keys)
|
||||
if len(video_files) != expected_count:
|
||||
print(
|
||||
f"Warning: expected {expected_count} video files "
|
||||
f"({total_episodes} episodes x {len(video_keys)} keys), "
|
||||
f"found {len(video_files)}. Keeping only videos matching existing episodes."
|
||||
)
|
||||
episode_pattern = re.compile(r"episode_(\d+)")
|
||||
valid_episodes = set(range(total_episodes))
|
||||
video_files = [
|
||||
f for f in video_files
|
||||
if (m := episode_pattern.search(f)) and int(m.group(1)) in valid_episodes
|
||||
]
|
||||
|
||||
lfs_untracked_videos = _get_lfs_untracked_videos(work_dir, video_files)
|
||||
|
||||
@@ -442,8 +457,16 @@ def convert_dataset(
|
||||
test_branch: str | None = None,
|
||||
**card_kwargs,
|
||||
):
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
v1x_dir = local_dir / V16 / repo_id
|
||||
try:
|
||||
v1 = get_safe_version(repo_id, V16)
|
||||
except CompatibilityError:
|
||||
hub_versions = get_repo_versions(repo_id)
|
||||
v1x_versions = [v for v in hub_versions if v.major == 1]
|
||||
if not v1x_versions:
|
||||
raise
|
||||
v1 = f"v{max(v1x_versions)}"
|
||||
logging.warning(f"v1.6 not found for {repo_id}, falling back to {v1}")
|
||||
v1x_dir = local_dir / v1 / repo_id
|
||||
v20_dir = local_dir / V20 / repo_id
|
||||
v1x_dir.mkdir(parents=True, exist_ok=True)
|
||||
v20_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -455,7 +478,7 @@ def convert_dataset(
|
||||
branch = "main"
|
||||
if test_branch:
|
||||
branch = test_branch
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset")
|
||||
create_branch(repo_id=repo_id, branch=test_branch, repo_type="dataset", revision=v1)
|
||||
|
||||
metadata_v1 = load_json(v1x_dir / V1_INFO_PATH)
|
||||
dataset = datasets.load_dataset("parquet", data_dir=v1x_dir / "data", split="train")
|
||||
@@ -564,6 +587,12 @@ def convert_dataset(
|
||||
"features": features,
|
||||
}
|
||||
write_json(metadata_v2_0, v20_dir / INFO_PATH)
|
||||
|
||||
info = load_json(v20_dir / INFO_PATH)
|
||||
if "language_instruction" in info.get("features", {}):
|
||||
del info["features"]["language_instruction"]
|
||||
write_json(info, v20_dir / INFO_PATH)
|
||||
|
||||
convert_stats_to_json(v1x_dir, v20_dir)
|
||||
card = create_lerobot_dataset_card(tags=repo_tags, dataset_info=metadata_v2_0, **card_kwargs)
|
||||
|
||||
@@ -677,6 +706,8 @@ def main():
|
||||
|
||||
if args.robot is not None:
|
||||
robot_config = make_robot_config(args.robot)
|
||||
else:
|
||||
robot_config = None
|
||||
|
||||
del args.robot
|
||||
|
||||
|
||||
@@ -85,7 +85,7 @@ def convert_dataset(
|
||||
path_in_repo=STATS_PATH, repo_id=dataset.repo_id, revision=branch, repo_type="dataset"
|
||||
)
|
||||
|
||||
hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
#hub_api.create_tag(repo_id, tag=CODEBASE_VERSION, revision=branch, repo_type="dataset")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -45,6 +45,8 @@ def convert_episode_stats(dataset: LeRobotDataset, ep_idx: int):
|
||||
|
||||
axes_to_reduce = (0, 2, 3) if ft["dtype"] in ["image", "video"] else 0
|
||||
keepdims = True if ft["dtype"] in ["image", "video"] else ep_ft_data.ndim == 1
|
||||
if ft["dtype"] in ["image", "video"] and ep_ft_data.ndim == 3:
|
||||
ep_ft_data = np.expand_dims(ep_ft_data, axis=0)
|
||||
ep_stats[key] = get_feature_stats(ep_ft_data, axis=axes_to_reduce, keepdims=keepdims)
|
||||
|
||||
if ft["dtype"] in ["image", "video"]: # remove batch dim
|
||||
|
||||
Reference in New Issue
Block a user