From 242f00a87637661a15f663a722c4d37decbd2910 Mon Sep 17 00:00:00 2001 From: Yushun Xiang <73413365+YushunXiang@users.noreply.github.com> Date: Wed, 21 May 2025 14:43:56 +0800 Subject: [PATCH] add functions to merge images with validation. (#27) * add functions to copy images with validation and check the video frame number. * update README to clarify image file handling and maintain metadata structure * Update utils/dataset_merging/merge_lerobot_dataset.py [feat] accelerate the episode file finding with the ` episode_file_mapping`. * Refactor image and video handling in dataset merging tool with enhanced validation * [fix] If `images` directory exists, then use`early_validation` * fix: Import `encode_video_frames` for video encoding in early validation * fix: Add image copying option and update README for PNG support * fix: Remove unused cv2 import from dataset merging utility --------- Co-authored-by: Qizhi Chen Co-authored-by: zhipeng tang <2444198418@qq.com> --- utils/dataset_merging/README.md | 14 +- .../dataset_merging/merge_lerobot_dataset.py | 333 ++++++++++++++++++ 2 files changed, 345 insertions(+), 2 deletions(-) diff --git a/utils/dataset_merging/README.md b/utils/dataset_merging/README.md index 8af7ced..5316546 100644 --- a/utils/dataset_merging/README.md +++ b/utils/dataset_merging/README.md @@ -8,8 +8,11 @@ merge_lerobot_dataset是一个功能强大的Python脚本,专门用于合并 2. **索引重编号**:重新编号所有的episode索引和任务索引,确保合并后数据的连续性和一致性。 3. **向量维度填充**:自动检测并填充向量维度,使所有数据在observation.state和action等方面具有一致性。 4. **统计信息合并**:智能合并多个数据集的统计信息,正确处理复杂的数据结构,如图像特征的嵌套结构。 -5. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构。 -6. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。 +5. **图像视频验证**:如果有图片 `images` 文件夹的话,自动检测视频 / 图像 / 元数据文件之间的数量关系。 +6. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构。 +7. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。 +8. **图像文件处理**:复制图像文件,保持图片与其他数据之间的正确索引关系。 + ## 三、安装 本脚本依赖于以下Python库: @@ -32,6 +35,7 @@ python dataset_merger.py --sources /path/to/dataset1 /path/to/dataset2 /path/to/ - **--output**:输出数据集文件夹路径,用于指定合并后数据集的存储位置。 - **--max_dim**:向量的最大维度,默认值为32。 - **--fps**:数据集的帧率,默认值为20。 +- **--copy_images**: 是否将图像从源文件夹复制且合并到输出文件夹。(default: `False`) ### (三)示例 ```bash @@ -51,6 +55,9 @@ dataset/ ├── data/ # 包含parquet格式的episode数据 │ └── chunk-xxx/ │ └── episode_xxxxxx.parquet +├── images/ # 可选的图片文件 +│ └── episode_xxxxxx/ +│ └── frame_xxxxxx.png └── videos/ # 可选的视频文件 └── chunk-xxx/ └── video_key/ @@ -63,11 +70,14 @@ dataset/ 3. **统计信息合并**:智能合并多个数据集的统计数据,能够正确处理复杂的数据结构,如图像特征的嵌套结构,确保统计信息的准确性和完整性。 4. **视频文件处理**:正确复制视频文件,并保持视频与其他数据之间的正确索引关系,支持多种视频存储结构,保证视频数据与其他数据的同步性。 5. **任务映射**:自动检测并合并相同的任务描述,创建新的任务索引映射,方便对任务进行统一管理和调用。 +6. **数据预验证**:在执行合并操作前对数据集执行全面的预验证,检查视频帧数、图片数量与元数据中记录的帧长度的一致性,确保合并后数据的准确性和完整性,并自动修复可修复的问题(从图片重新编码视频)。 +7. **图像文件管理**:复制和整理图像文件,保持正确的命名规则和目录结构,确保图像与视频和其他数据保持正确的索引对应关系,支持按需启用图像复制功能来优化存储空间使用。 ## 七、注意事项 1. 确保所有源数据集具有兼容的结构,否则可能导致合并失败或数据错误。 2. 合并后的数据集可能占用较大磁盘空间,在进行合并操作前,请确保有足够的存储空间。 3. 对于非常大的数据集,合并过程可能需要较长时间,请耐心等待。 +4. 图像文件夹占用磁盘空间很大,默认不开启 `copy_images` 参数。本工具在处理 `images` 文件夹下的图片时,**仅支持 PNG 格式的图片**,且要求图片文件名为 `frame_XXXXXX.png`(X为6位数字,例如 `frame_000001.png`)。合并过程中会自动检测并处理这些 PNG 图片。 ## 八、常见问题 1. **Q: 合并不同维度的数据集会发生什么?** diff --git a/utils/dataset_merging/merge_lerobot_dataset.py b/utils/dataset_merging/merge_lerobot_dataset.py index 660ea9c..f543167 100644 --- a/utils/dataset_merging/merge_lerobot_dataset.py +++ b/utils/dataset_merging/merge_lerobot_dataset.py @@ -7,6 +7,7 @@ import traceback import numpy as np import pandas as pd +from termcolor import colored def load_jsonl(file_path): @@ -888,6 +889,321 @@ def pad_parquet_data(source_path, target_path, original_dim=14, target_dim=18): return new_df +def count_video_frames_torchvision(video_path): + """ + Count the number of frames in a video file using torchvision + + Args: + video_path (str): + + Returns: + Frame count (int): + """ + try: + import torchvision + + # Ensure torchvision version is recent enough for VideoReader and AV1 support + # (This is a general good practice, specific version checks might be needed + # depending on the exact AV1 library used by torchvision's backend) + # print(f"Torchvision version: {torchvision.__version__}") + # print(f"PyTorch version: {torch.__version__}") + + # VideoReader requires the video path as a string + reader = torchvision.io.VideoReader(video_path, "video") + + # Attempt to get frame count from metadata + # Metadata structure can vary; "video" stream usually has "num_frames" + metadata = reader.get_metadata() + frame_count = 0 + + if "video" in metadata and "num_frames" in metadata["video"] and len(metadata["video"]["num_frames"]) > 0: + # num_frames is often a list, take the first element + frame_count = int(metadata["video"]["num_frames"][0]) + if frame_count > 0: + # If metadata provides a positive frame count, we can often trust it. + # For some backends/formats, this might be the most reliable way. + return frame_count + + # If metadata didn't provide a reliable frame count, or to be absolutely sure, + # we can iterate through the frames. + # This is more robust but potentially slower. + count_manually = 0 + for _ in reader: # Iterating through the reader yields frames + count_manually += 1 + + # If manual count is zero but metadata had a count, it might indicate an issue + # or an empty video. Prioritize manual count if it's > 0. + if count_manually > 0: + return count_manually + elif frame_count > 0 : # Fallback to metadata if manual count was 0 but metadata had a value + print(f"Warning: Manual count is 0, but metadata indicates {frame_count} frames. Video might be empty or there was a read issue. Returning metadata count.") + return frame_count + else: + # This case means both metadata (if available) and manual iteration yielded 0. + print(f"Video appears to have no frames: {video_path}") + return 0 + + except ImportError: + print("Warning: torchvision or its dependencies (like ffmpeg) not installed, cannot count video frames") + return 0 + except RuntimeError as e: + # RuntimeError can be raised by VideoReader for various issues (e.g., file not found, corrupt file, unsupported codec by the backend) + if "No video stream found" in str(e): + print(f"Error: No video stream found in video file: {video_path}") + elif "Could not open" in str(e) or "Demuxing video" in str(e): + print(f"Error: Could not open or demux video file (possibly unsupported format or corrupted file): {video_path} - {e}") + else: + print(f"Runtime error counting video frames: {e}") + return 0 + except Exception as e: + print(f"Error counting video frames: {e}") + return 0 + finally: + # VideoReader does not have an explicit close() or release() method. + # It's managed by its destructor when it goes out of scope. + pass + + +def early_validation(source_folders, episode_mapping, default_fps=20, fps=None): + """ + Validate and copy image files from source folders to output folder. + Performs validation first before any copying to ensure dataset consistency. + + Args: + source_folders (list): List of source dataset folder paths + output_folder (str): Output folder path + episode_mapping (list): List of tuples containing (old_folder, old_index, new_index) + default_fps (int): Default frame rate to use if not specified + fps (int): Frame rate to use for video encoding + + Returns: + dict: Validation results containing expected frame count and actual image count for each episode + """ + if fps is None: + info_path = os.path.join(source_folders[0], "meta", "info.json") + if os.path.exists(info_path): + with open(info_path) as f: + info = json.load(f) + fps = info.get("fps", default_fps) + else: + fps = default_fps + + print(f"Using FPS={fps}") + + # Get video path template and video keys + info_path = os.path.join(source_folders[0], "meta", "info.json") + with open(info_path) as f: + info = json.load(f) + + video_path_template = info["video_path"] + image_keys = [] + + for feature_name, feature_info in info["features"].items(): + if feature_info.get("dtype") == "video": + image_keys.append(feature_name) + + print(f"Found video/image keys: {image_keys}") + + # Validate first before copying anything + print("Starting validation of images and videos...") + validation_results = {} + validation_failed = False + + episode_file_mapping = {} + for old_folder, old_index, new_index in episode_mapping: + # Get expected frame count from episodes.jsonl + episode_file = os.path.join(old_folder, "meta", "episodes.jsonl") + expected_frames = 0 + if os.path.exists(episode_file): + if episode_file not in episode_file_mapping: + episodes = load_jsonl(episode_file) + episodes = {ep["episode_index"]: ep for ep in episodes} + episode_file_mapping[episode_file] = episodes + episode_data = episode_file_mapping[episode_file].get(old_index, None) + if episode_data and "length" in episode_data: + expected_frames = episode_data["length"] + + validation_key = f"{old_folder}_{old_index}" + validation_results[validation_key] = { + "expected_frames": expected_frames, + "image_counts": {}, + "video_frames": {}, + "old_index": old_index, + "new_index": new_index, + "is_valid": True # Default to valid + } + + # Check each image directory and video + episode_chunk = old_index // info["chunks_size"] + for image_dir in image_keys: + # Find the video file + source_video_path = os.path.join( + old_folder, + video_path_template.format( + episode_chunk=episode_chunk, video_key=image_dir, episode_index=old_index + ), + ) + source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}") + image_dir_exists = os.path.exists(source_image_dir) + video_file_exists = os.path.exists(source_video_path) + if not video_file_exists: + print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video file not found for {image_dir}, episode {old_index} in {old_folder}") + if image_dir_exists: + print(" Image directory exists, encoding video from images.") + from lerobot.common.datasets.video_utils import encode_video_frames + encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True) + print(" Encoded video frames successfully.") + else: + print(f"{colored('ERROR', 'red', attrs=['bold'])}: No video or image directory found for {image_dir}, episode {old_index} in {old_folder}") + validation_results[validation_key]["is_valid"] = False + validation_failed = True + continue + + # Count video frames + video_frame_count = count_video_frames_torchvision(source_video_path) + validation_results[validation_key]["video_frames"][image_dir] = video_frame_count + + # Check if image directory exists + + if image_dir_exists: + # Count image files + image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')]) + images_count = len(image_files) + validation_results[validation_key]["image_counts"][image_dir] = images_count + + error_msg = f"expected_frames: {expected_frames}, images_count: {images_count}, video_frame_count: {video_frame_count}" + assert expected_frames > 0 and expected_frames == images_count, ( + f"{colored('ERROR', 'red', attrs=['bold'])}: Image count should match expected frames for {source_image_dir}.\n {error_msg}" + ) + assert expected_frames >= video_frame_count, ( + f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count should be less or equal than expected frames for {source_video_path}.\n {error_msg}" + ) + # Validate frame counts + if video_frame_count != expected_frames: + print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video frame count mismatch for {source_video_path}") + print(f" Expected: {expected_frames}, Found: {video_frame_count}") + print(f" Re-encoded video frames from {source_image_dir} to {source_video_path}") + + from lerobot.common.datasets.video_utils import encode_video_frames + encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True) + print(" Re-encoded video frames successfully.") + + + else: + print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: No image directory {image_dir} found for episode {old_index} in {old_folder}") + print(" You can ignore this if you are not using images and your video frame count is equal to expected frames.") + # If no images directory, the video frames must match expected frames + if expected_frames > 0 and video_frame_count != expected_frames: + print(f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count mismatch for {source_video_path}") + print(f" Expected: {expected_frames}, Found: {video_frame_count}") + + validation_results[validation_key]["is_valid"] = False + validation_failed = True + + # Print validation summary + print("\nValidation Results:") + valid_count = sum(1 for result in validation_results.values() if result["is_valid"]) + print(f"{valid_count} of {len(validation_results)} episodes are valid") + + # If validation failed, stop the process + if validation_failed: + print(colored("Validation failed. Please fix the issues before continuing.", "red", attrs=["bold"])) + + +def copy_images(source_folders, output_folder, episode_mapping, default_fps=20, fps=None): + """ + Copy image files from source folders to output folder. + This function assumes validation has already been performed with early_validation(). + + Args: + source_folders (list): List of source dataset folder paths + output_folder (str): Output folder path + episode_mapping (list): List of tuples containing (old_folder, old_index, new_index) + default_fps (int): Default frame rate to use if not specified + fps (int): Frame rate to use for video encoding + + Returns: + int: Number of images copied + """ + if fps is None: + info_path = os.path.join(source_folders[0], "meta", "info.json") + if os.path.exists(info_path): + with open(info_path) as f: + info = json.load(f) + fps = info.get("fps", default_fps) + else: + fps = default_fps + + # Get video path template and video keys + info_path = os.path.join(source_folders[0], "meta", "info.json") + with open(info_path) as f: + info = json.load(f) + + video_path_template = info["video_path"] + image_keys = [] + + for feature_name, feature_info in info["features"].items(): + if feature_info.get("dtype") == "video": + image_keys.append(feature_name) + + # Create image directories in output folder + os.makedirs(os.path.join(output_folder, "images"), exist_ok=True) + + print(f"Starting to copy images for {len(image_keys)} video keys...") + total_copied = 0 + skipped_episodes = 0 + + # Copy images for each episode + for old_folder, old_index, new_index in episode_mapping: + episode_chunk = old_index // info["chunks_size"] + new_episode_chunk = new_index // info["chunks_size"] + + episode_copied = False + + for image_dir in image_keys: + # Create target directory for this video key + os.makedirs(os.path.join(output_folder, "images", image_dir), exist_ok=True) + + # Check if source image directory exists + source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}") + + if os.path.exists(source_image_dir): + # Create target directory + target_image_dir = os.path.join(output_folder, "images", image_dir, f"episode_{new_index:06d}") + os.makedirs(target_image_dir, exist_ok=True) + + # Copy image files + image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')]) + num_images = len(image_files) + + if num_images > 0: + print(f"Copying {num_images} images from {source_image_dir} to {target_image_dir}") + + for image_file in image_files: + try: + # Extract frame number from filename + frame_part = image_file.split('_')[1] if '_' in image_file else image_file + frame_num = int(frame_part.split('.')[0]) + + # Copy the file with consistent naming + dest_file = os.path.join(target_image_dir, f"frame_{frame_num:06d}.png") + shutil.copy2( + os.path.join(source_image_dir, image_file), + dest_file + ) + total_copied += 1 + episode_copied = True + except Exception as e: + print(f"Error copying image {image_file}: {e}") + + if not episode_copied: + skipped_episodes += 1 + + print(f"\nCopied {total_copied} images for {len(episode_mapping) - skipped_episodes} episodes") + if skipped_episodes > 0: + print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Skipped {skipped_episodes} episodes with no images") + + def merge_datasets( source_folders, output_folder, validate_ts=False, tolerance_s=1e-4, max_dim=18, default_fps=20 ): @@ -910,6 +1226,7 @@ def merge_datasets( 3. 填充向量维度使其一致 (Pads vector dimensions for consistency) 4. 更新元数据文件 (Updates metadata files) 5. 复制并处理数据和视频文件 (Copies and processes data and video files) + 6. 复制并验证图像文件 (Copies and validates image files) """ # Create output folder if it doesn't exist os.makedirs(output_folder, exist_ok=True) @@ -967,6 +1284,8 @@ def merge_datasets( # 从info.json获取chunks_size info_path = os.path.join(source_folders[0], "meta", "info.json") + # Check if all source folders have images directory + images_dir_exists = all(os.path.exists(os.path.join(folder, "images")) for folder in source_folders) chunks_size = 1000 # 默认值 if os.path.exists(info_path): with open(info_path) as f: @@ -1276,6 +1595,13 @@ def merge_datasets( with open(os.path.join(output_folder, "meta", "info.json"), "w") as f: json.dump(info, f, indent=4) + # Validate before video copying + if images_dir_exists: + early_validation( + source_folders, + episode_mapping, + ) + # Copy video and data files copy_videos(source_folders, output_folder, episode_mapping) copy_data_files( @@ -1289,6 +1615,12 @@ def merge_datasets( chunks_size=chunks_size, ) + # Copy images and check with video frames + if args.copy_images: + print("Starting to copy images and validate video frame counts") + copy_images(source_folders, output_folder, episode_mapping) + + print(f"Merged {total_episodes} episodes with {total_frames} frames into {output_folder}") @@ -1301,6 +1633,7 @@ if __name__ == "__main__": parser.add_argument("--output", required=True, help="Output folder path") parser.add_argument("--max_dim", type=int, default=32, help="Maximum dimension (default: 32)") parser.add_argument("--fps", type=int, default=20, help="Your datasets FPS (default: 20)") + parser.add_argument("--copy_images", action="store_true", help="Whether copy images from source folders to output folder with validation. (default: False)",) # Parse arguments args = parser.parse_args()