fix(precommit) solve precommit issues

This commit is contained in:
Michel Aractingi
2025-06-30 17:24:43 +02:00
parent bb85f4ebea
commit 67485b1edc
7 changed files with 105 additions and 265 deletions
@@ -1,184 +0,0 @@
import json
import logging
from pathlib import Path
import shutil
from huggingface_hub import snapshot_download
import tarfile
import tqdm
from examples.port_datasets.agibot_hdf5.port_agibot import port_agibot
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.utils.utils import init_logging
from huggingface_hub import HfApi, HfFileSystem
RAW_REPO_ID = "agibot-world/AgiBotWorld-Alpha"
def download(raw_dir, allow_patterns=None, ignore_patterns=None):
snapshot_download(
RAW_REPO_ID,
repo_type="dataset",
local_dir=str(raw_dir),
allow_patterns=allow_patterns,
ignore_patterns=ignore_patterns,
)
def download_proprio_stats(raw_dir):
proprio_stats_dir = raw_dir / "proprio_stats"
if proprio_stats_dir.exists():
logging.info("Skipping download proprio stats")
return
download(raw_dir, allow_patterns="proprio_stats/*.tar")
for path in proprio_stats_dir.glob("*.tar"):
logging.info(f"Untar-ing {path}...")
with tarfile.open(path, 'r') as tar:
tar.extractall(path=proprio_stats_dir)
logging.info(f"Deleting {path}...")
path.unlink()
def download_parameters(raw_dir):
params_dir = raw_dir / "parameters"
if params_dir.exists():
logging.info("Skipping download parameters")
return
download(raw_dir, allow_patterns="parameters/*.tar")
for path in params_dir.glob("*.tar"):
logging.info(f"Untar-ing {path}...")
with tarfile.open(path, 'r') as tar:
tar.extractall(path=params_dir)
logging.info(f"Deleting {path}...")
path.unlink()
def get_observations_files(raw_dir, raw_repo_id):
files_json_path = raw_dir / "observations_files.json"
sizes_json_path = raw_dir / "observations_sizes.json"
if files_json_path.exists() and sizes_json_path.exists():
with open(files_json_path) as f:
files = json.load(f)
with open(sizes_json_path) as f:
sizes = json.load(f)
return files, sizes
api = HfApi()
files = api.list_repo_files(repo_id=raw_repo_id, repo_type="dataset")
files = [file for file in files if "observations/" in file]
fs = HfFileSystem()
sizes = []
for file in tqdm.tqdm(files, desc="Downloading file sizes"):
file_info = fs.info(f"datasets/{raw_repo_id}/{file}")
size = file_info["size"] / 1000**3
sizes.append(size)
# Sort ASC to start with smaller size files
sizes, files = zip(*sorted(zip(sizes, files)))
with open(files_json_path, "w") as f:
json.dump(files, f)
with open(sizes_json_path, "w") as f:
json.dump(sizes, f)
return files, sizes
def display_observations_sizes(files, sizes):
size_per_task = {}
for i, (file, size) in enumerate(zip(files, sizes)):
logging.info(f"{i}/{len(files)}: {file} {size:.2f}GB")
task = int(file.split('/')[1])
if task not in size_per_task:
size_per_task[task] = 0
size_per_task[task] += size
for task, size in size_per_task.items():
logging.info(f"{task} {size:.2f}GB")
total_size = sum(list(size_per_task.values()))
logging.info(f"Total size: {total_size:.2f}GB")
def download_meta_data(raw_dir):
# Download task data
download(raw_dir, allow_patterns="task_info/task_*.json")
# Download all camera parameters ~170 GB
download_parameters(raw_dir)
# Download all proprio stats ~26 GB
download_proprio_stats(raw_dir)
def no_depth(tarinfo, path):
""" Utility to not untar depth data"""
if "depth" in tarinfo.name:
return None
return tarinfo
def main():
init_logging()
repo_id = "cadene/agibot_alpha_v30"
raw_dir = Path("/fsx/remi_cadene/data/AgiBotWorld-Alpha")
download_meta_data(raw_dir)
# Get list of tar files containing observation data (containing several episodes each)
obs_files, obs_sizes = get_observations_files(raw_dir, RAW_REPO_ID)
display_observations_sizes(obs_files, obs_sizes)
shard_indices = range(len(obs_files))
num_shards = len(obs_files)
# TOOD: remove
obs_files = obs_files[:2]
shard_indices = [0,1]
# Iterate on each subset of episodes
for shard_index, obs_file in zip(shard_indices, obs_files):
shard_repo_id = f"{repo_id}_world_{num_shards}_rank_{shard_index}"
dataset_dir = HF_LEROBOT_HOME / shard_repo_id
if dataset_dir.exists():
shutil.rmtree(dataset_dir)
# Download subset
download(raw_dir, allow_patterns=obs_file)
tar_path = raw_dir / obs_file
with tarfile.open(tar_path, 'r') as tar:
extracted_files = tar.getnames()
task_index = int(tar_path.parent.name)
episode_names = [int(p) for p in extracted_files if '/' not in p]
# Untar if needed
if not all([(tar_path.parent / f"{ep_name}").exists() for ep_name in episode_names]):
logging.info(f"Untar-ing {tar_path}...")
with tarfile.open(tar_path, 'r') as tar:
tar.extractall(path=tar_path.parent, filter=no_depth)
port_agibot(raw_dir, shard_repo_id, task_index, episode_names, push_to_hub=False)
for ep_name in episode_names:
shutil.rmtree(tar_path.parent / f"{ep_name}")
tar_path.unlink()
# dataset = LeRobotDataset(shard_repo_id, root=dataset_dir)
# lol=1
if __name__ == "__main__":
main()
@@ -1,18 +1,27 @@
import json import json
import logging import logging
from pathlib import Path
import shutil import shutil
import time import time
import numpy as np from pathlib import Path
import h5py import h5py
import numpy as np
import pandas as pd import pandas as pd
from lerobot.common.datasets.lerobot_dataset import LeRobotDataset from lerobot.common.datasets.lerobot_dataset import LeRobotDataset
from lerobot.common.datasets.utils import DEFAULT_CHUNK_SIZE, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_VIDEO_PATH, EPISODES_DIR, concat_video_files, get_video_duration_in_s, get_video_size_in_mb, update_chunk_file_indices, write_info from lerobot.common.datasets.utils import (
DEFAULT_CHUNK_SIZE,
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
DEFAULT_VIDEO_PATH,
EPISODES_DIR,
concat_video_files,
get_video_duration_in_s,
get_video_size_in_mb,
update_chunk_file_indices,
write_info,
)
from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds from lerobot.common.utils.utils import get_elapsed_time_in_days_hours_minutes_seconds
AGIBOT_FPS = 30 AGIBOT_FPS = 30
AGIBOT_ROBOT_TYPE = "AgiBot_A2D" AGIBOT_ROBOT_TYPE = "AgiBot_A2D"
AGIBOT_FEATURES = { AGIBOT_FEATURES = {
@@ -77,12 +86,12 @@ AGIBOT_FEATURES = {
"dtype": "float32", "dtype": "float32",
"shape": (20,), "shape": (20,),
"names": { "names": {
"axes": ["head_yaw", "head_pitch"] + \ "axes": ["head_yaw", "head_pitch"]
[f"left_joint_{i}" for i in range(7)] + \ + [f"left_joint_{i}" for i in range(7)]
["left_gripper"] + \ + ["left_gripper"]
[f"right_joint_{i}" for i in range(7)] + \ + [f"right_joint_{i}" for i in range(7)]
["right_gripper"] + \ + ["right_gripper"]
["waist_pitch", "waist_lift"], + ["waist_pitch", "waist_lift"],
}, },
}, },
# gripper open range in mm (0 for pull open, 1 for full close) # gripper open range in mm (0 for pull open, 1 for full close)
@@ -145,13 +154,13 @@ AGIBOT_FEATURES = {
"dtype": "float32", "dtype": "float32",
"shape": (22,), "shape": (22,),
"names": { "names": {
"axes": ["head_yaw", "head_pitch"] + \ "axes": ["head_yaw", "head_pitch"]
[f"left_joint_{i}" for i in range(7)] + \ + [f"left_joint_{i}" for i in range(7)]
["left_gripper"] + \ + ["left_gripper"]
[f"right_joint_{i}" for i in range(7)] + \ + [f"right_joint_{i}" for i in range(7)]
["right_gripper"] + \ + ["right_gripper"]
["waist_pitch", "waist_lift"] + \ + ["waist_pitch", "waist_lift"]
["velocity_x", "yaw_rate"], + ["velocity_x", "yaw_rate"],
}, },
}, },
# episode level annotation # episode level annotation
@@ -217,11 +226,12 @@ AGIBOT_IMAGES_FEATURES = {
}, },
} }
def load_info_per_task(raw_dir): def load_info_per_task(raw_dir):
info_per_task = {} info_per_task = {}
task_info_dir = raw_dir / "task_info" task_info_dir = raw_dir / "task_info"
for path in task_info_dir.glob("task_*.json"): for path in task_info_dir.glob("task_*.json"):
task_index = int(path.name.replace("task_","").replace(".json","")) task_index = int(path.name.replace("task_", "").replace(".json", ""))
with open(path) as f: with open(path) as f:
task_info = json.load(f) task_info = json.load(f)
@@ -230,6 +240,7 @@ def load_info_per_task(raw_dir):
return info_per_task return info_per_task
def create_frame_idx_to_frames_label_idx(ep_info): def create_frame_idx_to_frames_label_idx(ep_info):
frame_idx_to_frames_label_idx = {} frame_idx_to_frames_label_idx = {}
for label_idx, frames_label in enumerate(ep_info["label_info"]["action_config"]): for label_idx, frames_label in enumerate(ep_info["label_info"]["action_config"]):
@@ -237,9 +248,9 @@ def create_frame_idx_to_frames_label_idx(ep_info):
frame_idx_to_frames_label_idx[frame_idx] = label_idx frame_idx_to_frames_label_idx[frame_idx] = label_idx
return frame_idx_to_frames_label_idx return frame_idx_to_frames_label_idx
def generate_lerobot_frames(raw_dir: Path, task_index: int, episode_index: int): def generate_lerobot_frames(raw_dir: Path, task_index: int, episode_index: int):
""" /!\ The frames dont contain observation.cameras.* r"""/!\ The frames dont contain observation.cameras.*"""
"""
info_per_task = load_info_per_task(raw_dir) info_per_task = load_info_per_task(raw_dir)
ep_info = info_per_task[task_index][episode_index] ep_info = info_per_task[task_index][episode_index]
frame_idx_to_frames_label_idx = create_frame_idx_to_frames_label_idx(ep_info) frame_idx_to_frames_label_idx = create_frame_idx_to_frames_label_idx(ep_info)
@@ -297,7 +308,9 @@ def generate_lerobot_frames(raw_dir: Path, task_index: int, episode_index: int):
for h5_key in keys_mapping.values(): for h5_key in keys_mapping.values():
col_num_frames = h5[h5_key].shape[0] col_num_frames = h5[h5_key].shape[0]
if col_num_frames != num_frames: if col_num_frames != num_frames:
raise ValueError(f"HDF5 column '{h5_key}' is expected to have {num_frames} but has {col_num_frames}' frames instead.") raise ValueError(
f"HDF5 column '{h5_key}' is expected to have {num_frames} but has {col_num_frames}' frames instead."
)
for i in range(num_frames): for i in range(num_frames):
# Create frame # Create frame
@@ -308,26 +321,30 @@ def generate_lerobot_frames(raw_dir: Path, task_index: int, episode_index: int):
f["observation.state.end.position"] = f["observation.state.end.position"].reshape(6) f["observation.state.end.position"] = f["observation.state.end.position"].reshape(6)
f["observation.state.end.orientation"] = f["observation.state.end.orientation"].reshape(8) f["observation.state.end.orientation"] = f["observation.state.end.orientation"].reshape(8)
f["observation.state"] = np.concatenate([ f["observation.state"] = np.concatenate(
f["observation.state.head.position"], [
f["observation.state.joint.position"][:7], # left f["observation.state.head.position"],
f["observation.state.effector.position"][[0]], # left f["observation.state.joint.position"][:7], # left
f["observation.state.joint.position"][7:], # right f["observation.state.effector.position"][[0]], # left
f["observation.state.effector.position"][[1]], # right f["observation.state.joint.position"][7:], # right
f["observation.state.waist.position"], f["observation.state.effector.position"][[1]], # right
]) f["observation.state.waist.position"],
]
)
f["action.end.position"] = f["action.end.position"].reshape(6) f["action.end.position"] = f["action.end.position"].reshape(6)
f["action.end.orientation"] = f["action.end.orientation"].reshape(8) f["action.end.orientation"] = f["action.end.orientation"].reshape(8)
f["action"] = np.concatenate([ f["action"] = np.concatenate(
f["action.head.position"], [
f["action.joint.position"][:7], # left f["action.head.position"],
f["action.effector.position"][[0]], # left f["action.joint.position"][:7], # left
f["action.joint.position"][7:], # right f["action.effector.position"][[0]], # left
f["action.effector.position"][[1]], # right f["action.joint.position"][7:], # right
f["action.waist.position"], f["action.effector.position"][[1]], # right
f["action.robot.velocity"], f["action.waist.position"],
]) f["action.robot.velocity"],
]
)
# episode level annotation # episode level annotation
f["task"] = ep_info["task_name"] f["task"] = ep_info["task_name"]
@@ -361,6 +378,7 @@ def update_meta_data(
return df.apply(_update, axis=1) return df.apply(_update, axis=1)
def move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episode_names): def move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episode_names):
keys_mapping = { keys_mapping = {
"observation.images.top_head": "head_color", "observation.images.top_head": "head_color",
@@ -378,7 +396,6 @@ def move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episo
if key not in lerobot_dataset.meta.info["features"]: if key not in lerobot_dataset.meta.info["features"]:
raise ValueError(f"Key '{key}' not found in features.") raise ValueError(f"Key '{key}' not found in features.")
video_keys = keys_mapping.keys() video_keys = keys_mapping.keys()
chunk_idx = dict.fromkeys(video_keys, 0) chunk_idx = dict.fromkeys(video_keys, 0)
file_idx = dict.fromkeys(video_keys, 0) file_idx = dict.fromkeys(video_keys, 0)
@@ -438,12 +455,15 @@ def move_videos_to_lerobot_directory(lerobot_dataset, raw_dir, task_index, episo
latest_duration_in_s[key] += ep_duration_in_s latest_duration_in_s[key] += ep_duration_in_s
# Update episodes meta data # Update episodes meta data
for meta_path in (lerobot_dataset.root / EPISODES_DIR).glob("chunk-*/file-*.parquet"): for meta_path in (lerobot_dataset.root / EPISODES_DIR).glob("chunk-*/file-*.parquet"):
df = pd.read_parquet(meta_path) df = pd.read_parquet(meta_path)
df = update_meta_data(df, ep_to_meta) df = update_meta_data(df, ep_to_meta)
df.to_parquet(meta_path) df.to_parquet(meta_path)
def port_agibot(raw_dir: Path, repo_id: str, task_index: int, episode_indices: list[int], push_to_hub: bool = False):
def port_agibot(
raw_dir: Path, repo_id: str, task_index: int, episode_indices: list[int], push_to_hub: bool = False
):
lerobot_dataset = LeRobotDataset.create( lerobot_dataset = LeRobotDataset.create(
repo_id=repo_id, repo_id=repo_id,
robot_type=AGIBOT_ROBOT_TYPE, robot_type=AGIBOT_ROBOT_TYPE,
@@ -459,7 +479,9 @@ def port_agibot(raw_dir: Path, repo_id: str, task_index: int, episode_indices: l
elapsed_time = time.time() - start_time elapsed_time = time.time() - start_time
d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time) d, h, m, s = get_elapsed_time_in_days_hours_minutes_seconds(elapsed_time)
logging.info(f"{i} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)") logging.info(
f"{i} / {num_episodes} episodes processed (after {d} days, {h} hours, {m} minutes, {s:.3f} seconds)"
)
for frame in generate_lerobot_frames(raw_dir, task_index, episode_index): for frame in generate_lerobot_frames(raw_dir, task_index, episode_index):
lerobot_dataset.add_frame(frame) lerobot_dataset.add_frame(frame)
@@ -478,4 +500,4 @@ def port_agibot(raw_dir: Path, repo_id: str, task_index: int, episode_indices: l
# Add agibot tag, since it belongs to the agibot collection of datasets # Add agibot tag, since it belongs to the agibot collection of datasets
tags=["agibot"], tags=["agibot"],
private=False, private=False,
) )
@@ -1,14 +1,17 @@
import argparse import argparse
import logging import logging
from pathlib import Path
import tarfile import tarfile
from pathlib import Path
from datatrove.executor import LocalPipelineExecutor from datatrove.executor import LocalPipelineExecutor
from datatrove.executor.slurm import SlurmPipelineExecutor from datatrove.executor.slurm import SlurmPipelineExecutor
from datatrove.pipeline.base import PipelineStep from datatrove.pipeline.base import PipelineStep
from examples.port_datasets.agibot_hdf5.download import RAW_REPO_ID, download_meta_data, get_observations_files from examples.port_datasets.agibot_hdf5.download import (
RAW_REPO_ID,
download_meta_data,
get_observations_files,
)
class PortAgiBotShards(PipelineStep): class PortAgiBotShards(PipelineStep):
@@ -23,14 +26,18 @@ class PortAgiBotShards(PipelineStep):
def run(self, data=None, rank: int = 0, world_size: int = 1): def run(self, data=None, rank: int = 0, world_size: int = 1):
import shutil import shutil
import logging
import tarfile
from datasets.utils.tqdm import disable_progress_bars from datasets.utils.tqdm import disable_progress_bars
from lerobot.common.constants import HF_LEROBOT_HOME from examples.port_datasets.agibot_hdf5.download import (
RAW_REPO_ID,
download,
get_observations_files,
no_depth,
)
from examples.port_datasets.agibot_hdf5.port_agibot import port_agibot from examples.port_datasets.agibot_hdf5.port_agibot import port_agibot
from examples.port_datasets.agibot_hdf5.download import get_observations_files, download, no_depth, RAW_REPO_ID
from examples.port_datasets.droid_rlds.port_droid import validate_dataset from examples.port_datasets.droid_rlds.port_droid import validate_dataset
from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.utils.utils import init_logging from lerobot.common.utils.utils import init_logging
init_logging() init_logging()
@@ -49,17 +56,17 @@ class PortAgiBotShards(PipelineStep):
download(self.raw_dir, allow_patterns=obs_file) download(self.raw_dir, allow_patterns=obs_file)
tar_path = self.raw_dir / obs_file tar_path = self.raw_dir / obs_file
with tarfile.open(tar_path, 'r') as tar: with tarfile.open(tar_path, "r") as tar:
extracted_files = tar.getnames() extracted_files = tar.getnames()
task_index = int(tar_path.parent.name) task_index = int(tar_path.parent.name)
episode_names = [int(p) for p in extracted_files if '/' not in p] episode_names = [int(p) for p in extracted_files if "/" not in p]
# Untar if needed # Untar if needed
if not all([(tar_path.parent / f"{ep_name}").exists() for ep_name in episode_names]): if not all((tar_path.parent / f"{ep_name}").exists() for ep_name in episode_names):
logging.info(f"Untar-ing {tar_path}...") logging.info(f"Untar-ing {tar_path}...")
with tarfile.open(tar_path, 'r') as tar: with tarfile.open(tar_path, "r") as tar:
tar.extractall(path=tar_path.parent, filter=no_depth) tar.extractall(path=tar_path.parent, filter=no_depth) # nosec B202
port_agibot(self.raw_dir, shard_repo_id, task_index, episode_names, push_to_hub=False) port_agibot(self.raw_dir, shard_repo_id, task_index, episode_names, push_to_hub=False)
@@ -1,30 +1,31 @@
import argparse import argparse
from pathlib import Path
import json import json
from pathlib import Path
def find_missings(completions_dir, world_size):
""" Find workers that are not completed and returns their indices. def find_missing_workers(completions_dir, world_size):
""" """Find workers that are not completed and returns their indices."""
full = list(range(world_size)) full = list(range(world_size))
completed = [] completed = []
for path in completions_dir.glob("*"): for path in completions_dir.glob("*"):
if path.name in ['.', '..']: if path.name in [".", ".."]:
continue continue
index = path.name.lstrip('0') index = path.name.lstrip("0")
index = 0 if index == "" else int(index) index = 0 if index == "" else int(index)
completed.append(index) completed.append(index)
missings = set(full) - set(completed) missing_workers = set(full) - set(completed)
return missings return missing_workers
def find_output_files(slurm_dir, worker_indices): def find_output_files(slurm_dir, worker_indices):
""" Find output files associated to worker indices, and return tuples """Find output files associated to worker indices, and return tuples
of (worker index, output file path) of (worker index, output file path)
""" """
out_files = [] out_files = []
for path in slurm_dir.glob("*.out"): for path in slurm_dir.glob("*.out"):
_, worker_id = path.name.replace(".out", "").split('_') _, worker_id = path.name.replace(".out", "").split("_")
worker_id = int(worker_id) worker_id = int(worker_id)
if worker_id in worker_indices: if worker_id in worker_indices:
out_files.append((worker_id, path)) out_files.append((worker_id, path))
@@ -34,22 +35,15 @@ def find_output_files(slurm_dir, worker_indices):
def display_error_files(logs_dir, job_name): def display_error_files(logs_dir, job_name):
executor_path = Path(logs_dir) / job_name / "executor.json" executor_path = Path(logs_dir) / job_name / "executor.json"
completions_dir = Path(logs_dir) / job_name / "completions" completions_dir = Path(logs_dir) / job_name / "completions"
slurm_dir = Path(logs_dir) / job_name / "slurm_logs"
with open(executor_path) as f: with open(executor_path) as f:
executor = json.load(f) executor = json.load(f)
missings = find_missings(completions_dir, executor["world_size"]) missing_workers = find_missing_workers(completions_dir, executor["world_size"])
for missing in sorted(list(missings))[::-1]: for missing in sorted(missing_workers)[::-1]:
print(missing) print(missing)
# error_files = find_output_files(slurm_dir, missings)
# error_files = sorted(error_files, key=lambda x: x[0])
# for _, path in error_files[::-1]:
# print(path)
def main(): def main():
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@@ -70,5 +64,6 @@ def main():
display_error_files(**vars(args)) display_error_files(**vars(args))
if __name__ == "__main__": if __name__ == "__main__":
main() main()
@@ -32,8 +32,8 @@ class PortDroidShards(PipelineStep):
try: try:
validate_dataset(shard_repo_id) validate_dataset(shard_repo_id)
return return
except: except Exception:
pass pass # nosec B110 - Dataset doesn't exist yet, continue with porting
port_droid( port_droid(
self.raw_dir, self.raw_dir,
+1 -1
View File
@@ -31,7 +31,7 @@ from datasets import Dataset
from huggingface_hub import HfApi, snapshot_download from huggingface_hub import HfApi, snapshot_download
from huggingface_hub.constants import REPOCARD_NAME from huggingface_hub.constants import REPOCARD_NAME
from huggingface_hub.errors import RevisionNotFoundError from huggingface_hub.errors import RevisionNotFoundError
from torch.profiler import record_function
from lerobot.common.constants import HF_LEROBOT_HOME from lerobot.common.constants import HF_LEROBOT_HOME
from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats from lerobot.common.datasets.compute_stats import aggregate_stats, compute_episode_stats
from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image from lerobot.common.datasets.image_writer import AsyncImageWriter, write_image
+2 -2
View File
@@ -940,8 +940,8 @@ def validate_episode_buffer(episode_buffer: dict, total_episodes: int, features:
def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path): def to_parquet_with_hf_images(df: pandas.DataFrame, path: Path):
""" This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset. """This function correctly writes to parquet a panda DataFrame that contains images encoded by HF dataset.
This way, it can be loaded by HF dataset and correctly formated images are returned. This way, it can be loaded by HF dataset and correctly formatted images are returned.
""" """
# TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only # TODO(qlhoest): replace this weird synthax by `df.to_parquet(path)` only
datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path) datasets.Dataset.from_dict(df.to_dict(orient="list")).to_parquet(path)