mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 12:40:08 +00:00
fix(tests) add features argument to load_nested_dataset
This commit is contained in:
@@ -652,7 +652,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
|
|
||||||
def load_hf_dataset(self) -> datasets.Dataset:
|
def load_hf_dataset(self) -> datasets.Dataset:
|
||||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||||
hf_dataset = load_nested_dataset(self.root / "data")
|
features = get_hf_features_from_features(self.features)
|
||||||
|
hf_dataset = load_nested_dataset(self.root / "data", features=features)
|
||||||
hf_dataset.set_transform(hf_transform_to_torch)
|
hf_dataset.set_transform(hf_transform_to_torch)
|
||||||
return hf_dataset
|
return hf_dataset
|
||||||
|
|
||||||
|
|||||||
@@ -116,17 +116,21 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int):
|
|||||||
return chunk_idx, file_idx
|
return chunk_idx, file_idx
|
||||||
|
|
||||||
|
|
||||||
def load_nested_dataset(pq_dir: Path) -> Dataset:
|
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
|
||||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||||
Concatenate all pyarrow references to return HF Dataset format
|
Concatenate all pyarrow references to return HF Dataset format
|
||||||
|
|
||||||
|
Args:
|
||||||
|
pq_dir: Directory containing parquet files
|
||||||
|
features: Optional features schema to ensure consistent loading of complex types like images
|
||||||
"""
|
"""
|
||||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||||
if len(paths) == 0:
|
if len(paths) == 0:
|
||||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||||
|
|
||||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||||
datasets = [Dataset.from_parquet(str(path)) for path in paths]
|
datasets = [Dataset.from_parquet(str(path), features=features) for path in paths]
|
||||||
return concatenate_datasets(datasets)
|
return concatenate_datasets(datasets)
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -564,10 +564,7 @@ class ReplayBuffer:
|
|||||||
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
lerobot_dataset.start_image_writer(num_processes=0, num_threads=3)
|
||||||
|
|
||||||
# Convert transitions into episodes and frames
|
# Convert transitions into episodes and frames
|
||||||
episode_index = 0
|
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index)
|
|
||||||
|
|
||||||
frame_idx_in_episode = 0
|
|
||||||
for idx in range(self.size):
|
for idx in range(self.size):
|
||||||
actual_idx = (self.position - self.size + idx) % self.capacity
|
actual_idx = (self.position - self.size + idx) % self.capacity
|
||||||
|
|
||||||
@@ -581,6 +578,7 @@ class ReplayBuffer:
|
|||||||
frame_dict["action"] = self.actions[actual_idx].cpu()
|
frame_dict["action"] = self.actions[actual_idx].cpu()
|
||||||
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||||
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||||
|
frame_dict["task"] = task_name
|
||||||
|
|
||||||
# Add complementary_info if available
|
# Add complementary_info if available
|
||||||
if self.has_complementary_info:
|
if self.has_complementary_info:
|
||||||
@@ -596,20 +594,14 @@ class ReplayBuffer:
|
|||||||
frame_dict[f"complementary_info.{key}"] = val
|
frame_dict[f"complementary_info.{key}"] = val
|
||||||
|
|
||||||
# Add to the dataset's buffer
|
# Add to the dataset's buffer
|
||||||
frame_dict["task"] = task_name
|
|
||||||
lerobot_dataset.add_frame(frame_dict)
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
|
||||||
# Move to next frame
|
# Move to next frame
|
||||||
frame_idx_in_episode += 1
|
# frame_idx_in_episode += 1
|
||||||
|
|
||||||
# If we reached an episode boundary, call save_episode, reset counters
|
# If we reached an episode boundary, call save_episode, reset counters
|
||||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||||
lerobot_dataset.save_episode()
|
lerobot_dataset.save_episode()
|
||||||
episode_index += 1
|
|
||||||
frame_idx_in_episode = 0
|
|
||||||
lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(
|
|
||||||
episode_index=episode_index
|
|
||||||
)
|
|
||||||
|
|
||||||
# Save any remaining frames in the buffer
|
# Save any remaining frames in the buffer
|
||||||
if lerobot_dataset.episode_buffer["size"] > 0:
|
if lerobot_dataset.episode_buffer["size"] > 0:
|
||||||
|
|||||||
@@ -384,7 +384,7 @@ def test_to_lerobot_dataset(tmp_path):
|
|||||||
elif feature == "next.done":
|
elif feature == "next.done":
|
||||||
assert torch.equal(value, buffer.dones[i])
|
assert torch.equal(value, buffer.dones[i])
|
||||||
elif feature == "observation.image":
|
elif feature == "observation.image":
|
||||||
# Tenssor -> numpy is not precise, so we have some diff there
|
# Tensor -> numpy is not precise, so we have some diff there
|
||||||
# TODO: Check and fix it
|
# TODO: Check and fix it
|
||||||
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003)
|
||||||
elif feature == "observation.state":
|
elif feature == "observation.state":
|
||||||
|
|||||||
Reference in New Issue
Block a user