mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
add: support for videos generation in datasets
This commit is contained in:
committed by
Michel Aractingi
parent
8c1503dafa
commit
d4fbf6ef39
Vendored
+63
-18
@@ -69,15 +69,49 @@ def img_tensor_factory():
|
||||
|
||||
@pytest.fixture(scope="session")
|
||||
def img_array_factory():
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8) -> np.ndarray:
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
def _create_img_array(height=100, width=100, channels=3, dtype=np.uint8, content=None) -> np.ndarray:
|
||||
if content is None:
|
||||
# Original random noise behavior
|
||||
if np.issubdtype(dtype, np.unsignedinteger):
|
||||
# Int array in [0, 255] range
|
||||
img_array = np.random.randint(0, 256, size=(height, width, channels), dtype=dtype)
|
||||
elif np.issubdtype(dtype, np.floating):
|
||||
# Float array in [0, 1] range
|
||||
img_array = np.random.rand(height, width, channels).astype(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
else:
|
||||
raise ValueError(dtype)
|
||||
# Create image with text content using OpenCV
|
||||
import cv2
|
||||
|
||||
# Create white background
|
||||
img_array = np.ones((height, width, channels), dtype=np.uint8) * 255
|
||||
|
||||
# Font settings
|
||||
font = cv2.FONT_HERSHEY_SIMPLEX
|
||||
font_scale = max(0.5, height / 200) # Scale font with image size
|
||||
font_color = (0, 0, 0) # Black text
|
||||
thickness = max(1, int(height / 100))
|
||||
|
||||
# Get text size to center it
|
||||
text_size = cv2.getTextSize(content, font, font_scale, thickness)[0]
|
||||
text_x = (width - text_size[0]) // 2
|
||||
text_y = (height + text_size[1]) // 2
|
||||
|
||||
# Put text on image
|
||||
cv2.putText(img_array, content, (text_x, text_y), font, font_scale, font_color, thickness)
|
||||
|
||||
# Handle single channel case
|
||||
if channels == 1:
|
||||
img_array = cv2.cvtColor(img_array, cv2.COLOR_BGR2GRAY)
|
||||
img_array = img_array[:, :, np.newaxis]
|
||||
|
||||
# Convert to target dtype
|
||||
if np.issubdtype(dtype, np.floating):
|
||||
img_array = img_array.astype(dtype) / 255.0
|
||||
else:
|
||||
img_array = img_array.astype(dtype)
|
||||
|
||||
return img_array
|
||||
|
||||
return _create_img_array
|
||||
@@ -97,7 +131,7 @@ def features_factory():
|
||||
def _create_features(
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = False,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
if use_videos:
|
||||
camera_ft = {
|
||||
@@ -131,7 +165,7 @@ def info_factory(features_factory):
|
||||
video_path: str = DEFAULT_VIDEO_PATH,
|
||||
motor_features: dict = DUMMY_MOTOR_FEATURES,
|
||||
camera_features: dict = DUMMY_CAMERA_FEATURES,
|
||||
use_videos: bool = False,
|
||||
use_videos: bool = True,
|
||||
) -> dict:
|
||||
features = features_factory(motor_features, camera_features, use_videos)
|
||||
return {
|
||||
@@ -320,17 +354,20 @@ def create_videos(info_factory, img_array_factory):
|
||||
|
||||
video_feats = {key: feats for key, feats in info["features"].items() if feats["dtype"] == "video"}
|
||||
for key, ft in video_feats.items():
|
||||
# create and save images
|
||||
# create and save images with identifiable content
|
||||
tmp_dir = root / "tmp_images"
|
||||
tmp_dir.mkdir(parents=True, exist_ok=True)
|
||||
for frame_index in range(info["total_frames"]):
|
||||
img = img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
content = f"{key}-{frame_index}"
|
||||
img = img_array_factory(height=ft["shape"][0], width=ft["shape"][1], content=content)
|
||||
pil_img = PIL.Image.fromarray(img)
|
||||
path = tmp_dir / f"frame-{frame_index:06d}.png"
|
||||
pil_img.save(path)
|
||||
|
||||
video_path = root / DEFAULT_VIDEO_PATH.format(video_key=key, chunk_index=0, file_index=0)
|
||||
encode_video_frames(tmp_dir, video_path, fps=ft["video.fps"])
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
# Use the global fps from info, not video-specific fps which might not exist
|
||||
encode_video_frames(tmp_dir, video_path, fps=info["fps"])
|
||||
shutil.rmtree(tmp_dir)
|
||||
|
||||
return _create_video_directory
|
||||
@@ -372,8 +409,8 @@ def hf_dataset_factory(features_factory, tasks_factory, episodes_factory, img_ar
|
||||
for key, ft in features.items():
|
||||
if ft["dtype"] == "image":
|
||||
robot_cols[key] = [
|
||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0])
|
||||
for _ in range(len(index_col))
|
||||
img_array_factory(height=ft["shape"][1], width=ft["shape"][0], content=f"{key}-{i}")
|
||||
for i in range(len(index_col))
|
||||
]
|
||||
elif ft["shape"][0] > 1 and ft["dtype"] != "video":
|
||||
robot_cols[key] = np.random.random((len(index_col), ft["shape"][0])).astype(ft["dtype"])
|
||||
@@ -464,6 +501,7 @@ def lerobot_dataset_factory(
|
||||
total_frames: int = 150,
|
||||
total_tasks: int = 1,
|
||||
multi_task: bool = False,
|
||||
use_videos: bool = True,
|
||||
info: dict | None = None,
|
||||
stats: dict | None = None,
|
||||
tasks: pd.DataFrame | None = None,
|
||||
@@ -474,23 +512,30 @@ def lerobot_dataset_factory(
|
||||
# Instantiate objects
|
||||
if info is None:
|
||||
info = info_factory(
|
||||
total_episodes=total_episodes, total_frames=total_frames, total_tasks=total_tasks
|
||||
total_episodes=total_episodes,
|
||||
total_frames=total_frames,
|
||||
total_tasks=total_tasks,
|
||||
use_videos=use_videos,
|
||||
)
|
||||
if stats is None:
|
||||
stats = stats_factory(features=info["features"])
|
||||
if tasks is None:
|
||||
tasks = tasks_factory(total_tasks=info["total_tasks"])
|
||||
if episodes_metadata is None:
|
||||
video_keys = [key for key, ft in info["features"].items() if ft["dtype"] == "video"]
|
||||
episodes_metadata = episodes_factory(
|
||||
features=info["features"],
|
||||
fps=info["fps"],
|
||||
total_episodes=info["total_episodes"],
|
||||
total_frames=info["total_frames"],
|
||||
video_keys=video_keys,
|
||||
tasks=tasks,
|
||||
multi_task=multi_task,
|
||||
)
|
||||
if not hf_dataset:
|
||||
hf_dataset = hf_dataset_factory(tasks=tasks, episodes=episodes_metadata, fps=info["fps"])
|
||||
if hf_dataset is None:
|
||||
hf_dataset = hf_dataset_factory(
|
||||
features=info["features"], tasks=tasks, episodes=episodes_metadata, fps=info["fps"]
|
||||
)
|
||||
|
||||
# Write data on disk
|
||||
mock_snapshot_download = mock_snapshot_download_factory(
|
||||
|
||||
Reference in New Issue
Block a user