mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-21 19:49:49 +00:00
add: support for videos generation in datasets
This commit is contained in:
committed by
Michel Aractingi
parent
8c1503dafa
commit
d4fbf6ef39
Vendored
+63
-18
@@ -69,15 +69,49 @@ def img_tensor_factory():
|
|||||||
|
|
||||||
@pytest.fixture(scope="session")
|
@pytest.fixture(scope="session")
|
||||||
def img_array_factory():
|
def img_array_factory():
|
||||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray:
|
||||||
if np.issubdtype(dtype, np.unsignedinteger):
|
if content is None:
|
||||||
# Int array in [0, 255] range
|
# Original random noise behavior
|
||||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
if np.issubdtype(dtype, np.unsignedinteger):
|
||||||
elif np.issubdtype(dtype, np.floating):
|
# Int array in [0, 255] range
|
||||||
# Float array in [0, 1] range
|
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
elif np.issubdtype(dtype, np.floating):
|
||||||
|
# Float array in [0, 1] range
|
||||||
|
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||||
|
else:
|
||||||
|
raise ValueError(dtype)
|
||||||
else:
|
else:
|
||||||
raise ValueError(dtype)
|
# Create image with text content using OpenCV
|
||||||
|
import cv2
|
||||||
|
|
||||||
|
# Create white background
|
||||||
|
img_array = np.ones((height, width, channels), dtype=np.uint8) * 255
|
||||||
|
|
||||||
|
# Font settings
|
||||||
|
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||||
|
font_scale = max(0.5, height / 200) # Scale font with image size
|
||||||
|
font_color = (0, 0, 0) # Black text
|
||||||
|
thickness = max(1, int(height / 100))
|
||||||
|
|
||||||
|
# Get text size to center it
|
||||||
|
text_size = cv2.getTextSize(content, font, font_scale, thickness)[0]
|
||||||
|
text_x = (width - text_size[0]) // 2
|
||||||
|
text_y = (height + text_size[1]) // 2
|
||||||
|
|
||||||
|
# Put text on image
|
||||||
|
cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness)
|
||||||
|
|
||||||
|
# Handle single channel case
|
||||||
|
if channels == 1:
|
||||||
|
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
|
||||||
|
img_array = img_array[:, :, np.newaxis]
|
||||||
|
|
||||||
|
# Convert to target dtype
|
||||||
|
if np.issubdtype(dtype, np.floating):
|
||||||
|
img_array = img_array.astype(dtype) / 255.0
|
||||||
|
else:
|
||||||
|
img_array = img_array.astype(dtype)
|
||||||
|
|
||||||
return img_array
|
return img_array
|
||||||
|
|
||||||
return _create_img_array
|
return _create_img_array
|
||||||
@@ -97,7 +131,7 @@ def features_factory():
|
|||||||
def _create_features(
|
def _create_features(
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
use_videos: bool = False,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
if use_videos:
|
if use_videos:
|
||||||
camera_ft = {
|
camera_ft = {
|
||||||
@@ -131,7 +165,7 @@ def info_factory(features_factory):
|
|||||||
video_path: str = DEFAULT_VIDEO_PATH,
|
video_path: str = DEFAULT_VIDEO_PATH,
|
||||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||||
use_videos: bool = False,
|
use_videos: bool = True,
|
||||||
) -> dict:
|
) -> dict:
|
||||||
features = features_factory(motor_features, camera_features, use_videos)
|
features = features_factory(motor_features, camera_features, use_videos)
|
||||||
return {
|
return {
|
||||||
@@ -320,17 +354,20 @@ def create_videos(info_factory, img_array_factory):
|
|||||||
|
|
||||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||||
for key, ft in video_feats.items():
|
for key, ft in video_feats.items():
|
||||||
# create and save images
|
# create and save images with identifiable content
|
||||||
tmp_dir = root / "tmp_images"
|
tmp_dir = root / "tmp_images"
|
||||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||||
for frame_index in range(info["total_frames"]):
|
for frame_index in range(info["total_frames"]):
|
||||||
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
content = f"{key}-{frame_index}"
|
||||||
|
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
|
||||||
pil_img = PIL.Image.fromarray(img)
|
pil_img = PIL.Image.fromarray(img)
|
||||||
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||||
pil_img.save(path)
|
pil_img.save(path)
|
||||||
|
|
||||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||||
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
|
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||||
|
# Use the global fps from info, not video-specific fps which might not exist
|
||||||
|
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
|
||||||
shutil.rmtree(tmp_dir)
|
shutil.rmtree(tmp_dir)
|
||||||
|
|
||||||
return _create_video_directory
|
return _create_video_directory
|
||||||
@@ -372,8 +409,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
|||||||
for key, ft in features.items():
|
for key, ft in features.items():
|
||||||
if ft["dtype"] == "image":
|
if ft["dtype"] == "image":
|
||||||
robot_cols[key] = [
|
robot_cols[key] = [
|
||||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}")
|
||||||
for _ in range(len(index_col))
|
for i in range(len(index_col))
|
||||||
]
|
]
|
||||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||||
@@ -464,6 +501,7 @@ def lerobot_dataset_factory(
|
|||||||
total_frames: int = 150,
|
total_frames: int = 150,
|
||||||
total_tasks: int = 1,
|
total_tasks: int = 1,
|
||||||
multi_task: bool = False,
|
multi_task: bool = False,
|
||||||
|
use_videos: bool = True,
|
||||||
info: dict | None = None,
|
info: dict | None = None,
|
||||||
stats: dict | None = None,
|
stats: dict | None = None,
|
||||||
tasks: pd.DataFrame | None = None,
|
tasks: pd.DataFrame | None = None,
|
||||||
@@ -474,23 +512,30 @@ def lerobot_dataset_factory(
|
|||||||
# Instantiate objects
|
# Instantiate objects
|
||||||
if info is None:
|
if info is None:
|
||||||
info = info_factory(
|
info = info_factory(
|
||||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
total_episodes=total_episodes,
|
||||||
|
total_frames=total_frames,
|
||||||
|
total_tasks=total_tasks,
|
||||||
|
use_videos=use_videos,
|
||||||
)
|
)
|
||||||
if stats is None:
|
if stats is None:
|
||||||
stats = stats_factory(features=info["features"])
|
stats = stats_factory(features=info["features"])
|
||||||
if tasks is None:
|
if tasks is None:
|
||||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||||
if episodes_metadata is None:
|
if episodes_metadata is None:
|
||||||
|
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||||
episodes_metadata = episodes_factory(
|
episodes_metadata = episodes_factory(
|
||||||
features=info["features"],
|
features=info["features"],
|
||||||
fps=info["fps"],
|
fps=info["fps"],
|
||||||
total_episodes=info["total_episodes"],
|
total_episodes=info["total_episodes"],
|
||||||
total_frames=info["total_frames"],
|
total_frames=info["total_frames"],
|
||||||
|
video_keys=video_keys,
|
||||||
tasks=tasks,
|
tasks=tasks,
|
||||||
multi_task=multi_task,
|
multi_task=multi_task,
|
||||||
)
|
)
|
||||||
if not hf_dataset:
|
if hf_dataset is None:
|
||||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
hf_dataset = hf_dataset_factory(
|
||||||
|
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
|
||||||
|
)
|
||||||
|
|
||||||
# Write data on disk
|
# Write data on disk
|
||||||
mock_snapshot_download = mock_snapshot_download_factory(
|
mock_snapshot_download = mock_snapshot_download_factory(
|
||||||
|
|||||||
Reference in New Issue
Block a user