This commit is contained in:
Jade Choghari
2025-12-01 13:47:15 +01:00
parent d22fa6446b
commit 8d861fe94b
2 changed files with 34 additions and 29 deletions
+1
View File
@@ -124,6 +124,7 @@ lerobot-edit-dataset \
``` ```
**Parameters:** **Parameters:**
- `output_dir`: Directory where videos will be saved (default: `outputs/converted_videos`) - `output_dir`: Directory where videos will be saved (default: `outputs/converted_videos`)
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`) - `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`) - `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
+33 -29
View File
@@ -18,7 +18,7 @@
Edit LeRobot datasets using various transformation tools. Edit LeRobot datasets using various transformation tools.
This script allows you to delete episodes, split datasets, merge datasets, This script allows you to delete episodes, split datasets, merge datasets,
remove features, and convert image datasets to video format. remove features, and convert image datasets to video format.
When new_repo_id is specified, creates a new dataset. When new_repo_id is specified, creates a new dataset.
Usage Examples: Usage Examples:
@@ -293,7 +293,7 @@ def save_episode_images(
num_workers: int = 4, num_workers: int = 4,
) -> None: ) -> None:
"""Save images from a specific episode to disk. """Save images from a specific episode to disk.
Args: Args:
dataset: The LeRobot dataset to extract images from dataset: The LeRobot dataset to extract images from
imgs_dir: Directory to save images to imgs_dir: Directory to save images to
@@ -302,35 +302,35 @@ def save_episode_images(
num_workers: Number of threads for parallel image saving (default: 4) num_workers: Number of threads for parallel image saving (default: 4)
""" """
ep_num_images = dataset.meta.episodes["length"][episode_index] ep_num_images = dataset.meta.episodes["length"][episode_index]
# Check if images already exist # Check if images already exist
if not overwrite and imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images: if not overwrite and imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
logging.info(f"Images for episode {episode_index} already exist in {imgs_dir}. Skipping.") logging.info(f"Images for episode {episode_index} already exist in {imgs_dir}. Skipping.")
return return
# Create directory # Create directory
imgs_dir.mkdir(parents=True, exist_ok=True) imgs_dir.mkdir(parents=True, exist_ok=True)
# Get dataset without torch format for PIL image access # Get dataset without torch format for PIL image access
hf_dataset = dataset.hf_dataset.with_format(None) hf_dataset = dataset.hf_dataset.with_format(None)
# Get all image keys (for all cameras) # Get all image keys (for all cameras)
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)] img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
if len(img_keys) == 0: if len(img_keys) == 0:
raise ValueError(f"No image keys found in dataset {dataset.repo_id}") raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
# Use first camera only # Use first camera only
img_key = img_keys[0] img_key = img_keys[0]
imgs_dataset = hf_dataset.select_columns(img_key) imgs_dataset = hf_dataset.select_columns(img_key)
# Get episode start and end indices # Get episode start and end indices
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index] from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index] to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
# Get all items for this episode # Get all items for this episode
episode_dataset = imgs_dataset.select(range(from_idx, to_idx)) episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
# Define function to save a single image # Define function to save a single image
def save_single_image(i_item_tuple): def save_single_image(i_item_tuple):
i, item = i_item_tuple i, item = i_item_tuple
@@ -338,11 +338,11 @@ def save_episode_images(
# Use frame-XXXXXX.png format to match encode_video_frames expectations # Use frame-XXXXXX.png format to match encode_video_frames expectations
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100) img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
return i return i
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png) # Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
# Use ThreadPoolExecutor for parallel processing # Use ThreadPoolExecutor for parallel processing
items = list(enumerate(episode_dataset)) items = list(enumerate(episode_dataset))
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [executor.submit(save_single_image, item) for item in items] futures = [executor.submit(save_single_image, item) for item in items]
for future in tqdm( for future in tqdm(
@@ -368,7 +368,7 @@ def process_single_episode(
overwrite: bool, overwrite: bool,
) -> str: ) -> str:
"""Process a single episode: save images and encode to video. """Process a single episode: save images and encode to video.
Args: Args:
dataset: The LeRobot dataset dataset: The LeRobot dataset
episode_index: Index of the episode to process episode_index: Index of the episode to process
@@ -381,24 +381,26 @@ def process_single_episode(
fps: Frames per second fps: Frames per second
num_image_workers: Number of threads for parallel image saving num_image_workers: Number of threads for parallel image saving
overwrite: Whether to overwrite existing files overwrite: Whether to overwrite existing files
Returns: Returns:
Status message for this episode Status message for this episode
""" """
# Create paths # Create paths
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") / f"episode_{episode_index:06d}" imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_") / f"episode_{episode_index:06d}"
# Create video filename with encoding parameters # Create video filename with encoding parameters
video_filename = f"{dataset.repo_id.replace('/', '_')}_ep{episode_index:06d}_{vcodec}_{pix_fmt}_g{g}_crf{crf}.mp4" video_filename = (
f"{dataset.repo_id.replace('/', '_')}_ep{episode_index:06d}_{vcodec}_{pix_fmt}_g{g}_crf{crf}.mp4"
)
video_path = output_dir / "videos" / dataset.repo_id.replace("/", "_") / video_filename video_path = output_dir / "videos" / dataset.repo_id.replace("/", "_") / video_filename
# Save episode images # Save episode images
save_episode_images(dataset, imgs_dir, episode_index, overwrite, num_image_workers) save_episode_images(dataset, imgs_dir, episode_index, overwrite, num_image_workers)
# Encode to video # Encode to video
if overwrite or not video_path.is_file(): if overwrite or not video_path.is_file():
video_path.parent.mkdir(parents=True, exist_ok=True) video_path.parent.mkdir(parents=True, exist_ok=True)
encode_video_frames( encode_video_frames(
imgs_dir=imgs_dir, imgs_dir=imgs_dir,
video_path=video_path, video_path=video_path,
@@ -410,7 +412,7 @@ def process_single_episode(
fast_decode=fast_decode, fast_decode=fast_decode,
overwrite=True, overwrite=True,
) )
return f"✓ Video saved to {video_path}" return f"✓ Video saved to {video_path}"
else: else:
return f"Video already exists: {video_path}. Skipping." return f"Video already exists: {video_path}. Skipping."
@@ -429,7 +431,7 @@ def convert_dataset_to_videos(
overwrite: bool = False, overwrite: bool = False,
) -> None: ) -> None:
"""Convert dataset images to video files. """Convert dataset images to video files.
Args: Args:
dataset: The LeRobot dataset dataset: The LeRobot dataset
output_dir: Base directory for outputs output_dir: Base directory for outputs
@@ -447,16 +449,18 @@ def convert_dataset_to_videos(
raise ValueError( raise ValueError(
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}" f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
) )
fps = dataset.fps fps = dataset.fps
# Determine which episodes to process # Determine which episodes to process
num_episodes = len(dataset.meta.episodes) num_episodes = len(dataset.meta.episodes)
if episode_indices is None: if episode_indices is None:
episode_indices = list(range(num_episodes)) episode_indices = list(range(num_episodes))
logging.info(f"Processing {len(episode_indices)} episodes from {dataset.repo_id} with {num_workers} workers") logging.info(
f"Processing {len(episode_indices)} episodes from {dataset.repo_id} with {num_workers} workers"
)
# Process episodes in parallel # Process episodes in parallel
with ThreadPoolExecutor(max_workers=num_workers) as executor: with ThreadPoolExecutor(max_workers=num_workers) as executor:
futures = [ futures = [
@@ -476,7 +480,7 @@ def convert_dataset_to_videos(
) )
for episode_index in episode_indices for episode_index in episode_indices
] ]
for future in tqdm( for future in tqdm(
as_completed(futures), as_completed(futures),
total=len(episode_indices), total=len(episode_indices),
@@ -484,7 +488,7 @@ def convert_dataset_to_videos(
): ):
result = future.result() # This will raise any exceptions that occurred result = future.result() # This will raise any exceptions that occurred
logging.info(result) logging.info(result)
logging.info(f"\n✓ Completed processing {dataset.repo_id}") logging.info(f"\n✓ Completed processing {dataset.repo_id}")