fixed bug in crop_dataset_roi.py

added missing buffer.pt in server dir

Co-authored-by: Adil Zouitine <adilzouitinegm@gmail.com>
This commit is contained in:
Michel Aractingi
2025-02-05 18:22:50 +00:00
parent d143043037
commit 273fa2e6e1
2 changed files with 586 additions and 30 deletions
+26 -30
View File
@@ -187,43 +187,39 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# 2. Process each episode in the original dataset.
episodes_info = original_dataset.meta.episodes
# (Sort episodes by episode_index for consistency.)
episodes_info = sorted(episodes_info, key=lambda x: x["episode_index"])
# Use the first task from the episode metadata (or "unknown" if not provided)
task = episodes_info[0]["tasks"][0] if episodes_info[0].get("tasks") else "unknown"
for ep in tqdm(episodes_info[:3], desc="Processing episodes"):
ep_index = ep.pop("episode_index")
# Use the first task from the episode metadata (or "unknown" if not provided)
task = ep["tasks"][0] if ep.get("tasks") else "unknown"
last_episode_index = 0
for sample in tqdm(original_dataset):
episode_index = sample.pop("episode_index")
if episode_index != last_episode_index:
new_dataset.save_episode(task, encode_videos=True)
last_episode_index = episode_index
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# Reset the episode buffer in the new dataset (this will store frames for one episode).
new_dataset.episode_buffer = new_dataset.create_episode_buffer(episode_index=ep_index)
# 3. Filter and process all frames belonging to this episode.
# Here we loop over the entire dataset and select the frames with the matching episode_index.
# (Depending on the dataset size, you might want a more efficient method.)
ep_frames = [sample for sample in original_dataset if sample["episode_index"] == ep_index]
for sample in tqdm(ep_frames):
sample.pop("episode_index")
sample.pop("frame_index")
# Make a shallow copy of the sample (the values—e.g. torch tensors—are assumed immutable)
new_sample = sample.copy()
# Loop over each observation key that should be cropped/resized.
for key, params in crop_params_dict.items():
if key in new_sample:
top, left, height, width = params
# Apply crop then resize.
cropped = F.crop(new_sample[key], top, left, height, width)
resized = F.resize(cropped, resize_size)
new_sample[key] = resized
# Add the transformed frame to the new dataset.
new_dataset.add_frame(new_sample)
# 4. Save the episode (this writes the parquet file and image files).
new_dataset.save_episode(task, encode_videos=True)
# save last episode
new_dataset.save_episode(task, encode_videos=True)
# Optionally, consolidate the new dataset to compute statistics and update video info.
new_dataset.consolidate(run_compute_stats=True, keep_image_files=True)
new_dataset.push_to_hub(tags=None)
return new_dataset