add functions to merge images with validation. (#27)

* add functions to copy images with validation and check the video frame number.

* update README to clarify image file handling and maintain metadata structure

* Update utils/dataset_merging/merge_lerobot_dataset.py

[feat] accelerate the episode file finding with the ` episode_file_mapping`.

* Refactor image and video handling in dataset merging tool with enhanced validation

* [fix] If `images` directory exists, then use`early_validation`

* fix: Import `encode_video_frames` for video encoding in early validation

* fix: Add image copying option and update README for PNG support

* fix: Remove unused cv2 import from dataset merging utility

---------

Co-authored-by: Qizhi Chen <tavish9.chen@gmail.com>
Co-authored-by: zhipeng tang <2444198418@qq.com>
This commit is contained in:
Yushun Xiang
2025-05-21 14:43:56 +08:00
committed by GitHub
parent 4364671ea7
commit 242f00a876
2 changed files with 345 additions and 2 deletions
+12 -2
View File
@@ -8,8 +8,11 @@ merge_lerobot_dataset是一个功能强大的Python脚本,专门用于合并
2. **索引重编号**:重新编号所有的episode索引和任务索引,确保合并后数据的连续性和一致性。
3. **向量维度填充**:自动检测并填充向量维度,使所有数据在observation.state和action等方面具有一致性。
4. **统计信息合并**:智能合并多个数据集的统计信息,正确处理复杂的数据结构,如图像特征的嵌套结构。
5. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构
6. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。
5. **图像视频验证**:如果有图片 `images` 文件夹的话,自动检测视频 / 图像 / 元数据文件之间的数量关系
6. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构。
7. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。
8. **图像文件处理**:复制图像文件,保持图片与其他数据之间的正确索引关系。
## 三、安装
本脚本依赖于以下Python库:
@@ -32,6 +35,7 @@ python dataset_merger.py --sources /path/to/dataset1 /path/to/dataset2 /path/to/
- **--output**:输出数据集文件夹路径,用于指定合并后数据集的存储位置。
- **--max_dim**:向量的最大维度,默认值为32。
- **--fps**:数据集的帧率,默认值为20。
- **--copy_images**: 是否将图像从源文件夹复制且合并到输出文件夹。(default: `False`)
### (三)示例
```bash
@@ -51,6 +55,9 @@ dataset/
├── data/ # 包含parquet格式的episode数据
│ └── chunk-xxx/
│ └── episode_xxxxxx.parquet
├── images/ # 可选的图片文件
│ └── episode_xxxxxx/
│ └── frame_xxxxxx.png
└── videos/ # 可选的视频文件
└── chunk-xxx/
└── video_key/
@@ -63,11 +70,14 @@ dataset/
3. **统计信息合并**:智能合并多个数据集的统计数据,能够正确处理复杂的数据结构,如图像特征的嵌套结构,确保统计信息的准确性和完整性。
4. **视频文件处理**:正确复制视频文件,并保持视频与其他数据之间的正确索引关系,支持多种视频存储结构,保证视频数据与其他数据的同步性。
5. **任务映射**:自动检测并合并相同的任务描述,创建新的任务索引映射,方便对任务进行统一管理和调用。
6. **数据预验证**:在执行合并操作前对数据集执行全面的预验证,检查视频帧数、图片数量与元数据中记录的帧长度的一致性,确保合并后数据的准确性和完整性,并自动修复可修复的问题(从图片重新编码视频)。
7. **图像文件管理**:复制和整理图像文件,保持正确的命名规则和目录结构,确保图像与视频和其他数据保持正确的索引对应关系,支持按需启用图像复制功能来优化存储空间使用。
## 七、注意事项
1. 确保所有源数据集具有兼容的结构,否则可能导致合并失败或数据错误。
2. 合并后的数据集可能占用较大磁盘空间,在进行合并操作前,请确保有足够的存储空间。
3. 对于非常大的数据集,合并过程可能需要较长时间,请耐心等待。
4. 图像文件夹占用磁盘空间很大,默认不开启 `copy_images` 参数。本工具在处理 `images` 文件夹下的图片时,**仅支持 PNG 格式的图片**,且要求图片文件名为 `frame_XXXXXX.png`X为6位数字,例如 `frame_000001.png`)。合并过程中会自动检测并处理这些 PNG 图片。
## 八、常见问题
1. **Q: 合并不同维度的数据集会发生什么?**
@@ -7,6 +7,7 @@ import traceback
import numpy as np
import pandas as pd
from termcolor import colored
def load_jsonl(file_path):
@@ -888,6 +889,321 @@ def pad_parquet_data(source_path, target_path, original_dim=14, target_dim=18):
return new_df
def count_video_frames_torchvision(video_path):
"""
Count the number of frames in a video file using torchvision
Args:
video_path (str):
Returns:
Frame count (int):
"""
try:
import torchvision
# Ensure torchvision version is recent enough for VideoReader and AV1 support
# (This is a general good practice, specific version checks might be needed
# depending on the exact AV1 library used by torchvision's backend)
# print(f"Torchvision version: {torchvision.__version__}")
# print(f"PyTorch version: {torch.__version__}")
# VideoReader requires the video path as a string
reader = torchvision.io.VideoReader(video_path, "video")
# Attempt to get frame count from metadata
# Metadata structure can vary; "video" stream usually has "num_frames"
metadata = reader.get_metadata()
frame_count = 0
if "video" in metadata and "num_frames" in metadata["video"] and len(metadata["video"]["num_frames"]) > 0:
# num_frames is often a list, take the first element
frame_count = int(metadata["video"]["num_frames"][0])
if frame_count > 0:
# If metadata provides a positive frame count, we can often trust it.
# For some backends/formats, this might be the most reliable way.
return frame_count
# If metadata didn't provide a reliable frame count, or to be absolutely sure,
# we can iterate through the frames.
# This is more robust but potentially slower.
count_manually = 0
for _ in reader: # Iterating through the reader yields frames
count_manually += 1
# If manual count is zero but metadata had a count, it might indicate an issue
# or an empty video. Prioritize manual count if it's > 0.
if count_manually > 0:
return count_manually
elif frame_count > 0 : # Fallback to metadata if manual count was 0 but metadata had a value
print(f"Warning: Manual count is 0, but metadata indicates {frame_count} frames. Video might be empty or there was a read issue. Returning metadata count.")
return frame_count
else:
# This case means both metadata (if available) and manual iteration yielded 0.
print(f"Video appears to have no frames: {video_path}")
return 0
except ImportError:
print("Warning: torchvision or its dependencies (like ffmpeg) not installed, cannot count video frames")
return 0
except RuntimeError as e:
# RuntimeError can be raised by VideoReader for various issues (e.g., file not found, corrupt file, unsupported codec by the backend)
if "No video stream found" in str(e):
print(f"Error: No video stream found in video file: {video_path}")
elif "Could not open" in str(e) or "Demuxing video" in str(e):
print(f"Error: Could not open or demux video file (possibly unsupported format or corrupted file): {video_path} - {e}")
else:
print(f"Runtime error counting video frames: {e}")
return 0
except Exception as e:
print(f"Error counting video frames: {e}")
return 0
finally:
# VideoReader does not have an explicit close() or release() method.
# It's managed by its destructor when it goes out of scope.
pass
def early_validation(source_folders, episode_mapping, default_fps=20, fps=None):
"""
Validate and copy image files from source folders to output folder.
Performs validation first before any copying to ensure dataset consistency.
Args:
source_folders (list): List of source dataset folder paths
output_folder (str): Output folder path
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
default_fps (int): Default frame rate to use if not specified
fps (int): Frame rate to use for video encoding
Returns:
dict: Validation results containing expected frame count and actual image count for each episode
"""
if fps is None:
info_path = os.path.join(source_folders[0], "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
fps = info.get("fps", default_fps)
else:
fps = default_fps
print(f"Using FPS={fps}")
# Get video path template and video keys
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
video_path_template = info["video_path"]
image_keys = []
for feature_name, feature_info in info["features"].items():
if feature_info.get("dtype") == "video":
image_keys.append(feature_name)
print(f"Found video/image keys: {image_keys}")
# Validate first before copying anything
print("Starting validation of images and videos...")
validation_results = {}
validation_failed = False
episode_file_mapping = {}
for old_folder, old_index, new_index in episode_mapping:
# Get expected frame count from episodes.jsonl
episode_file = os.path.join(old_folder, "meta", "episodes.jsonl")
expected_frames = 0
if os.path.exists(episode_file):
if episode_file not in episode_file_mapping:
episodes = load_jsonl(episode_file)
episodes = {ep["episode_index"]: ep for ep in episodes}
episode_file_mapping[episode_file] = episodes
episode_data = episode_file_mapping[episode_file].get(old_index, None)
if episode_data and "length" in episode_data:
expected_frames = episode_data["length"]
validation_key = f"{old_folder}_{old_index}"
validation_results[validation_key] = {
"expected_frames": expected_frames,
"image_counts": {},
"video_frames": {},
"old_index": old_index,
"new_index": new_index,
"is_valid": True # Default to valid
}
# Check each image directory and video
episode_chunk = old_index // info["chunks_size"]
for image_dir in image_keys:
# Find the video file
source_video_path = os.path.join(
old_folder,
video_path_template.format(
episode_chunk=episode_chunk, video_key=image_dir, episode_index=old_index
),
)
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
image_dir_exists = os.path.exists(source_image_dir)
video_file_exists = os.path.exists(source_video_path)
if not video_file_exists:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video file not found for {image_dir}, episode {old_index} in {old_folder}")
if image_dir_exists:
print(" Image directory exists, encoding video from images.")
from lerobot.common.datasets.video_utils import encode_video_frames
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
print(" Encoded video frames successfully.")
else:
print(f"{colored('ERROR', 'red', attrs=['bold'])}: No video or image directory found for {image_dir}, episode {old_index} in {old_folder}")
validation_results[validation_key]["is_valid"] = False
validation_failed = True
continue
# Count video frames
video_frame_count = count_video_frames_torchvision(source_video_path)
validation_results[validation_key]["video_frames"][image_dir] = video_frame_count
# Check if image directory exists
if image_dir_exists:
# Count image files
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
images_count = len(image_files)
validation_results[validation_key]["image_counts"][image_dir] = images_count
error_msg = f"expected_frames: {expected_frames}, images_count: {images_count}, video_frame_count: {video_frame_count}"
assert expected_frames > 0 and expected_frames == images_count, (
f"{colored('ERROR', 'red', attrs=['bold'])}: Image count should match expected frames for {source_image_dir}.\n {error_msg}"
)
assert expected_frames >= video_frame_count, (
f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count should be less or equal than expected frames for {source_video_path}.\n {error_msg}"
)
# Validate frame counts
if video_frame_count != expected_frames:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
print(f" Re-encoded video frames from {source_image_dir} to {source_video_path}")
from lerobot.common.datasets.video_utils import encode_video_frames
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
print(" Re-encoded video frames successfully.")
else:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: No image directory {image_dir} found for episode {old_index} in {old_folder}")
print(" You can ignore this if you are not using images and your video frame count is equal to expected frames.")
# If no images directory, the video frames must match expected frames
if expected_frames > 0 and video_frame_count != expected_frames:
print(f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
validation_results[validation_key]["is_valid"] = False
validation_failed = True
# Print validation summary
print("\nValidation Results:")
valid_count = sum(1 for result in validation_results.values() if result["is_valid"])
print(f"{valid_count} of {len(validation_results)} episodes are valid")
# If validation failed, stop the process
if validation_failed:
print(colored("Validation failed. Please fix the issues before continuing.", "red", attrs=["bold"]))
def copy_images(source_folders, output_folder, episode_mapping, default_fps=20, fps=None):
"""
Copy image files from source folders to output folder.
This function assumes validation has already been performed with early_validation().
Args:
source_folders (list): List of source dataset folder paths
output_folder (str): Output folder path
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
default_fps (int): Default frame rate to use if not specified
fps (int): Frame rate to use for video encoding
Returns:
int: Number of images copied
"""
if fps is None:
info_path = os.path.join(source_folders[0], "meta", "info.json")
if os.path.exists(info_path):
with open(info_path) as f:
info = json.load(f)
fps = info.get("fps", default_fps)
else:
fps = default_fps
# Get video path template and video keys
info_path = os.path.join(source_folders[0], "meta", "info.json")
with open(info_path) as f:
info = json.load(f)
video_path_template = info["video_path"]
image_keys = []
for feature_name, feature_info in info["features"].items():
if feature_info.get("dtype") == "video":
image_keys.append(feature_name)
# Create image directories in output folder
os.makedirs(os.path.join(output_folder, "images"), exist_ok=True)
print(f"Starting to copy images for {len(image_keys)} video keys...")
total_copied = 0
skipped_episodes = 0
# Copy images for each episode
for old_folder, old_index, new_index in episode_mapping:
episode_chunk = old_index // info["chunks_size"]
new_episode_chunk = new_index // info["chunks_size"]
episode_copied = False
for image_dir in image_keys:
# Create target directory for this video key
os.makedirs(os.path.join(output_folder, "images", image_dir), exist_ok=True)
# Check if source image directory exists
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
if os.path.exists(source_image_dir):
# Create target directory
target_image_dir = os.path.join(output_folder, "images", image_dir, f"episode_{new_index:06d}")
os.makedirs(target_image_dir, exist_ok=True)
# Copy image files
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
num_images = len(image_files)
if num_images > 0:
print(f"Copying {num_images} images from {source_image_dir} to {target_image_dir}")
for image_file in image_files:
try:
# Extract frame number from filename
frame_part = image_file.split('_')[1] if '_' in image_file else image_file
frame_num = int(frame_part.split('.')[0])
# Copy the file with consistent naming
dest_file = os.path.join(target_image_dir, f"frame_{frame_num:06d}.png")
shutil.copy2(
os.path.join(source_image_dir, image_file),
dest_file
)
total_copied += 1
episode_copied = True
except Exception as e:
print(f"Error copying image {image_file}: {e}")
if not episode_copied:
skipped_episodes += 1
print(f"\nCopied {total_copied} images for {len(episode_mapping) - skipped_episodes} episodes")
if skipped_episodes > 0:
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Skipped {skipped_episodes} episodes with no images")
def merge_datasets(
source_folders, output_folder, validate_ts=False, tolerance_s=1e-4, max_dim=18, default_fps=20
):
@@ -910,6 +1226,7 @@ def merge_datasets(
3. 填充向量维度使其一致 (Pads vector dimensions for consistency)
4. 更新元数据文件 (Updates metadata files)
5. 复制并处理数据和视频文件 (Copies and processes data and video files)
6. 复制并验证图像文件 (Copies and validates image files)
"""
# Create output folder if it doesn't exist
os.makedirs(output_folder, exist_ok=True)
@@ -967,6 +1284,8 @@ def merge_datasets(
# 从info.json获取chunks_size
info_path = os.path.join(source_folders[0], "meta", "info.json")
# Check if all source folders have images directory
images_dir_exists = all(os.path.exists(os.path.join(folder, "images")) for folder in source_folders)
chunks_size = 1000 # 默认值
if os.path.exists(info_path):
with open(info_path) as f:
@@ -1276,6 +1595,13 @@ def merge_datasets(
with open(os.path.join(output_folder, "meta", "info.json"), "w") as f:
json.dump(info, f, indent=4)
# Validate before video copying
if images_dir_exists:
early_validation(
source_folders,
episode_mapping,
)
# Copy video and data files
copy_videos(source_folders, output_folder, episode_mapping)
copy_data_files(
@@ -1289,6 +1615,12 @@ def merge_datasets(
chunks_size=chunks_size,
)
# Copy images and check with video frames
if args.copy_images:
print("Starting to copy images and validate video frame counts")
copy_images(source_folders, output_folder, episode_mapping)
print(f"Merged {total_episodes} episodes with {total_frames} frames into {output_folder}")
@@ -1301,6 +1633,7 @@ if __name__ == "__main__":
parser.add_argument("--output", required=True, help="Output folder path")
parser.add_argument("--max_dim", type=int, default=32, help="Maximum dimension (default: 32)")
parser.add_argument("--fps", type=int, default=20, help="Your datasets FPS (default: 20)")
parser.add_argument("--copy_images", action="store_true", help="Whether copy images from source folders to output folder with validation. (default: False)",)
# Parse arguments
args = parser.parse_args()