mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
Compare commits
2 Commits
8aa7343137
...
2ef2370d66
| Author | SHA1 | Date | |
|---|---|---|---|
| 2ef2370d66 | |||
| 723bd71cf2 |
@@ -32,6 +32,11 @@ import pyarrow.parquet as pq
|
||||
import tqdm
|
||||
from datasets import Dataset
|
||||
from huggingface_hub import snapshot_download
|
||||
from lerobot.datasets.io_utils import (
|
||||
load_info,
|
||||
load_tasks,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_PATH,
|
||||
@@ -40,11 +45,8 @@ from lerobot.datasets.utils import (
|
||||
LEGACY_EPISODES_PATH,
|
||||
LEGACY_EPISODES_STATS_PATH,
|
||||
LEGACY_TASKS_PATH,
|
||||
load_info,
|
||||
load_tasks,
|
||||
serialize_dict,
|
||||
unflatten_dict,
|
||||
write_info,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.utils.utils import init_logging
|
||||
@@ -52,8 +54,12 @@ from lerobot.utils.utils import init_logging
|
||||
V21 = "v2.1"
|
||||
V30 = "v3.0"
|
||||
|
||||
LEGACY_DATA_PATH_TEMPLATE = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
LEGACY_VIDEO_PATH_TEMPLATE = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
LEGACY_DATA_PATH_TEMPLATE = (
|
||||
"data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
|
||||
)
|
||||
LEGACY_VIDEO_PATH_TEMPLATE = (
|
||||
"videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
|
||||
)
|
||||
MIN_VIDEO_DURATION = 1e-6
|
||||
LEGACY_STATS_KEYS = ("mean", "std", "min", "max", "count")
|
||||
|
||||
@@ -137,7 +143,9 @@ def convert_info(
|
||||
if ft.get("dtype") != "video":
|
||||
ft.pop("fps", None)
|
||||
|
||||
info["total_chunks"] = math.ceil(total_episodes / chunks_size) if total_episodes > 0 else 0
|
||||
info["total_chunks"] = (
|
||||
math.ceil(total_episodes / chunks_size) if total_episodes > 0 else 0
|
||||
)
|
||||
info["total_videos"] = total_episodes * len(video_keys)
|
||||
|
||||
write_info(info, new_root)
|
||||
@@ -156,14 +164,22 @@ def _group_episodes_by_data_file(
|
||||
return grouped
|
||||
|
||||
|
||||
def convert_data(root: Path, new_root: Path, episode_records: list[dict[str, Any]]) -> None:
|
||||
def convert_data(
|
||||
root: Path, new_root: Path, episode_records: list[dict[str, Any]]
|
||||
) -> None:
|
||||
logging.info("Converting consolidated parquet files back to per-episode files")
|
||||
grouped = _group_episodes_by_data_file(episode_records)
|
||||
|
||||
for (chunk_idx, file_idx), records in tqdm.tqdm(grouped.items(), desc="convert data files"):
|
||||
source_path = root / DEFAULT_DATA_PATH.format(chunk_index=chunk_idx, file_index=file_idx)
|
||||
for (chunk_idx, file_idx), records in tqdm.tqdm(
|
||||
grouped.items(), desc="convert data files"
|
||||
):
|
||||
source_path = root / DEFAULT_DATA_PATH.format(
|
||||
chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
if not source_path.exists():
|
||||
raise FileNotFoundError(f"Expected source parquet file not found: {source_path}")
|
||||
raise FileNotFoundError(
|
||||
f"Expected source parquet file not found: {source_path}"
|
||||
)
|
||||
|
||||
table = pq.read_table(source_path)
|
||||
records = sorted(records, key=lambda rec: int(rec["dataset_from_index"]))
|
||||
@@ -181,7 +197,7 @@ def convert_data(root: Path, new_root: Path, episode_records: list[dict[str, Any
|
||||
f"episode_index={episode_index}, length={length}"
|
||||
)
|
||||
|
||||
episode_table = table.slice(start, length).to_pandas()
|
||||
episode_table = table.slice(start, length)
|
||||
|
||||
dest_chunk = episode_index // DEFAULT_CHUNK_SIZE
|
||||
dest_path = new_root / LEGACY_DATA_PATH_TEMPLATE.format(
|
||||
@@ -189,7 +205,7 @@ def convert_data(root: Path, new_root: Path, episode_records: list[dict[str, Any
|
||||
episode_index=episode_index,
|
||||
)
|
||||
dest_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
Dataset.from_pandas(episode_table).to_parquet(dest_path)
|
||||
Dataset(episode_table).to_parquet(dest_path)
|
||||
|
||||
|
||||
def _group_episodes_by_video_file(
|
||||
@@ -235,10 +251,14 @@ def _validate_video_paths(src: Path, dst: Path) -> None:
|
||||
# Validate file extensions for video files
|
||||
valid_video_extensions = {".mp4", ".avi", ".mov", ".mkv", ".webm", ".m4v"}
|
||||
if src_resolved.suffix.lower() not in valid_video_extensions:
|
||||
raise ValueError(f"Source file does not have a valid video extension: {src_resolved}")
|
||||
raise ValueError(
|
||||
f"Source file does not have a valid video extension: {src_resolved}"
|
||||
)
|
||||
|
||||
if dst_resolved.suffix.lower() not in valid_video_extensions:
|
||||
raise ValueError(f"Destination file does not have a valid video extension: {dst_resolved}")
|
||||
raise ValueError(
|
||||
f"Destination file does not have a valid video extension: {dst_resolved}"
|
||||
)
|
||||
|
||||
# Check for path traversal attempts in the original paths
|
||||
src_str = str(src)
|
||||
@@ -253,11 +273,16 @@ def _validate_video_paths(src: Path, dst: Path) -> None:
|
||||
|
||||
# Additional check: ensure resolved paths don't point to system directories
|
||||
system_dirs = {"/etc", "/sys", "/proc", "/dev", "/boot", "/root"}
|
||||
for resolved_path, name in [(src_resolved, "source"), (dst_resolved, "destination")]:
|
||||
for resolved_path, name in [
|
||||
(src_resolved, "source"),
|
||||
(dst_resolved, "destination"),
|
||||
]:
|
||||
path_str = str(resolved_path)
|
||||
for sys_dir in system_dirs:
|
||||
if path_str.startswith(sys_dir + "/") or path_str == sys_dir:
|
||||
raise ValueError(f"Path points to system directory: {name} path {resolved_path}")
|
||||
raise ValueError(
|
||||
f"Path points to system directory: {name} path {resolved_path}"
|
||||
)
|
||||
|
||||
# Ensure the destination directory can be created safely
|
||||
try:
|
||||
@@ -324,9 +349,13 @@ def _extract_video_segment(
|
||||
text=True,
|
||||
)
|
||||
except subprocess.TimeoutExpired as exc:
|
||||
raise RuntimeError(f"ffmpeg timed out while processing video '{src}' -> '{dst}'") from exc
|
||||
raise RuntimeError(
|
||||
f"ffmpeg timed out while processing video '{src}' -> '{dst}'"
|
||||
) from exc
|
||||
except FileNotFoundError as exc:
|
||||
raise RuntimeError("ffmpeg executable not found; it is required for video conversion") from exc
|
||||
raise RuntimeError(
|
||||
"ffmpeg executable not found; it is required for video conversion"
|
||||
) from exc
|
||||
except subprocess.CalledProcessError as exc:
|
||||
error_msg = f"ffmpeg failed while splitting video '{src}' into '{dst}'"
|
||||
if exc.stderr:
|
||||
@@ -334,7 +363,12 @@ def _extract_video_segment(
|
||||
raise RuntimeError(error_msg) from exc
|
||||
|
||||
|
||||
def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, Any]], video_keys: list[str]) -> None:
|
||||
def convert_videos(
|
||||
root: Path,
|
||||
new_root: Path,
|
||||
episode_records: list[dict[str, Any]],
|
||||
video_keys: list[str],
|
||||
) -> None:
|
||||
if len(video_keys) == 0:
|
||||
logging.info("No video features detected; skipping video conversion")
|
||||
return
|
||||
@@ -347,7 +381,9 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
||||
logging.info("No video metadata found for key '%s'; skipping", video_key)
|
||||
continue
|
||||
|
||||
for (chunk_idx, file_idx), records in tqdm.tqdm(grouped.items(), desc=f"convert videos ({video_key})"):
|
||||
for (chunk_idx, file_idx), records in tqdm.tqdm(
|
||||
grouped.items(), desc=f"convert videos ({video_key})"
|
||||
):
|
||||
src_path = root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=video_key,
|
||||
chunk_index=chunk_idx,
|
||||
@@ -356,7 +392,10 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
||||
if not src_path.exists():
|
||||
raise FileNotFoundError(f"Expected MP4 file not found: {src_path}")
|
||||
|
||||
records = sorted(records, key=lambda rec: float(rec[f"videos/{video_key}/from_timestamp"]))
|
||||
records = sorted(
|
||||
records,
|
||||
key=lambda rec: float(rec[f"videos/{video_key}/from_timestamp"]),
|
||||
)
|
||||
|
||||
for record in records:
|
||||
episode_index = int(record["episode_index"])
|
||||
@@ -373,7 +412,9 @@ def convert_videos(root: Path, new_root: Path, episode_records: list[dict[str, A
|
||||
_extract_video_segment(src_path, dest_path, start=start, end=end)
|
||||
|
||||
|
||||
def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, Any]]) -> None:
|
||||
def convert_episodes_metadata(
|
||||
new_root: Path, episode_records: list[dict[str, Any]]
|
||||
) -> None:
|
||||
logging.info("Reconstructing legacy episodes and episodes_stats JSONL files")
|
||||
|
||||
episodes_path = new_root / LEGACY_EPISODES_PATH
|
||||
@@ -396,7 +437,9 @@ def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, An
|
||||
jsonlines.open(episodes_path, mode="w") as episodes_writer,
|
||||
jsonlines.open(stats_path, mode="w") as stats_writer,
|
||||
):
|
||||
for record in sorted(episode_records, key=lambda rec: int(rec["episode_index"])):
|
||||
for record in sorted(
|
||||
episode_records, key=lambda rec: int(rec["episode_index"])
|
||||
):
|
||||
legacy_episode = {
|
||||
key: value
|
||||
for key, value in record.items()
|
||||
@@ -407,10 +450,14 @@ def convert_episodes_metadata(new_root: Path, episode_records: list[dict[str, An
|
||||
and key not in {"dataset_from_index", "dataset_to_index"}
|
||||
}
|
||||
|
||||
serializable_episode = {key: _to_serializable(value) for key, value in legacy_episode.items()}
|
||||
serializable_episode = {
|
||||
key: _to_serializable(value) for key, value in legacy_episode.items()
|
||||
}
|
||||
episodes_writer.write(serializable_episode)
|
||||
|
||||
stats_flat = {key: record[key] for key in record if key.startswith("stats/")}
|
||||
stats_flat = {
|
||||
key: record[key] for key in record if key.startswith("stats/")
|
||||
}
|
||||
stats_nested = unflatten_dict(stats_flat).get("stats", {})
|
||||
stats_serialized = serialize_dict(_filter_stats(stats_nested))
|
||||
stats_writer.write(
|
||||
@@ -453,7 +500,11 @@ def convert_dataset(
|
||||
new_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
episode_records = load_episode_records(root)
|
||||
video_keys = [key for key, ft in load_info(root)["features"].items() if ft.get("dtype") == "video"]
|
||||
video_keys = [
|
||||
key
|
||||
for key, ft in load_info(root)["features"].items()
|
||||
if ft.get("dtype") == "video"
|
||||
]
|
||||
|
||||
convert_info(root, new_root, episode_records, video_keys)
|
||||
convert_tasks(root, new_root)
|
||||
|
||||
Reference in New Issue
Block a user