feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)

* feat(dataset): 2xfaster dataloader

* fix(dataset): streaming return uint8 decode

* fix(tests): adjust normalization step comparison

* fix(dataset): with threadexecutor + False default

* chore(dataset): make it a config

* fix(test): account for uint8 in training path testing
This commit is contained in:
Steven Palma
2026-04-19 00:08:22 +02:00
committed by GitHub
parent 760220d532
commit a8b72d9615
10 changed files with 78 additions and 18 deletions
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
)
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# indicating padding (those ending with "_is_pad")
dataset.reader.delta_indices = None
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
obs = {}
for k in batch:
# TODO: regenerate the safetensors