mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 11:09:59 +00:00
feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)
* feat(dataset): 2xfaster dataloader * fix(dataset): streaming return uint8 decode * fix(tests): adjust normalization step comparison * fix(dataset): with threadexecutor + False default * chore(dataset): make it a config * fix(test): account for uint8 in training path testing
This commit is contained in:
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
)
|
||||
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
batch = preprocessor(batch)
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
# indicating padding (those ending with "_is_pad")
|
||||
dataset.reader.delta_indices = None
|
||||
batch = next(iter(dataloader))
|
||||
for key in batch:
|
||||
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||
obs = {}
|
||||
for k in batch:
|
||||
# TODO: regenerate the safetensors
|
||||
|
||||
Reference in New Issue
Block a user