mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 07:09:43 +00:00
Compare commits
6 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 6216932fb0 | |||
| 5fab1ed5cd | |||
| 8b7c46c5f7 | |||
| ba97f64afd | |||
| 8d861fe94b | |||
| d22fa6446b |
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
#### Convert to Video
|
||||
|
||||
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||
|
||||
```bash
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
@@ -96,7 +159,7 @@ lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
Goal:
|
||||
Create a high and low policy, the high policy take obs + text and return text, its a VLM
|
||||
The low policy is a VLA that take text returned by the high and ouptut actions
|
||||
|
||||
|
||||
High policy is ran every one second or when user send prompt
|
||||
|
||||
|
||||
Synthetic data generation:
|
||||
D demo -> teleop data with a global task annotation
|
||||
D label -> segment data into short skills (one to three seconds)
|
||||
D syn: p-gen will create the high level prompt user might gave to p-hi
|
||||
Given D label prompt p-gen to imagine appropriate action, take images, ALL PRIOR skill labels in the episode: ℓ̂₀, …, ℓ̂ₜ₋₁
|
||||
|
||||
“Given the scene + all previous steps + current needed skill ℓ̂₅,
|
||||
generate a user request that logically leads to ℓ̂₅.”
|
||||
|
||||
Train:
|
||||
Phi(lt| images, global label) cross entropy - next token predictions
|
||||
Plow(At| images, qt, lt)
|
||||
@@ -0,0 +1,141 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --time=24:00:00
|
||||
#SBATCH --partition=hopper-cpu
|
||||
#SBATCH --cpus-per-task=96
|
||||
#SBATCH --output=/fsx/jade_choghari/logs/launcher_%j.out
|
||||
#SBATCH --error=/fsx/jade_choghari/logs/launcher_%j.err
|
||||
# Activate conda environment
|
||||
# Load conda
|
||||
source /fsx/jade_choghari/miniforge3/etc/profile.d/conda.sh
|
||||
conda activate lerobot
|
||||
set -e # Exit on error
|
||||
|
||||
# Input dataset
|
||||
REPO_ID="lerobot"
|
||||
ROOT="/fsx/jade_choghari/vlabench-primitive/"
|
||||
|
||||
# Output paths
|
||||
OUTPUT_DIR="/fsx/jade_choghari/vlabench-primitive-encoded"
|
||||
OUTPUT_REPO_ID="vlabench-primitive-encoded"
|
||||
LOGS_DIR="/fsx/jade_choghari/logs/convert_video"
|
||||
|
||||
# Video encoding settings
|
||||
VCODEC="libsvtav1"
|
||||
PIX_FMT="yuv420p"
|
||||
GOP_SIZE=2
|
||||
CRF=30
|
||||
FAST_DECODE=0
|
||||
|
||||
# Parallelization settings
|
||||
NUM_WORKERS=24 # Number of parallel SLURM workers
|
||||
NUM_IMAGE_WORKERS=4 # Threads per worker for image saving
|
||||
|
||||
# SLURM settings
|
||||
PARTITION="hopper-cpu" # Change to your CPU partition name
|
||||
CPUS_PER_TASK=32 # CPUs per worker
|
||||
MEM_PER_CPU="4G" # Memory per CPU
|
||||
TIME_LIMIT="24:00:00" # Time limit per job
|
||||
|
||||
###############################################################################
|
||||
# STEP 1: Parallel Video Conversion
|
||||
###############################################################################
|
||||
rm -rf "${OUTPUT_DIR}"
|
||||
mkdir -p "${OUTPUT_DIR}"
|
||||
|
||||
echo "=============================================="
|
||||
echo "STEP 1: Starting parallel video conversion"
|
||||
echo " Workers: ${NUM_WORKERS}"
|
||||
echo " Input: ${REPO_ID} (root: ${ROOT})"
|
||||
echo " Output: ${OUTPUT_DIR}"
|
||||
echo "=============================================="
|
||||
|
||||
python /admin/home/jade_choghari/lerobot/examples/port_datasets/slurm_convert_to_video.py\
|
||||
--repo-id "${REPO_ID}" \
|
||||
--root "${ROOT}" \
|
||||
--output-dir "${OUTPUT_DIR}" \
|
||||
--output-repo-id "${OUTPUT_REPO_ID}" \
|
||||
--vcodec "${VCODEC}" \
|
||||
--pix-fmt "${PIX_FMT}" \
|
||||
--g ${GOP_SIZE} \
|
||||
--crf ${CRF} \
|
||||
--fast-decode ${FAST_DECODE} \
|
||||
--num-image-workers ${NUM_IMAGE_WORKERS} \
|
||||
--logs-dir "${LOGS_DIR}" \
|
||||
--job-name "convert_video" \
|
||||
--slurm 1 \
|
||||
--workers ${NUM_WORKERS} \
|
||||
--partition "${PARTITION}" \
|
||||
--cpus-per-task ${CPUS_PER_TASK} \
|
||||
--mem-per-cpu "${MEM_PER_CPU}" \
|
||||
--time-limit "${TIME_LIMIT}"
|
||||
|
||||
echo ""
|
||||
echo "✓ Parallel conversion jobs submitted!"
|
||||
echo " Monitor with: squeue -u \$USER"
|
||||
echo " Check logs in: ${LOGS_DIR}/convert_video"
|
||||
echo ""
|
||||
echo "Wait for all jobs to complete before running Step 2."
|
||||
echo "You can check completion with: squeue -u \$USER | grep convert_video"
|
||||
echo ""
|
||||
echo "After all jobs complete, run Step 2 to aggregate shards:"
|
||||
echo " bash convert_to_video_parallel.sh aggregate"
|
||||
|
||||
###############################################################################
|
||||
# STEP 2: Aggregate Shards (run this after Step 1 completes)
|
||||
###############################################################################
|
||||
|
||||
if [ "$1" == "aggregate" ]; then
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "STEP 2: Aggregating video shards"
|
||||
echo " Shards: ${NUM_WORKERS}"
|
||||
echo " Input: ${OUTPUT_DIR}/shard_XXXX"
|
||||
echo " Output: ${OUTPUT_DIR}_final"
|
||||
echo "=============================================="
|
||||
|
||||
python slurm_aggregate_video_shards.py \
|
||||
--shards-dir "${OUTPUT_DIR}" \
|
||||
--output-dir "${OUTPUT_DIR}_final" \
|
||||
--output-repo-id "${OUTPUT_REPO_ID}" \
|
||||
--num-shards ${NUM_WORKERS} \
|
||||
--logs-dir "${LOGS_DIR}" \
|
||||
--job-name "aggregate_video" \
|
||||
--slurm 1 \
|
||||
--partition "${PARTITION}" \
|
||||
--cpus-per-task 16 \
|
||||
--mem-per-cpu "8G" \
|
||||
--time-limit "08:00:00"
|
||||
|
||||
echo ""
|
||||
echo "✓ Aggregation job submitted!"
|
||||
echo " Monitor with: squeue -u \$USER | grep aggregate_video"
|
||||
echo " Check logs in: ${LOGS_DIR}/aggregate_video"
|
||||
echo ""
|
||||
echo "After completion, your final dataset will be in:"
|
||||
echo " ${OUTPUT_DIR}_final"
|
||||
fi
|
||||
|
||||
###############################################################################
|
||||
# Helpful information
|
||||
###############################################################################
|
||||
|
||||
if [ "$1" != "aggregate" ]; then
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
echo "WORKFLOW SUMMARY"
|
||||
echo "=============================================="
|
||||
echo ""
|
||||
echo "1. Step 1 is now running - it will:"
|
||||
echo " - Split episodes across ${NUM_WORKERS} workers"
|
||||
echo " - Each worker converts its episodes to video"
|
||||
echo " - Creates shard datasets in ${OUTPUT_DIR}/shard_XXXX"
|
||||
echo ""
|
||||
echo "2. After Step 1 completes, run Step 2:"
|
||||
echo " bash convert_to_video_parallel.sh aggregate"
|
||||
echo ""
|
||||
echo "3. Step 2 will merge all shards into a single dataset"
|
||||
echo ""
|
||||
echo "=============================================="
|
||||
fi
|
||||
|
||||
|
||||
@@ -0,0 +1,266 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Aggregate video dataset shards into a single dataset.
|
||||
|
||||
After parallel conversion using slurm_convert_to_video.py, this script merges
|
||||
all the shard datasets into one final dataset.
|
||||
|
||||
Example usage:
|
||||
python slurm_aggregate_video_shards.py \
|
||||
--shards-dir /fsx/jade_choghari/libero_video \
|
||||
--output-dir /fsx/jade_choghari/libero_video_final \
|
||||
--output-repo-id lerobot_video \
|
||||
--num-workers 100 \
|
||||
--partition cpu_partition \
|
||||
--cpus-per-task 16
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class AggregateVideoShards(PipelineStep):
|
||||
"""Pipeline step that aggregates video dataset shards."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
shards_dir: str | Path,
|
||||
output_dir: str | Path,
|
||||
output_repo_id: str,
|
||||
num_shards: int,
|
||||
):
|
||||
super().__init__()
|
||||
self.shards_dir = Path(shards_dir)
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_repo_id = output_repo_id
|
||||
self.num_shards = num_shards
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
"""Aggregate all shards into a single dataset."""
|
||||
from lerobot.datasets.dataset_tools import merge_datasets
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Only worker 0 performs aggregation
|
||||
if rank != 0:
|
||||
logging.info(f"Worker {rank} skipping - only worker 0 performs aggregation")
|
||||
return
|
||||
|
||||
logging.info(f"Starting aggregation of {self.num_shards} shards")
|
||||
|
||||
# Collect all shard datasets
|
||||
shard_datasets = []
|
||||
for shard_idx in range(self.num_shards):
|
||||
shard_dir = self.shards_dir / f"shard_{shard_idx:04d}"
|
||||
if not shard_dir.exists():
|
||||
logging.warning(f"Shard directory not found: {shard_dir}")
|
||||
continue
|
||||
|
||||
# Find the repo_id for this shard
|
||||
shard_repo_id = f"{self.output_repo_id}_shard_{shard_idx:04d}"
|
||||
try:
|
||||
shard_dataset = LeRobotDataset(shard_repo_id, root=shard_dir)
|
||||
shard_datasets.append(shard_dataset)
|
||||
logging.info(
|
||||
f"Loaded shard {shard_idx}: {shard_dataset.meta.total_episodes} episodes, "
|
||||
f"{shard_dataset.meta.total_frames} frames"
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to load shard {shard_idx}: {e}")
|
||||
continue
|
||||
|
||||
if len(shard_datasets) == 0:
|
||||
raise ValueError(f"No valid shards found in {self.shards_dir}")
|
||||
|
||||
logging.info(f"Successfully loaded {len(shard_datasets)} shards, starting merge")
|
||||
|
||||
# Merge all shards
|
||||
self.output_dir.mkdir(parents=True, exist_ok=True)
|
||||
merged_dataset = merge_datasets(
|
||||
shard_datasets,
|
||||
output_repo_id=self.output_repo_id,
|
||||
output_dir=self.output_dir,
|
||||
)
|
||||
|
||||
logging.info("✓ Aggregation complete!")
|
||||
logging.info(f"Merged dataset saved to: {self.output_dir}")
|
||||
logging.info(f"Total episodes: {merged_dataset.meta.total_episodes}")
|
||||
logging.info(f"Total frames: {merged_dataset.meta.total_frames}")
|
||||
|
||||
|
||||
def make_aggregate_executor(
|
||||
shards_dir,
|
||||
output_dir,
|
||||
output_repo_id,
|
||||
num_shards,
|
||||
job_name,
|
||||
logs_dir,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
time_limit,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for shard aggregation."""
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
AggregateVideoShards(
|
||||
shards_dir=shards_dir,
|
||||
output_dir=output_dir,
|
||||
output_repo_id=output_repo_id,
|
||||
num_shards=num_shards,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
# Only need 1 worker for aggregation
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": 1,
|
||||
"workers": 1,
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Aggregate video dataset shards into a single dataset",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
parser.add_argument(
|
||||
"--shards-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Directory containing shard_XXXX subdirectories",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Output directory for the aggregated dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID for the aggregated dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-shards",
|
||||
type=int,
|
||||
required=True,
|
||||
help="Number of shards to aggregate (should match --workers from conversion)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="aggregate_video_shards",
|
||||
help="Job name for SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM (1) or locally (0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
required=True,
|
||||
help="SLURM partition (use CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=16,
|
||||
help="Number of CPUs per task (aggregation can use more CPUs)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="8G",
|
||||
help="Memory per CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time-limit",
|
||||
type=str,
|
||||
default="08:00:00",
|
||||
help="Time limit for SLURM job",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert slurm flag to boolean
|
||||
slurm = args.slurm == 1
|
||||
|
||||
# Create and run executor
|
||||
executor = make_aggregate_executor(
|
||||
shards_dir=args.shards_dir,
|
||||
output_dir=args.output_dir,
|
||||
output_repo_id=args.output_repo_id,
|
||||
num_shards=args.num_shards,
|
||||
job_name=args.job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time_limit,
|
||||
slurm=slurm,
|
||||
)
|
||||
|
||||
logging.info("Starting shard aggregation")
|
||||
executor.run()
|
||||
logging.info("Aggregation job submitted/completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -0,0 +1,539 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Parallelize video conversion using SLURM and Datatrove.
|
||||
|
||||
This script converts an image-based LeRobot dataset to video format by distributing
|
||||
episodes across multiple workers for parallel processing.
|
||||
|
||||
Example usage:
|
||||
python slurm_convert_to_video.py \
|
||||
--repo-id lerobot \
|
||||
--root /fsx/jade_choghari/libero/ \
|
||||
--output-dir /fsx/jade_choghari/libero_video \
|
||||
--output-repo-id lerobot_video \
|
||||
--workers 100 \
|
||||
--partition cpu_partition \
|
||||
--cpus-per-task 8 \
|
||||
--logs-dir ./logs
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
|
||||
|
||||
class ConvertEpisodesToVideo(PipelineStep):
|
||||
"""Pipeline step that converts episodes to videos in parallel."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
repo_id: str,
|
||||
root: str | None,
|
||||
output_dir: str,
|
||||
output_repo_id: str | None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
num_image_workers: int = 4,
|
||||
):
|
||||
super().__init__()
|
||||
self.repo_id = repo_id
|
||||
self.root = root
|
||||
self.output_dir = Path(output_dir)
|
||||
self.output_repo_id = output_repo_id or f"{repo_id}_video"
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.fast_decode = fast_decode
|
||||
self.num_image_workers = num_image_workers
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
"""Process a shard of episodes."""
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.scripts.lerobot_edit_dataset import (
|
||||
encode_episode_videos,
|
||||
)
|
||||
from lerobot.utils.utils import init_logging
|
||||
import logging
|
||||
init_logging()
|
||||
disable_progress_bars()
|
||||
|
||||
logging.info(f"Worker {rank}/{world_size} starting video conversion")
|
||||
|
||||
# Load source dataset
|
||||
dataset = LeRobotDataset(self.repo_id, root=self.root)
|
||||
total_episodes = dataset.meta.total_episodes
|
||||
|
||||
# Determine which episodes this worker processes
|
||||
episodes_per_worker = total_episodes // world_size
|
||||
remainder = total_episodes % world_size
|
||||
|
||||
if rank < remainder:
|
||||
# First 'remainder' workers get one extra episode
|
||||
start_ep = rank * (episodes_per_worker + 1)
|
||||
end_ep = start_ep + episodes_per_worker + 1
|
||||
else:
|
||||
start_ep = remainder * (episodes_per_worker + 1) + (rank - remainder) * episodes_per_worker
|
||||
end_ep = start_ep + episodes_per_worker
|
||||
|
||||
episode_indices = list(range(start_ep, end_ep))
|
||||
|
||||
logging.info(
|
||||
f"Worker {rank} processing episodes {start_ep} to {end_ep-1} ({len(episode_indices)} episodes)"
|
||||
)
|
||||
|
||||
if len(episode_indices) == 0:
|
||||
logging.info(f"Worker {rank} has no episodes to process")
|
||||
return
|
||||
|
||||
# Create shard-specific output directory
|
||||
import shutil
|
||||
shard_output_dir = self.output_dir / f"shard_{rank:04d}"
|
||||
|
||||
# Remove existing directory to avoid conflicts with LeRobotDatasetMetadata.create
|
||||
if shard_output_dir.exists():
|
||||
logging.warning(f"Shard directory {shard_output_dir} already exists, removing it")
|
||||
shutil.rmtree(shard_output_dir)
|
||||
|
||||
# Import conversion function
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
|
||||
|
||||
logging.info(
|
||||
f"Worker {rank} converting {len(episode_indices)} episodes with codec {self.vcodec}, CRF {self.crf}"
|
||||
)
|
||||
|
||||
# Convert this shard's episodes with remapped indices
|
||||
# We need to remap episode indices to start from 0 for proper file structure
|
||||
self._convert_shard_to_videos(
|
||||
dataset=dataset,
|
||||
shard_output_dir=shard_output_dir,
|
||||
shard_repo_id=f"{self.output_repo_id}_shard_{rank:04d}",
|
||||
episode_indices=episode_indices,
|
||||
)
|
||||
|
||||
logging.info(f"Worker {rank} completed successfully")
|
||||
|
||||
def _convert_shard_to_videos(
|
||||
self,
|
||||
dataset,
|
||||
shard_output_dir,
|
||||
shard_repo_id,
|
||||
episode_indices,
|
||||
):
|
||||
"""Convert a shard's episodes to videos with proper index remapping."""
|
||||
import shutil
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_info, write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.scripts.lerobot_edit_dataset import (
|
||||
save_episode_images_for_video,
|
||||
_copy_data_without_images,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
logging.info(f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video"
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=shard_repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=shard_output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = shard_output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each episode with REMAPPED indices (0, 1, 2, ...)
|
||||
all_episode_metadata = []
|
||||
fps = int(dataset.fps)
|
||||
|
||||
try:
|
||||
for new_ep_idx, orig_ep_idx in enumerate(tqdm(episode_indices, desc="Converting episodes")):
|
||||
# Get episode metadata from source using ORIGINAL index
|
||||
src_episode = dataset.meta.episodes[orig_ep_idx]
|
||||
episode_length = src_episode["length"]
|
||||
episode_duration = episode_length / dataset.fps
|
||||
|
||||
video_metadata = {}
|
||||
|
||||
# Encode videos for each camera
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily using ORIGINAL episode index
|
||||
imgs_dir = temp_dir / f"episode_{orig_ep_idx:06d}" / img_key
|
||||
save_episode_images_for_video(
|
||||
dataset, imgs_dir, img_key, orig_ep_idx, self.num_image_workers
|
||||
)
|
||||
|
||||
# Determine chunk and file indices using NEW (remapped) episode index
|
||||
chunk_idx = new_ep_idx // new_meta.chunks_size
|
||||
file_idx = new_ep_idx % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
fast_decode=self.fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
# Build episode metadata using NEW index
|
||||
episode_meta = {
|
||||
"episode_index": new_ep_idx,
|
||||
"length": episode_length,
|
||||
"dataset_from_index": new_ep_idx * episode_length,
|
||||
"dataset_to_index": (new_ep_idx + 1) * episode_length,
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys using the actual first video file
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
# Use the first actually created video file
|
||||
chunk_idx = all_episode_metadata[0][f"videos/{img_key}/chunk_index"]
|
||||
file_idx = all_episode_metadata[0][f"videos/{img_key}/file_index"]
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
|
||||
def make_convert_executor(
|
||||
repo_id,
|
||||
root,
|
||||
output_dir,
|
||||
output_repo_id,
|
||||
vcodec,
|
||||
pix_fmt,
|
||||
g,
|
||||
crf,
|
||||
fast_decode,
|
||||
num_image_workers,
|
||||
job_name,
|
||||
logs_dir,
|
||||
workers,
|
||||
partition,
|
||||
cpus_per_task,
|
||||
mem_per_cpu,
|
||||
time_limit,
|
||||
slurm=True,
|
||||
):
|
||||
"""Create executor for parallel video conversion."""
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
ConvertEpisodesToVideo(
|
||||
repo_id=repo_id,
|
||||
root=root,
|
||||
output_dir=output_dir,
|
||||
output_repo_id=output_repo_id,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
num_image_workers=num_image_workers,
|
||||
),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
|
||||
if slurm:
|
||||
kwargs.update(
|
||||
{
|
||||
"job_name": job_name,
|
||||
"tasks": workers, # Number of parallel tasks = number of workers
|
||||
"workers": workers,
|
||||
"time": time_limit,
|
||||
"partition": partition,
|
||||
"cpus_per_task": cpus_per_task,
|
||||
"sbatch_args": {"mem-per-cpu": mem_per_cpu},
|
||||
}
|
||||
)
|
||||
executor = SlurmPipelineExecutor(**kwargs)
|
||||
else:
|
||||
kwargs.update(
|
||||
{
|
||||
"tasks": workers,
|
||||
"workers": 1, # Local mode: sequential
|
||||
}
|
||||
)
|
||||
executor = LocalPipelineExecutor(**kwargs)
|
||||
|
||||
return executor
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Parallelize video conversion across SLURM workers",
|
||||
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
|
||||
)
|
||||
|
||||
# Input/Output paths
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Repository ID of the source image dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--root",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Root directory containing the source dataset (default: HF cache)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Output directory for the video dataset",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-repo-id",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Repository ID for output dataset (default: <repo_id>_video)",
|
||||
)
|
||||
|
||||
# Video encoding parameters
|
||||
parser.add_argument(
|
||||
"--vcodec",
|
||||
type=str,
|
||||
default="libsvtav1",
|
||||
help="Video codec",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--pix-fmt",
|
||||
type=str,
|
||||
default="yuv420p",
|
||||
help="Pixel format",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--g",
|
||||
type=int,
|
||||
default=2,
|
||||
help="GOP size (group of pictures)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--crf",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Constant rate factor (quality)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fast-decode",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Fast decode tuning",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-image-workers",
|
||||
type=int,
|
||||
default=4,
|
||||
help="Number of threads per worker for saving images",
|
||||
)
|
||||
|
||||
# SLURM parameters
|
||||
parser.add_argument(
|
||||
"--logs-dir",
|
||||
type=Path,
|
||||
required=True,
|
||||
help="Path to logs directory for datatrove",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--job-name",
|
||||
type=str,
|
||||
default="convert_to_video",
|
||||
help="Job name for SLURM",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--slurm",
|
||||
type=int,
|
||||
default=1,
|
||||
help="Launch over SLURM (1) or locally (0)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--workers",
|
||||
type=int,
|
||||
default=100,
|
||||
help="Number of parallel workers (each processes different episodes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--partition",
|
||||
type=str,
|
||||
required=True,
|
||||
help="SLURM partition (use CPU partition)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--cpus-per-task",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of CPUs per SLURM task",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--mem-per-cpu",
|
||||
type=str,
|
||||
default="4G",
|
||||
help="Memory per CPU",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--time-limit",
|
||||
type=str,
|
||||
default="08:00:00",
|
||||
help="Time limit for SLURM job",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Convert slurm flag to boolean
|
||||
slurm = args.slurm == 1
|
||||
|
||||
# Create and run executor
|
||||
executor = make_convert_executor(
|
||||
repo_id=args.repo_id,
|
||||
root=args.root,
|
||||
output_dir=args.output_dir,
|
||||
output_repo_id=args.output_repo_id,
|
||||
vcodec=args.vcodec,
|
||||
pix_fmt=args.pix_fmt,
|
||||
g=args.g,
|
||||
crf=args.crf,
|
||||
fast_decode=args.fast_decode,
|
||||
num_image_workers=args.num_image_workers,
|
||||
job_name=args.job_name,
|
||||
logs_dir=args.logs_dir,
|
||||
workers=args.workers,
|
||||
partition=args.partition,
|
||||
cpus_per_task=args.cpus_per_task,
|
||||
mem_per_cpu=args.mem_per_cpu,
|
||||
time_limit=args.time_limit,
|
||||
slurm=slurm,
|
||||
)
|
||||
|
||||
logging.info(f"Starting video conversion with {args.workers} workers")
|
||||
executor.run()
|
||||
logging.info("All workers submitted/completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
@@ -0,0 +1,7 @@
|
||||
==============================================
|
||||
STEP 1: Starting parallel video conversion
|
||||
Workers: 2
|
||||
Input: lerobot (root: /fsx/jade_choghari/vlabench-primitive/)
|
||||
Output: /fsx/jade_choghari/vlabench-primitive-encoded
|
||||
==============================================
|
||||
python: can't open file '/admin/home/jade_choghari/lerobot/slurm_convert_to_video.py': [Errno 2] No such file or directory
|
||||
@@ -18,7 +18,8 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
and remove features. When new_repo_id is specified, creates a new dataset.
|
||||
remove features, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
@@ -65,6 +66,25 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Convert image dataset to video format (saves locally):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
Convert and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -72,9 +92,13 @@ Using JSON config file:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertToVideoConfig:
|
||||
type: str = "convert_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def encode_episode_videos(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
episode_index: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
temp_dir: Path,
|
||||
num_image_workers: int = 4,
|
||||
) -> dict[str, dict]:
|
||||
"""Encode videos for a single episode and return video metadata.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images
|
||||
new_meta: Metadata object for the new video dataset
|
||||
episode_index: Episode index to process
|
||||
vcodec: Video codec
|
||||
pix_fmt: Pixel format
|
||||
g: Group of pictures size
|
||||
crf: Constant rate factor
|
||||
fast_decode: Fast decode tuning
|
||||
temp_dir: Temporary directory for images
|
||||
num_image_workers: Number of workers for saving images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
video_metadata = {}
|
||||
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily
|
||||
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||
|
||||
# Determine chunk and file indices
|
||||
# For simplicity, we'll put each episode in its own file
|
||||
chunk_idx = episode_index // new_meta.chunks_size
|
||||
file_idx = episode_index % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
return video_metadata
|
||||
|
||||
|
||||
def convert_dataset_to_videos(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-based dataset to video-based dataset.
|
||||
|
||||
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each episode
|
||||
all_episode_metadata = []
|
||||
|
||||
try:
|
||||
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||
# Get episode metadata from source
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
|
||||
# Encode videos for this episode
|
||||
video_metadata = encode_episode_videos(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
episode_index=ep_idx,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
temp_dir=temp_dir,
|
||||
num_image_workers=num_workers,
|
||||
)
|
||||
|
||||
# Build episode metadata
|
||||
episode_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": src_episode["length"],
|
||||
"dataset_from_index": ep_idx * src_episode["length"],
|
||||
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info (using same structure as source)
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
from lerobot.datasets.utils import write_info
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
output_repo_id = cfg.new_repo_id
|
||||
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
# Extract repo name from output_dir for the dataset
|
||||
output_repo_id = output_dir.name
|
||||
logging.info(f"Saving to local directory: {output_dir}")
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
new_dataset = convert_dataset_to_videos(
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
)
|
||||
|
||||
logging.info("Video dataset created successfully!")
|
||||
logging.info(f"Location: {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
|
||||
logging.info(f"Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}...")
|
||||
new_dataset.push_to_hub()
|
||||
logging.info("✓ Successfully pushed to hub!")
|
||||
else:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "convert_to_video":
|
||||
handle_convert_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.scripts.lerobot_edit_dataset import convert_dataset_to_videos
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -1047,3 +1048,107 @@ def test_modify_features_preserves_file_structure(sample_dataset, tmp_path):
|
||||
assert new_chunk_indices == original_chunk_indices, "Chunk indices should be preserved"
|
||||
assert new_file_indices == original_file_indices, "File indices should be preserved"
|
||||
assert "reward" in modified_dataset.meta.features
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos(tmp_path):
|
||||
"""Test converting lerobot/pusht_image dataset to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load the actual lerobot/pusht_image dataset (only first 2 episodes for speed)
|
||||
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1])
|
||||
|
||||
output_dir = tmp_path / "pusht_video"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Verify source dataset has images, not videos
|
||||
assert len(source_dataset.meta.video_keys) == 0
|
||||
assert "observation.image" in source_dataset.meta.features
|
||||
|
||||
# Convert to video dataset (only first 2 episodes for speed)
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video",
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
episode_indices=[0, 1],
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
# Verify new dataset has videos
|
||||
assert len(video_dataset.meta.video_keys) > 0
|
||||
assert "observation.image" in video_dataset.meta.video_keys
|
||||
|
||||
# Verify correct number of episodes and frames (2 episodes)
|
||||
assert video_dataset.meta.total_episodes == 2
|
||||
# Compare against the actual number of frames in the loaded episodes, not metadata total
|
||||
assert len(video_dataset) == len(source_dataset)
|
||||
|
||||
# Verify video files exist
|
||||
for ep_idx in range(video_dataset.meta.total_episodes):
|
||||
for video_key in video_dataset.meta.video_keys:
|
||||
video_path = video_dataset.root / video_dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
assert video_path.exists(), f"Video file should exist: {video_path}"
|
||||
|
||||
# Verify we can load the dataset and access it
|
||||
assert len(video_dataset) == video_dataset.meta.total_frames
|
||||
|
||||
# Test that we can actually get an item from the video dataset
|
||||
item = video_dataset[0]
|
||||
assert "observation.image" in item
|
||||
assert "action" in item
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
|
||||
def test_convert_dataset_to_videos_subset_episodes(tmp_path):
|
||||
"""Test converting only specific episodes from lerobot/pusht_image to video format."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load the actual lerobot/pusht_image dataset (only first 3 episodes)
|
||||
source_dataset = LeRobotDataset("lerobot/pusht_image", episodes=[0, 1, 2])
|
||||
|
||||
output_dir = tmp_path / "pusht_video_subset"
|
||||
|
||||
with (
|
||||
patch("lerobot.datasets.lerobot_dataset.get_safe_version") as mock_get_safe_version,
|
||||
patch("lerobot.datasets.lerobot_dataset.snapshot_download") as mock_snapshot_download,
|
||||
):
|
||||
mock_get_safe_version.return_value = "v3.0"
|
||||
mock_snapshot_download.return_value = str(output_dir)
|
||||
|
||||
# Convert only episode 0 to video (subset of loaded episodes)
|
||||
episode_indices = [0]
|
||||
|
||||
video_dataset = convert_dataset_to_videos(
|
||||
dataset=source_dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id="lerobot/pusht_video_subset",
|
||||
episode_indices=episode_indices,
|
||||
num_workers=2,
|
||||
)
|
||||
|
||||
# Verify correct number of episodes
|
||||
assert video_dataset.meta.total_episodes == len(episode_indices)
|
||||
|
||||
# Verify video files exist for selected episodes
|
||||
assert len(video_dataset.meta.video_keys) > 0
|
||||
assert "observation.image" in video_dataset.meta.video_keys
|
||||
|
||||
# Cleanup
|
||||
import shutil
|
||||
|
||||
if output_dir.exists():
|
||||
shutil.rmtree(output_dir)
|
||||
|
||||
Reference in New Issue
Block a user