Merge branch 'main' into feat/add-xvla

This commit is contained in:
Jade Choghari
2025-12-01 09:05:08 +01:00
committed by GitHub
@@ -65,7 +65,6 @@ import argparse
import gc import gc
import logging import logging
import time import time
from collections.abc import Iterator
from pathlib import Path from pathlib import Path
import numpy as np import numpy as np
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
class EpisodeSampler(torch.utils.data.Sampler):
def __init__(self, dataset: LeRobotDataset, episode_index: int):
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
self.frame_ids = range(from_idx, to_idx)
def __iter__(self) -> Iterator:
return iter(self.frame_ids)
def __len__(self) -> int:
return len(self.frame_ids)
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32 assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3 assert chw_float32_torch.ndim == 3
@@ -119,12 +105,10 @@ def visualize_dataset(
repo_id = dataset.repo_id repo_id = dataset.repo_id
logging.info("Loading dataloader") logging.info("Loading dataloader")
episode_sampler = EpisodeSampler(dataset, episode_index)
dataloader = torch.utils.data.DataLoader( dataloader = torch.utils.data.DataLoader(
dataset, dataset,
num_workers=num_workers, num_workers=num_workers,
batch_size=batch_size, batch_size=batch_size,
sampler=episode_sampler,
) )
logging.info("Starting Rerun") logging.info("Starting Rerun")