remove the sampler cause the relative index is added (#2521)

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
This commit is contained in:
Sota Nakamura
2025-12-01 06:28:32 +09:00
committed by GitHub
parent c55fbe1b3e
commit 5f7b5f2817
@@ -65,7 +65,6 @@ import argparse
import gc import gc
import logging import logging
import time import time
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3 assert chw_float32_torch.ndim == 3
@@ -119,12 +105,10 @@ def visualize_dataset(
repo_id = dataset.repo_id repo_id = dataset.repo_id
logging.info("Loading dataloader") logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
sampler=episode_sampler,
) )
logging.info("Starting Rerun") logging.info("Starting Rerun")