feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)

* feat(dataset): 2xfaster dataloader

* fix(dataset): streaming return uint8 decode

* fix(tests): adjust normalization step comparison

* fix(dataset): with threadexecutor + False default

* chore(dataset): make it a config

* fix(test): account for uint8 in training path testing
This commit is contained in:
Steven Palma
2026-04-19 00:08:22 +02:00
committed by GitHub
parent 760220d532
commit a8b72d9615
10 changed files with 78 additions and 18 deletions
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
)
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch = preprocessor(batch)
loss, output_dict = policy.forward(batch)
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
# indicating padding (those ending with "_is_pad")
dataset.reader.delta_indices = None
batch = next(iter(dataloader))
for key in batch:
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
obs = {}
for k in batch:
# TODO: regenerate the safetensors
+2
View File
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
for key in batch:
if isinstance(batch[key], torch.Tensor):
if batch[key].dtype == torch.uint8:
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
batch[key] = batch[key].to(DEVICE, non_blocking=True)
# Test updating the policy (and test that it does not mutate the batch)