diff --git a/src/lerobot/datasets/lerobot_dataset.py b/src/lerobot/datasets/lerobot_dataset.py index b4777582d..7d578e77b 100644 --- a/src/lerobot/datasets/lerobot_dataset.py +++ b/src/lerobot/datasets/lerobot_dataset.py @@ -652,7 +652,8 @@ class LeRobotDataset(torch.utils.data.Dataset): def load_hf_dataset(self) -> datasets.Dataset: """hf_dataset contains all the observations, states, actions, rewards, etc.""" - hf_dataset = load_nested_dataset(self.root / "data") + features = get_hf_features_from_features(self.features) + hf_dataset = load_nested_dataset(self.root / "data", features=features) hf_dataset.set_transform(hf_transform_to_torch) return hf_dataset diff --git a/src/lerobot/datasets/utils.py b/src/lerobot/datasets/utils.py index bd4005443..1d1101f59 100644 --- a/src/lerobot/datasets/utils.py +++ b/src/lerobot/datasets/utils.py @@ -116,17 +116,21 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int): return chunk_idx, file_idx -def load_nested_dataset(pq_dir: Path) -> Dataset: +def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset: """Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage Concatenate all pyarrow references to return HF Dataset format + + Args: + pq_dir: Directory containing parquet files + features: Optional features schema to ensure consistent loading of complex types like images """ paths = sorted(pq_dir.glob("*/*.parquet")) if len(paths) == 0: raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}") # TODO(rcadene): set num_proc to accelerate conversion to pyarrow - datasets = [Dataset.from_parquet(str(path)) for path in paths] + datasets = [Dataset.from_parquet(str(path), features=features) for path in paths] return concatenate_datasets(datasets) diff --git a/src/lerobot/utils/buffer.py b/src/lerobot/utils/buffer.py index e276ef453..ab075f949 100644 --- a/src/lerobot/utils/buffer.py +++ b/src/lerobot/utils/buffer.py @@ -564,10 +564,7 @@ class ReplayBuffer: lerobot_dataset.start_image_writer(num_processes=0, num_threads=3) # Convert transitions into episodes and frames - episode_index = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer(episode_index=episode_index) - frame_idx_in_episode = 0 for idx in range(self.size): actual_idx = (self.position - self.size + idx) % self.capacity @@ -581,6 +578,7 @@ class ReplayBuffer: frame_dict["action"] = self.actions[actual_idx].cpu() frame_dict["next.reward"] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict["next.done"] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + frame_dict["task"] = task_name # Add complementary_info if available if self.has_complementary_info: @@ -596,20 +594,14 @@ class ReplayBuffer: frame_dict[f"complementary_info.{key}"] = val # Add to the dataset's buffer - frame_dict["task"] = task_name lerobot_dataset.add_frame(frame_dict) # Move to next frame - frame_idx_in_episode += 1 + # frame_idx_in_episode += 1 # If we reached an episode boundary, call save_episode, reset counters if self.dones[actual_idx] or self.truncateds[actual_idx]: lerobot_dataset.save_episode() - episode_index += 1 - frame_idx_in_episode = 0 - lerobot_dataset.episode_buffer = lerobot_dataset.create_episode_buffer( - episode_index=episode_index - ) # Save any remaining frames in the buffer if lerobot_dataset.episode_buffer["size"] > 0: diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index 260276032..b616334ce 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -384,7 +384,7 @@ def test_to_lerobot_dataset(tmp_path): elif feature == "next.done": assert torch.equal(value, buffer.dones[i]) elif feature == "observation.image": - # Tenssor -> numpy is not precise, so we have some diff there + # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it torch.testing.assert_close(value, buffer.states["observation.image"][i], rtol=0.3, atol=0.003) elif feature == "observation.state":