mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
feat(dataset): 2x faster dataloader via parallel decode, uint8 transport, and persistent workers (#3406)
* feat(dataset): 2xfaster dataloader * fix(dataset): streaming return uint8 decode * fix(tests): adjust normalization step comparison * fix(dataset): with threadexecutor + False default * chore(dataset): make it a config * fix(test): account for uint8 in training path testing
This commit is contained in:
@@ -35,6 +35,9 @@ class DatasetConfig:
|
|||||||
revision: str | None = None
|
revision: str | None = None
|
||||||
use_imagenet_stats: bool = True
|
use_imagenet_stats: bool = True
|
||||||
video_backend: str = field(default_factory=get_safe_default_codec)
|
video_backend: str = field(default_factory=get_safe_default_codec)
|
||||||
|
# When True, video frames are returned as uint8 tensors (0-255) instead of float32 (0.0-1.0).
|
||||||
|
# This reduces memory and speeds up DataLoader IPC. The training pipeline handles the conversion.
|
||||||
|
return_uint8: bool = False
|
||||||
streaming: bool = False
|
streaming: bool = False
|
||||||
|
|
||||||
def __post_init__(self) -> None:
|
def __post_init__(self) -> None:
|
||||||
|
|||||||
@@ -56,6 +56,8 @@ class TrainPipelineConfig(HubMixin):
|
|||||||
# Number of workers for the dataloader.
|
# Number of workers for the dataloader.
|
||||||
num_workers: int = 4
|
num_workers: int = 4
|
||||||
batch_size: int = 8
|
batch_size: int = 8
|
||||||
|
prefetch_factor: int = 4
|
||||||
|
persistent_workers: bool = True
|
||||||
steps: int = 100_000
|
steps: int = 100_000
|
||||||
eval_freq: int = 20_000
|
eval_freq: int = 20_000
|
||||||
log_freq: int = 200
|
log_freq: int = 200
|
||||||
|
|||||||
@@ -16,6 +16,7 @@
|
|||||||
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
"""Private reader component for LeRobotDataset. Handles random-access reading (HF dataset, delta indices, video decoding)."""
|
||||||
|
|
||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
import datasets
|
import datasets
|
||||||
@@ -49,6 +50,7 @@ class DatasetReader:
|
|||||||
video_backend: str,
|
video_backend: str,
|
||||||
delta_timestamps: dict[str, list[float]] | None,
|
delta_timestamps: dict[str, list[float]] | None,
|
||||||
image_transforms: Callable | None,
|
image_transforms: Callable | None,
|
||||||
|
return_uint8: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize the reader with metadata, filtering, and transform config.
|
"""Initialize the reader with metadata, filtering, and transform config.
|
||||||
|
|
||||||
@@ -73,6 +75,7 @@ class DatasetReader:
|
|||||||
self._tolerance_s = tolerance_s
|
self._tolerance_s = tolerance_s
|
||||||
self._video_backend = video_backend
|
self._video_backend = video_backend
|
||||||
self._image_transforms = image_transforms
|
self._image_transforms = image_transforms
|
||||||
|
self._return_uint8 = return_uint8
|
||||||
|
|
||||||
self.hf_dataset: datasets.Dataset | None = None
|
self.hf_dataset: datasets.Dataset | None = None
|
||||||
self._absolute_to_relative_idx: dict[int, int] | None = None
|
self._absolute_to_relative_idx: dict[int, int] | None = None
|
||||||
@@ -233,16 +236,30 @@ class DatasetReader:
|
|||||||
Segmentation Fault.
|
Segmentation Fault.
|
||||||
"""
|
"""
|
||||||
ep = self._meta.episodes[ep_idx]
|
ep = self._meta.episodes[ep_idx]
|
||||||
item = {}
|
|
||||||
for vid_key, query_ts in query_timestamps.items():
|
def _decode_single(vid_key: str, query_ts: list[float]) -> tuple[str, torch.Tensor]:
|
||||||
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
from_timestamp = ep[f"videos/{vid_key}/from_timestamp"]
|
||||||
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
shifted_query_ts = [from_timestamp + ts for ts in query_ts]
|
||||||
|
|
||||||
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
video_path = self.root / self._meta.get_video_file_path(ep_idx, vid_key)
|
||||||
frames = decode_video_frames(video_path, shifted_query_ts, self._tolerance_s, self._video_backend)
|
frames = decode_video_frames(
|
||||||
item[vid_key] = frames.squeeze(0)
|
video_path,
|
||||||
|
shifted_query_ts,
|
||||||
|
self._tolerance_s,
|
||||||
|
self._video_backend,
|
||||||
|
return_uint8=self._return_uint8,
|
||||||
|
)
|
||||||
|
return vid_key, frames.squeeze(0)
|
||||||
|
|
||||||
return item
|
items = list(query_timestamps.items())
|
||||||
|
|
||||||
|
# Single camera: no threading overhead
|
||||||
|
if len(items) <= 1:
|
||||||
|
return {vid_key: _decode_single(vid_key, query_ts)[1] for vid_key, query_ts in items}
|
||||||
|
|
||||||
|
# Multi-camera: decode in parallel (video decoding releases the GIL)
|
||||||
|
with ThreadPoolExecutor(max_workers=len(items)) as pool:
|
||||||
|
futures = [pool.submit(_decode_single, k, ts) for k, ts in items]
|
||||||
|
return dict(f.result() for f in futures)
|
||||||
|
|
||||||
def get_item(self, idx) -> dict:
|
def get_item(self, idx) -> dict:
|
||||||
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
"""Core __getitem__ logic. Assumes hf_dataset is loaded.
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
video_backend=cfg.dataset.video_backend,
|
video_backend=cfg.dataset.video_backend,
|
||||||
|
return_uint8=True,
|
||||||
tolerance_s=cfg.tolerance_s,
|
tolerance_s=cfg.tolerance_s,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
@@ -104,6 +105,7 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
|||||||
revision=cfg.dataset.revision,
|
revision=cfg.dataset.revision,
|
||||||
max_num_shards=cfg.num_workers,
|
max_num_shards=cfg.num_workers,
|
||||||
tolerance_s=cfg.tolerance_s,
|
tolerance_s=cfg.tolerance_s,
|
||||||
|
return_uint8=True,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
raise NotImplementedError("The MultiLeRobotDataset isn't supported for now.")
|
||||||
|
|||||||
@@ -56,6 +56,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
force_cache_sync: bool = False,
|
force_cache_sync: bool = False,
|
||||||
download_videos: bool = True,
|
download_videos: bool = True,
|
||||||
video_backend: str | None = None,
|
video_backend: str | None = None,
|
||||||
|
return_uint8: bool = False,
|
||||||
batch_encoding_size: int = 1,
|
batch_encoding_size: int = 1,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
streaming_encoding: bool = False,
|
streaming_encoding: bool = False,
|
||||||
@@ -202,6 +203,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
self.tolerance_s = tolerance_s
|
self.tolerance_s = tolerance_s
|
||||||
self.revision = revision if revision else CODEBASE_VERSION
|
self.revision = revision if revision else CODEBASE_VERSION
|
||||||
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
self._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
self._return_uint8 = return_uint8
|
||||||
self._batch_encoding_size = batch_encoding_size
|
self._batch_encoding_size = batch_encoding_size
|
||||||
self._vcodec = resolve_vcodec(vcodec)
|
self._vcodec = resolve_vcodec(vcodec)
|
||||||
self._encoder_threads = encoder_threads
|
self._encoder_threads = encoder_threads
|
||||||
@@ -225,6 +227,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
video_backend=self._video_backend,
|
video_backend=self._video_backend,
|
||||||
delta_timestamps=delta_timestamps,
|
delta_timestamps=delta_timestamps,
|
||||||
image_transforms=image_transforms,
|
image_transforms=image_transforms,
|
||||||
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Load actual data
|
# Load actual data
|
||||||
@@ -288,6 +291,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
video_backend=self._video_backend,
|
video_backend=self._video_backend,
|
||||||
delta_timestamps=self.delta_timestamps,
|
delta_timestamps=self.delta_timestamps,
|
||||||
image_transforms=self.image_transforms,
|
image_transforms=self.image_transforms,
|
||||||
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
return self.reader
|
return self.reader
|
||||||
|
|
||||||
@@ -683,6 +687,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
obj._video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||||
|
obj._return_uint8 = False
|
||||||
obj._batch_encoding_size = batch_encoding_size
|
obj._batch_encoding_size = batch_encoding_size
|
||||||
obj._vcodec = vcodec
|
obj._vcodec = vcodec
|
||||||
obj._encoder_threads = encoder_threads
|
obj._encoder_threads = encoder_threads
|
||||||
@@ -775,6 +780,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
obj.delta_timestamps = None
|
obj.delta_timestamps = None
|
||||||
obj.episodes = None
|
obj.episodes = None
|
||||||
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
obj._video_backend = video_backend if video_backend else get_safe_default_codec()
|
||||||
|
obj._return_uint8 = False
|
||||||
obj._batch_encoding_size = batch_encoding_size
|
obj._batch_encoding_size = batch_encoding_size
|
||||||
obj._vcodec = vcodec
|
obj._vcodec = vcodec
|
||||||
obj._encoder_threads = encoder_threads
|
obj._encoder_threads = encoder_threads
|
||||||
|
|||||||
@@ -251,6 +251,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
seed: int = 42,
|
seed: int = 42,
|
||||||
rng: np.random.Generator | None = None,
|
rng: np.random.Generator | None = None,
|
||||||
shuffle: bool = True,
|
shuffle: bool = True,
|
||||||
|
return_uint8: bool = False,
|
||||||
):
|
):
|
||||||
"""Initialize a StreamingLeRobotDataset.
|
"""Initialize a StreamingLeRobotDataset.
|
||||||
|
|
||||||
@@ -288,6 +289,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
|
|
||||||
self.streaming = streaming
|
self.streaming = streaming
|
||||||
self.buffer_size = buffer_size
|
self.buffer_size = buffer_size
|
||||||
|
self._return_uint8 = return_uint8
|
||||||
|
|
||||||
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
||||||
self.video_decoder_cache = None
|
self.video_decoder_cache = None
|
||||||
@@ -553,7 +555,11 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|||||||
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
||||||
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
||||||
frames = decode_video_frames_torchcodec(
|
frames = decode_video_frames_torchcodec(
|
||||||
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
video_path,
|
||||||
|
query_ts,
|
||||||
|
self.tolerance_s,
|
||||||
|
decoder_cache=self.video_decoder_cache,
|
||||||
|
return_uint8=self._return_uint8,
|
||||||
)
|
)
|
||||||
|
|
||||||
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
||||||
|
|||||||
@@ -123,6 +123,7 @@ def decode_video_frames(
|
|||||||
timestamps: list[float],
|
timestamps: list[float],
|
||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str | None = None,
|
backend: str | None = None,
|
||||||
|
return_uint8: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""
|
"""
|
||||||
Decodes video frames using the specified backend.
|
Decodes video frames using the specified backend.
|
||||||
@@ -131,19 +132,23 @@ def decode_video_frames(
|
|||||||
video_path (Path): Path to the video file.
|
video_path (Path): Path to the video file.
|
||||||
timestamps (list[float]): List of timestamps to extract frames.
|
timestamps (list[float]): List of timestamps to extract frames.
|
||||||
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
tolerance_s (float): Allowed deviation in seconds for frame retrieval.
|
||||||
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav"..
|
backend (str, optional): Backend to use for decoding. Defaults to "torchcodec" when available in the platform; otherwise, defaults to "pyav".
|
||||||
|
return_uint8 (bool): If True, return raw uint8 frames without float32 normalization.
|
||||||
|
This reduces memory for DataLoader IPC; normalization can be done on GPU afterward.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
torch.Tensor: Decoded frames.
|
torch.Tensor: Decoded frames (float32 in [0,1] by default, or uint8 if return_uint8=True).
|
||||||
|
|
||||||
Currently supports torchcodec on cpu and pyav.
|
Currently supports torchcodec on cpu and pyav.
|
||||||
"""
|
"""
|
||||||
if backend is None:
|
if backend is None:
|
||||||
backend = get_safe_default_codec()
|
backend = get_safe_default_codec()
|
||||||
if backend == "torchcodec":
|
if backend == "torchcodec":
|
||||||
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s)
|
return decode_video_frames_torchcodec(video_path, timestamps, tolerance_s, return_uint8=return_uint8)
|
||||||
elif backend in ["pyav", "video_reader"]:
|
elif backend in ["pyav", "video_reader"]:
|
||||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
return decode_video_frames_torchvision(
|
||||||
|
video_path, timestamps, tolerance_s, backend, return_uint8=return_uint8
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported video backend: {backend}")
|
raise ValueError(f"Unsupported video backend: {backend}")
|
||||||
|
|
||||||
@@ -154,6 +159,7 @@ def decode_video_frames_torchvision(
|
|||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
backend: str = "pyav",
|
backend: str = "pyav",
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
|
return_uint8: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated to the requested timestamps of a video
|
"""Loads frames associated to the requested timestamps of a video
|
||||||
|
|
||||||
@@ -240,14 +246,17 @@ def decode_video_frames_torchvision(
|
|||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logger.info(f"{closest_ts=}")
|
logger.info(f"{closest_ts=}")
|
||||||
|
|
||||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
|
||||||
|
|
||||||
if len(timestamps) != len(closest_frames):
|
if len(timestamps) != len(closest_frames):
|
||||||
raise FrameTimestampError(
|
raise FrameTimestampError(
|
||||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||||
f"number of queried timestamps ({len(timestamps)})"
|
f"number of queried timestamps ({len(timestamps)})"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_uint8:
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
@@ -306,6 +315,7 @@ def decode_video_frames_torchcodec(
|
|||||||
tolerance_s: float,
|
tolerance_s: float,
|
||||||
log_loaded_timestamps: bool = False,
|
log_loaded_timestamps: bool = False,
|
||||||
decoder_cache: VideoDecoderCache | None = None,
|
decoder_cache: VideoDecoderCache | None = None,
|
||||||
|
return_uint8: bool = False,
|
||||||
) -> torch.Tensor:
|
) -> torch.Tensor:
|
||||||
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
"""Loads frames associated with the requested timestamps of a video using torchcodec.
|
||||||
|
|
||||||
@@ -373,14 +383,16 @@ def decode_video_frames_torchcodec(
|
|||||||
if log_loaded_timestamps:
|
if log_loaded_timestamps:
|
||||||
logger.info(f"{closest_ts=}")
|
logger.info(f"{closest_ts=}")
|
||||||
|
|
||||||
# convert to float32 in [0,1] range
|
|
||||||
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
|
||||||
|
|
||||||
if not len(timestamps) == len(closest_frames):
|
if not len(timestamps) == len(closest_frames):
|
||||||
raise FrameTimestampError(
|
raise FrameTimestampError(
|
||||||
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
f"Retrieved timestamps differ from queried {set(closest_frames) - set(timestamps)}"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if return_uint8:
|
||||||
|
return closest_frames
|
||||||
|
|
||||||
|
# convert to float32 in [0,1] range
|
||||||
|
closest_frames = (closest_frames / 255.0).type(torch.float32)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -386,7 +386,8 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
sampler=sampler,
|
sampler=sampler,
|
||||||
pin_memory=device.type == "cuda",
|
pin_memory=device.type == "cuda",
|
||||||
drop_last=False,
|
drop_last=False,
|
||||||
prefetch_factor=2 if cfg.num_workers > 0 else None,
|
prefetch_factor=cfg.prefetch_factor if cfg.num_workers > 0 else None,
|
||||||
|
persistent_workers=cfg.persistent_workers and cfg.num_workers > 0,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Prepare everything with accelerator
|
# Prepare everything with accelerator
|
||||||
@@ -433,6 +434,9 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
|
|||||||
for _ in range(step, cfg.steps):
|
for _ in range(step, cfg.steps):
|
||||||
start_time = time.perf_counter()
|
start_time = time.perf_counter()
|
||||||
batch = next(dl_iter)
|
batch = next(dl_iter)
|
||||||
|
for cam_key in dataset.meta.camera_keys:
|
||||||
|
if cam_key in batch and batch[cam_key].dtype == torch.uint8:
|
||||||
|
batch[cam_key] = batch[cam_key].to(dtype=torch.float32) / 255.0
|
||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
train_tracker.dataloading_s = time.perf_counter() - start_time
|
train_tracker.dataloading_s = time.perf_counter() - start_time
|
||||||
|
|
||||||
|
|||||||
@@ -52,6 +52,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|||||||
)
|
)
|
||||||
|
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||||
|
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||||
batch = preprocessor(batch)
|
batch = preprocessor(batch)
|
||||||
loss, output_dict = policy.forward(batch)
|
loss, output_dict = policy.forward(batch)
|
||||||
|
|
||||||
@@ -82,6 +85,9 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
|||||||
# indicating padding (those ending with "_is_pad")
|
# indicating padding (those ending with "_is_pad")
|
||||||
dataset.reader.delta_indices = None
|
dataset.reader.delta_indices = None
|
||||||
batch = next(iter(dataloader))
|
batch = next(iter(dataloader))
|
||||||
|
for key in batch:
|
||||||
|
if isinstance(batch[key], torch.Tensor) and batch[key].dtype == torch.uint8:
|
||||||
|
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||||
obs = {}
|
obs = {}
|
||||||
for k in batch:
|
for k in batch:
|
||||||
# TODO: regenerate the safetensors
|
# TODO: regenerate the safetensors
|
||||||
|
|||||||
@@ -196,6 +196,8 @@ def test_policy(ds_repo_id, env_name, env_kwargs, policy_name, policy_kwargs):
|
|||||||
|
|
||||||
for key in batch:
|
for key in batch:
|
||||||
if isinstance(batch[key], torch.Tensor):
|
if isinstance(batch[key], torch.Tensor):
|
||||||
|
if batch[key].dtype == torch.uint8:
|
||||||
|
batch[key] = batch[key].to(dtype=torch.float32) / 255.0
|
||||||
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
batch[key] = batch[key].to(DEVICE, non_blocking=True)
|
||||||
|
|
||||||
# Test updating the policy (and test that it does not mutate the batch)
|
# Test updating the policy (and test that it does not mutate the batch)
|
||||||
|
|||||||
Reference in New Issue
Block a user