Validate features during add_frame + Add 2D-to-5D + Add string (#720)

This commit is contained in:
Remi
2025-02-14 19:59:48 +01:00
committed by GitHub
parent 9d6886dd08
commit 7c2bbee613
8 changed files with 448 additions and 53 deletions
+9 -12
View File
@@ -39,6 +39,7 @@ from lerobot.common.datasets.utils import (
TASKS_PATH,
append_jsonlines,
check_delta_timestamps,
check_frame_features,
check_timestamps_sync,
check_version_compatibility,
create_branch,
@@ -724,10 +725,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
temporary directory — nothing is written to disk. To save those frames, the 'save_episode()' method
then needs to be called.
"""
# TODO(aliberts, rcadene): Add sanity check for the input, check it's numpy or torch,
# check the dtype and shape matches, etc.
if "task" not in frame:
raise ValueError("The mandatory feature 'task' wasn't found in `frame` dictionnary.")
# Convert torch to numpy if needed
for name in frame:
if isinstance(frame[name], torch.Tensor):
frame[name] = frame[name].numpy()
check_frame_features(frame, self.features)
if self.episode_buffer is None:
self.episode_buffer = self.create_episode_buffer()
@@ -757,8 +760,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
self._save_image(frame[key], img_path)
self.episode_buffer[key].append(str(img_path))
else:
item = frame[key].numpy() if isinstance(frame[key], torch.Tensor) else frame[key]
self.episode_buffer[key].append(item)
self.episode_buffer[key].append(frame[key])
self.episode_buffer["size"] += 1
@@ -815,12 +817,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# are processed separately by storing image path and frame info as meta data
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue
elif len(ft["shape"]) == 1 and ft["shape"][0] == 1:
episode_buffer[key] = np.array(episode_buffer[key], dtype=ft["dtype"])
elif len(ft["shape"]) == 1 and ft["shape"][0] > 1:
episode_buffer[key] = np.stack(episode_buffer[key])
else:
raise ValueError(key)
episode_buffer[key] = np.stack(episode_buffer[key])
self._wait_image_writer()
self._save_episode_table(episode_buffer, episode_index)