mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-16 07:49:48 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 2b83956eb5 | |||
| 7309790d56 |
@@ -474,6 +474,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
if reader.hf_dataset is None:
|
||||
# One-shot load after finalize()
|
||||
reader.load_and_activate()
|
||||
if reader._absolute_to_relative_idx is not None and idx in reader._absolute_to_relative_idx:
|
||||
idx = reader._absolute_to_relative_idx[idx]
|
||||
return reader.get_item(idx)
|
||||
|
||||
def select_columns(self, column_names: str | list[str]):
|
||||
|
||||
@@ -435,8 +435,6 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
f"Resuming data order at epoch {sampler_state['epoch']}, "
|
||||
f"sample {sampler_state['start_index']}"
|
||||
)
|
||||
if dataset.reader._absolute_to_relative_idx is not None:
|
||||
sampler.indices = [dataset.reader._absolute_to_relative_idx[i] for i in sampler.indices]
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -463,13 +461,13 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
if eval_dataset is not None:
|
||||
eval_ds = eval_dataset
|
||||
if cfg.max_eval_samples > 0 and hasattr(eval_dataset, "hf_dataset"):
|
||||
task_indices = eval_dataset.hf_dataset["task_index"]
|
||||
unique_tasks = sorted(set(task_indices))
|
||||
task_arr = eval_dataset.hf_dataset.data.column("task_index").to_numpy()
|
||||
unique_tasks = sorted(set(task_arr.tolist()))
|
||||
per_task = max(1, cfg.max_eval_samples // len(unique_tasks))
|
||||
selected: list[int] = []
|
||||
for t in unique_tasks:
|
||||
frames = [i for i, ti in enumerate(task_indices) if ti == t][:per_task]
|
||||
selected.extend(frames)
|
||||
frames = (task_arr == t).nonzero()[0][:per_task]
|
||||
selected.extend(frames.tolist())
|
||||
eval_ds = torch.utils.data.Subset(eval_dataset, selected)
|
||||
|
||||
eval_collate_fn = lerobot_collate_fn if dataset.meta.has_language_columns else None
|
||||
@@ -481,6 +479,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
||||
pin_memory=device.type == "cuda",
|
||||
drop_last=False,
|
||||
collate_fn=eval_collate_fn,
|
||||
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||
)
|
||||
|
||||
# Prepare everything with accelerator
|
||||
|
||||
Reference in New Issue
Block a user