mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-18 07:29:44 +00:00
add functions to merge images with validation. (#27)
* add functions to copy images with validation and check the video frame number. * update README to clarify image file handling and maintain metadata structure * Update utils/dataset_merging/merge_lerobot_dataset.py [feat] accelerate the episode file finding with the ` episode_file_mapping`. * Refactor image and video handling in dataset merging tool with enhanced validation * [fix] If `images` directory exists, then use`early_validation` * fix: Import `encode_video_frames` for video encoding in early validation * fix: Add image copying option and update README for PNG support * fix: Remove unused cv2 import from dataset merging utility --------- Co-authored-by: Qizhi Chen <tavish9.chen@gmail.com> Co-authored-by: zhipeng tang <2444198418@qq.com>
This commit is contained in:
@@ -8,8 +8,11 @@ merge_lerobot_dataset是一个功能强大的Python脚本,专门用于合并
|
||||
2. **索引重编号**:重新编号所有的episode索引和任务索引,确保合并后数据的连续性和一致性。
|
||||
3. **向量维度填充**:自动检测并填充向量维度,使所有数据在observation.state和action等方面具有一致性。
|
||||
4. **统计信息合并**:智能合并多个数据集的统计信息,正确处理复杂的数据结构,如图像特征的嵌套结构。
|
||||
5. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构。
|
||||
6. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。
|
||||
5. **图像视频验证**:如果有图片 `images` 文件夹的话,自动检测视频 / 图像 / 元数据文件之间的数量关系。
|
||||
6. **视频文件处理**:正确处理和复制视频文件,保持视频与其他数据之间的正确索引关系,支持多种视频存储结构。
|
||||
7. **元数据更新**:更新所有元数据文件,准确反映合并后的数据集结构。
|
||||
8. **图像文件处理**:复制图像文件,保持图片与其他数据之间的正确索引关系。
|
||||
|
||||
|
||||
## 三、安装
|
||||
本脚本依赖于以下Python库:
|
||||
@@ -32,6 +35,7 @@ python dataset_merger.py --sources /path/to/dataset1 /path/to/dataset2 /path/to/
|
||||
- **--output**:输出数据集文件夹路径,用于指定合并后数据集的存储位置。
|
||||
- **--max_dim**:向量的最大维度,默认值为32。
|
||||
- **--fps**:数据集的帧率,默认值为20。
|
||||
- **--copy_images**: 是否将图像从源文件夹复制且合并到输出文件夹。(default: `False`)
|
||||
|
||||
### (三)示例
|
||||
```bash
|
||||
@@ -51,6 +55,9 @@ dataset/
|
||||
├── data/ # 包含parquet格式的episode数据
|
||||
│ └── chunk-xxx/
|
||||
│ └── episode_xxxxxx.parquet
|
||||
├── images/ # 可选的图片文件
|
||||
│ └── episode_xxxxxx/
|
||||
│ └── frame_xxxxxx.png
|
||||
└── videos/ # 可选的视频文件
|
||||
└── chunk-xxx/
|
||||
└── video_key/
|
||||
@@ -63,11 +70,14 @@ dataset/
|
||||
3. **统计信息合并**:智能合并多个数据集的统计数据,能够正确处理复杂的数据结构,如图像特征的嵌套结构,确保统计信息的准确性和完整性。
|
||||
4. **视频文件处理**:正确复制视频文件,并保持视频与其他数据之间的正确索引关系,支持多种视频存储结构,保证视频数据与其他数据的同步性。
|
||||
5. **任务映射**:自动检测并合并相同的任务描述,创建新的任务索引映射,方便对任务进行统一管理和调用。
|
||||
6. **数据预验证**:在执行合并操作前对数据集执行全面的预验证,检查视频帧数、图片数量与元数据中记录的帧长度的一致性,确保合并后数据的准确性和完整性,并自动修复可修复的问题(从图片重新编码视频)。
|
||||
7. **图像文件管理**:复制和整理图像文件,保持正确的命名规则和目录结构,确保图像与视频和其他数据保持正确的索引对应关系,支持按需启用图像复制功能来优化存储空间使用。
|
||||
|
||||
## 七、注意事项
|
||||
1. 确保所有源数据集具有兼容的结构,否则可能导致合并失败或数据错误。
|
||||
2. 合并后的数据集可能占用较大磁盘空间,在进行合并操作前,请确保有足够的存储空间。
|
||||
3. 对于非常大的数据集,合并过程可能需要较长时间,请耐心等待。
|
||||
4. 图像文件夹占用磁盘空间很大,默认不开启 `copy_images` 参数。本工具在处理 `images` 文件夹下的图片时,**仅支持 PNG 格式的图片**,且要求图片文件名为 `frame_XXXXXX.png`(X为6位数字,例如 `frame_000001.png`)。合并过程中会自动检测并处理这些 PNG 图片。
|
||||
|
||||
## 八、常见问题
|
||||
1. **Q: 合并不同维度的数据集会发生什么?**
|
||||
|
||||
@@ -7,6 +7,7 @@ import traceback
|
||||
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from termcolor import colored
|
||||
|
||||
|
||||
def load_jsonl(file_path):
|
||||
@@ -888,6 +889,321 @@ def pad_parquet_data(source_path, target_path, original_dim=14, target_dim=18):
|
||||
return new_df
|
||||
|
||||
|
||||
def count_video_frames_torchvision(video_path):
|
||||
"""
|
||||
Count the number of frames in a video file using torchvision
|
||||
|
||||
Args:
|
||||
video_path (str):
|
||||
|
||||
Returns:
|
||||
Frame count (int):
|
||||
"""
|
||||
try:
|
||||
import torchvision
|
||||
|
||||
# Ensure torchvision version is recent enough for VideoReader and AV1 support
|
||||
# (This is a general good practice, specific version checks might be needed
|
||||
# depending on the exact AV1 library used by torchvision's backend)
|
||||
# print(f"Torchvision version: {torchvision.__version__}")
|
||||
# print(f"PyTorch version: {torch.__version__}")
|
||||
|
||||
# VideoReader requires the video path as a string
|
||||
reader = torchvision.io.VideoReader(video_path, "video")
|
||||
|
||||
# Attempt to get frame count from metadata
|
||||
# Metadata structure can vary; "video" stream usually has "num_frames"
|
||||
metadata = reader.get_metadata()
|
||||
frame_count = 0
|
||||
|
||||
if "video" in metadata and "num_frames" in metadata["video"] and len(metadata["video"]["num_frames"]) > 0:
|
||||
# num_frames is often a list, take the first element
|
||||
frame_count = int(metadata["video"]["num_frames"][0])
|
||||
if frame_count > 0:
|
||||
# If metadata provides a positive frame count, we can often trust it.
|
||||
# For some backends/formats, this might be the most reliable way.
|
||||
return frame_count
|
||||
|
||||
# If metadata didn't provide a reliable frame count, or to be absolutely sure,
|
||||
# we can iterate through the frames.
|
||||
# This is more robust but potentially slower.
|
||||
count_manually = 0
|
||||
for _ in reader: # Iterating through the reader yields frames
|
||||
count_manually += 1
|
||||
|
||||
# If manual count is zero but metadata had a count, it might indicate an issue
|
||||
# or an empty video. Prioritize manual count if it's > 0.
|
||||
if count_manually > 0:
|
||||
return count_manually
|
||||
elif frame_count > 0 : # Fallback to metadata if manual count was 0 but metadata had a value
|
||||
print(f"Warning: Manual count is 0, but metadata indicates {frame_count} frames. Video might be empty or there was a read issue. Returning metadata count.")
|
||||
return frame_count
|
||||
else:
|
||||
# This case means both metadata (if available) and manual iteration yielded 0.
|
||||
print(f"Video appears to have no frames: {video_path}")
|
||||
return 0
|
||||
|
||||
except ImportError:
|
||||
print("Warning: torchvision or its dependencies (like ffmpeg) not installed, cannot count video frames")
|
||||
return 0
|
||||
except RuntimeError as e:
|
||||
# RuntimeError can be raised by VideoReader for various issues (e.g., file not found, corrupt file, unsupported codec by the backend)
|
||||
if "No video stream found" in str(e):
|
||||
print(f"Error: No video stream found in video file: {video_path}")
|
||||
elif "Could not open" in str(e) or "Demuxing video" in str(e):
|
||||
print(f"Error: Could not open or demux video file (possibly unsupported format or corrupted file): {video_path} - {e}")
|
||||
else:
|
||||
print(f"Runtime error counting video frames: {e}")
|
||||
return 0
|
||||
except Exception as e:
|
||||
print(f"Error counting video frames: {e}")
|
||||
return 0
|
||||
finally:
|
||||
# VideoReader does not have an explicit close() or release() method.
|
||||
# It's managed by its destructor when it goes out of scope.
|
||||
pass
|
||||
|
||||
|
||||
def early_validation(source_folders, episode_mapping, default_fps=20, fps=None):
|
||||
"""
|
||||
Validate and copy image files from source folders to output folder.
|
||||
Performs validation first before any copying to ensure dataset consistency.
|
||||
|
||||
Args:
|
||||
source_folders (list): List of source dataset folder paths
|
||||
output_folder (str): Output folder path
|
||||
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
|
||||
default_fps (int): Default frame rate to use if not specified
|
||||
fps (int): Frame rate to use for video encoding
|
||||
|
||||
Returns:
|
||||
dict: Validation results containing expected frame count and actual image count for each episode
|
||||
"""
|
||||
if fps is None:
|
||||
info_path = os.path.join(source_folders[0], "meta", "info.json")
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
fps = info.get("fps", default_fps)
|
||||
else:
|
||||
fps = default_fps
|
||||
|
||||
print(f"Using FPS={fps}")
|
||||
|
||||
# Get video path template and video keys
|
||||
info_path = os.path.join(source_folders[0], "meta", "info.json")
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
video_path_template = info["video_path"]
|
||||
image_keys = []
|
||||
|
||||
for feature_name, feature_info in info["features"].items():
|
||||
if feature_info.get("dtype") == "video":
|
||||
image_keys.append(feature_name)
|
||||
|
||||
print(f"Found video/image keys: {image_keys}")
|
||||
|
||||
# Validate first before copying anything
|
||||
print("Starting validation of images and videos...")
|
||||
validation_results = {}
|
||||
validation_failed = False
|
||||
|
||||
episode_file_mapping = {}
|
||||
for old_folder, old_index, new_index in episode_mapping:
|
||||
# Get expected frame count from episodes.jsonl
|
||||
episode_file = os.path.join(old_folder, "meta", "episodes.jsonl")
|
||||
expected_frames = 0
|
||||
if os.path.exists(episode_file):
|
||||
if episode_file not in episode_file_mapping:
|
||||
episodes = load_jsonl(episode_file)
|
||||
episodes = {ep["episode_index"]: ep for ep in episodes}
|
||||
episode_file_mapping[episode_file] = episodes
|
||||
episode_data = episode_file_mapping[episode_file].get(old_index, None)
|
||||
if episode_data and "length" in episode_data:
|
||||
expected_frames = episode_data["length"]
|
||||
|
||||
validation_key = f"{old_folder}_{old_index}"
|
||||
validation_results[validation_key] = {
|
||||
"expected_frames": expected_frames,
|
||||
"image_counts": {},
|
||||
"video_frames": {},
|
||||
"old_index": old_index,
|
||||
"new_index": new_index,
|
||||
"is_valid": True # Default to valid
|
||||
}
|
||||
|
||||
# Check each image directory and video
|
||||
episode_chunk = old_index // info["chunks_size"]
|
||||
for image_dir in image_keys:
|
||||
# Find the video file
|
||||
source_video_path = os.path.join(
|
||||
old_folder,
|
||||
video_path_template.format(
|
||||
episode_chunk=episode_chunk, video_key=image_dir, episode_index=old_index
|
||||
),
|
||||
)
|
||||
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
|
||||
image_dir_exists = os.path.exists(source_image_dir)
|
||||
video_file_exists = os.path.exists(source_video_path)
|
||||
if not video_file_exists:
|
||||
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video file not found for {image_dir}, episode {old_index} in {old_folder}")
|
||||
if image_dir_exists:
|
||||
print(" Image directory exists, encoding video from images.")
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
|
||||
print(" Encoded video frames successfully.")
|
||||
else:
|
||||
print(f"{colored('ERROR', 'red', attrs=['bold'])}: No video or image directory found for {image_dir}, episode {old_index} in {old_folder}")
|
||||
validation_results[validation_key]["is_valid"] = False
|
||||
validation_failed = True
|
||||
continue
|
||||
|
||||
# Count video frames
|
||||
video_frame_count = count_video_frames_torchvision(source_video_path)
|
||||
validation_results[validation_key]["video_frames"][image_dir] = video_frame_count
|
||||
|
||||
# Check if image directory exists
|
||||
|
||||
if image_dir_exists:
|
||||
# Count image files
|
||||
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
|
||||
images_count = len(image_files)
|
||||
validation_results[validation_key]["image_counts"][image_dir] = images_count
|
||||
|
||||
error_msg = f"expected_frames: {expected_frames}, images_count: {images_count}, video_frame_count: {video_frame_count}"
|
||||
assert expected_frames > 0 and expected_frames == images_count, (
|
||||
f"{colored('ERROR', 'red', attrs=['bold'])}: Image count should match expected frames for {source_image_dir}.\n {error_msg}"
|
||||
)
|
||||
assert expected_frames >= video_frame_count, (
|
||||
f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count should be less or equal than expected frames for {source_video_path}.\n {error_msg}"
|
||||
)
|
||||
# Validate frame counts
|
||||
if video_frame_count != expected_frames:
|
||||
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
|
||||
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
|
||||
print(f" Re-encoded video frames from {source_image_dir} to {source_video_path}")
|
||||
|
||||
from lerobot.common.datasets.video_utils import encode_video_frames
|
||||
encode_video_frames(source_image_dir, source_video_path, fps, overwrite=True)
|
||||
print(" Re-encoded video frames successfully.")
|
||||
|
||||
|
||||
else:
|
||||
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: No image directory {image_dir} found for episode {old_index} in {old_folder}")
|
||||
print(" You can ignore this if you are not using images and your video frame count is equal to expected frames.")
|
||||
# If no images directory, the video frames must match expected frames
|
||||
if expected_frames > 0 and video_frame_count != expected_frames:
|
||||
print(f"{colored('ERROR', 'red', attrs=['bold'])}: Video frame count mismatch for {source_video_path}")
|
||||
print(f" Expected: {expected_frames}, Found: {video_frame_count}")
|
||||
|
||||
validation_results[validation_key]["is_valid"] = False
|
||||
validation_failed = True
|
||||
|
||||
# Print validation summary
|
||||
print("\nValidation Results:")
|
||||
valid_count = sum(1 for result in validation_results.values() if result["is_valid"])
|
||||
print(f"{valid_count} of {len(validation_results)} episodes are valid")
|
||||
|
||||
# If validation failed, stop the process
|
||||
if validation_failed:
|
||||
print(colored("Validation failed. Please fix the issues before continuing.", "red", attrs=["bold"]))
|
||||
|
||||
|
||||
def copy_images(source_folders, output_folder, episode_mapping, default_fps=20, fps=None):
|
||||
"""
|
||||
Copy image files from source folders to output folder.
|
||||
This function assumes validation has already been performed with early_validation().
|
||||
|
||||
Args:
|
||||
source_folders (list): List of source dataset folder paths
|
||||
output_folder (str): Output folder path
|
||||
episode_mapping (list): List of tuples containing (old_folder, old_index, new_index)
|
||||
default_fps (int): Default frame rate to use if not specified
|
||||
fps (int): Frame rate to use for video encoding
|
||||
|
||||
Returns:
|
||||
int: Number of images copied
|
||||
"""
|
||||
if fps is None:
|
||||
info_path = os.path.join(source_folders[0], "meta", "info.json")
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
fps = info.get("fps", default_fps)
|
||||
else:
|
||||
fps = default_fps
|
||||
|
||||
# Get video path template and video keys
|
||||
info_path = os.path.join(source_folders[0], "meta", "info.json")
|
||||
with open(info_path) as f:
|
||||
info = json.load(f)
|
||||
|
||||
video_path_template = info["video_path"]
|
||||
image_keys = []
|
||||
|
||||
for feature_name, feature_info in info["features"].items():
|
||||
if feature_info.get("dtype") == "video":
|
||||
image_keys.append(feature_name)
|
||||
|
||||
# Create image directories in output folder
|
||||
os.makedirs(os.path.join(output_folder, "images"), exist_ok=True)
|
||||
|
||||
print(f"Starting to copy images for {len(image_keys)} video keys...")
|
||||
total_copied = 0
|
||||
skipped_episodes = 0
|
||||
|
||||
# Copy images for each episode
|
||||
for old_folder, old_index, new_index in episode_mapping:
|
||||
episode_chunk = old_index // info["chunks_size"]
|
||||
new_episode_chunk = new_index // info["chunks_size"]
|
||||
|
||||
episode_copied = False
|
||||
|
||||
for image_dir in image_keys:
|
||||
# Create target directory for this video key
|
||||
os.makedirs(os.path.join(output_folder, "images", image_dir), exist_ok=True)
|
||||
|
||||
# Check if source image directory exists
|
||||
source_image_dir = os.path.join(old_folder, "images", image_dir, f"episode_{old_index:06d}")
|
||||
|
||||
if os.path.exists(source_image_dir):
|
||||
# Create target directory
|
||||
target_image_dir = os.path.join(output_folder, "images", image_dir, f"episode_{new_index:06d}")
|
||||
os.makedirs(target_image_dir, exist_ok=True)
|
||||
|
||||
# Copy image files
|
||||
image_files = sorted([f for f in os.listdir(source_image_dir) if f.endswith('.png')])
|
||||
num_images = len(image_files)
|
||||
|
||||
if num_images > 0:
|
||||
print(f"Copying {num_images} images from {source_image_dir} to {target_image_dir}")
|
||||
|
||||
for image_file in image_files:
|
||||
try:
|
||||
# Extract frame number from filename
|
||||
frame_part = image_file.split('_')[1] if '_' in image_file else image_file
|
||||
frame_num = int(frame_part.split('.')[0])
|
||||
|
||||
# Copy the file with consistent naming
|
||||
dest_file = os.path.join(target_image_dir, f"frame_{frame_num:06d}.png")
|
||||
shutil.copy2(
|
||||
os.path.join(source_image_dir, image_file),
|
||||
dest_file
|
||||
)
|
||||
total_copied += 1
|
||||
episode_copied = True
|
||||
except Exception as e:
|
||||
print(f"Error copying image {image_file}: {e}")
|
||||
|
||||
if not episode_copied:
|
||||
skipped_episodes += 1
|
||||
|
||||
print(f"\nCopied {total_copied} images for {len(episode_mapping) - skipped_episodes} episodes")
|
||||
if skipped_episodes > 0:
|
||||
print(f"{colored('WARNING', 'yellow', attrs=['bold'])}: Skipped {skipped_episodes} episodes with no images")
|
||||
|
||||
|
||||
def merge_datasets(
|
||||
source_folders, output_folder, validate_ts=False, tolerance_s=1e-4, max_dim=18, default_fps=20
|
||||
):
|
||||
@@ -910,6 +1226,7 @@ def merge_datasets(
|
||||
3. 填充向量维度使其一致 (Pads vector dimensions for consistency)
|
||||
4. 更新元数据文件 (Updates metadata files)
|
||||
5. 复制并处理数据和视频文件 (Copies and processes data and video files)
|
||||
6. 复制并验证图像文件 (Copies and validates image files)
|
||||
"""
|
||||
# Create output folder if it doesn't exist
|
||||
os.makedirs(output_folder, exist_ok=True)
|
||||
@@ -967,6 +1284,8 @@ def merge_datasets(
|
||||
|
||||
# 从info.json获取chunks_size
|
||||
info_path = os.path.join(source_folders[0], "meta", "info.json")
|
||||
# Check if all source folders have images directory
|
||||
images_dir_exists = all(os.path.exists(os.path.join(folder, "images")) for folder in source_folders)
|
||||
chunks_size = 1000 # 默认值
|
||||
if os.path.exists(info_path):
|
||||
with open(info_path) as f:
|
||||
@@ -1276,6 +1595,13 @@ def merge_datasets(
|
||||
with open(os.path.join(output_folder, "meta", "info.json"), "w") as f:
|
||||
json.dump(info, f, indent=4)
|
||||
|
||||
# Validate before video copying
|
||||
if images_dir_exists:
|
||||
early_validation(
|
||||
source_folders,
|
||||
episode_mapping,
|
||||
)
|
||||
|
||||
# Copy video and data files
|
||||
copy_videos(source_folders, output_folder, episode_mapping)
|
||||
copy_data_files(
|
||||
@@ -1289,6 +1615,12 @@ def merge_datasets(
|
||||
chunks_size=chunks_size,
|
||||
)
|
||||
|
||||
# Copy images and check with video frames
|
||||
if args.copy_images:
|
||||
print("Starting to copy images and validate video frame counts")
|
||||
copy_images(source_folders, output_folder, episode_mapping)
|
||||
|
||||
|
||||
print(f"Merged {total_episodes} episodes with {total_frames} frames into {output_folder}")
|
||||
|
||||
|
||||
@@ -1301,6 +1633,7 @@ if __name__ == "__main__":
|
||||
parser.add_argument("--output", required=True, help="Output folder path")
|
||||
parser.add_argument("--max_dim", type=int, default=32, help="Maximum dimension (default: 32)")
|
||||
parser.add_argument("--fps", type=int, default=20, help="Your datasets FPS (default: 20)")
|
||||
parser.add_argument("--copy_images", action="store_true", help="Whether copy images from source folders to output folder with validation. (default: False)",)
|
||||
|
||||
# Parse arguments
|
||||
args = parser.parse_args()
|
||||
|
||||
Reference in New Issue
Block a user