mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 02:59:50 +00:00
d90e4bcfd3
* refactor(dataset): modular files * refactor(dataset): update imports across the codebase
690 lines
27 KiB
Python
690 lines
27 KiB
Python
#!/usr/bin/env python
|
|
|
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
|
#
|
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
# you may not use this file except in compliance with the License.
|
|
# You may obtain a copy of the License at
|
|
#
|
|
# http://www.apache.org/licenses/LICENSE-2.0
|
|
#
|
|
# Unless required by applicable law or agreed to in writing, software
|
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
# See the License for the specific language governing permissions and
|
|
# limitations under the License.
|
|
from collections import deque
|
|
from collections.abc import Callable, Generator, Iterable, Iterator
|
|
from pathlib import Path
|
|
|
|
import datasets
|
|
import numpy as np
|
|
import torch
|
|
from datasets import load_dataset
|
|
|
|
from lerobot.datasets.dataset_metadata import CODEBASE_VERSION, LeRobotDatasetMetadata
|
|
from lerobot.datasets.feature_utils import get_delta_indices
|
|
from lerobot.datasets.io_utils import item_to_torch
|
|
from lerobot.datasets.utils import (
|
|
check_version_compatibility,
|
|
find_float_index,
|
|
is_float_in_list,
|
|
safe_shard,
|
|
)
|
|
from lerobot.datasets.video_utils import (
|
|
VideoDecoderCache,
|
|
decode_video_frames_torchcodec,
|
|
)
|
|
from lerobot.utils.constants import HF_LEROBOT_HOME, LOOKAHEAD_BACKTRACKTABLE, LOOKBACK_BACKTRACKTABLE
|
|
|
|
|
|
class LookBackError(Exception):
|
|
"""
|
|
Exception raised when trying to look back in the history of a Backtrackable object.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class LookAheadError(Exception):
|
|
"""
|
|
Exception raised when trying to look ahead in the future of a Backtrackable object.
|
|
"""
|
|
|
|
pass
|
|
|
|
|
|
class Backtrackable[T]:
|
|
"""
|
|
Wrap any iterator/iterable so you can step back up to `history` items
|
|
and look ahead up to `lookahead` items.
|
|
|
|
This is useful for streaming datasets where you need to access previous and future items
|
|
but can't load the entire dataset into memory.
|
|
|
|
Example:
|
|
-------
|
|
```python
|
|
ds = load_dataset("c4", "en", streaming=True, split="train")
|
|
rev = Backtrackable(ds, history=3, lookahead=2)
|
|
|
|
x0 = next(rev) # forward
|
|
x1 = next(rev)
|
|
x2 = next(rev)
|
|
|
|
# Look ahead
|
|
x3_peek = rev.peek_ahead(1) # next item without moving cursor
|
|
x4_peek = rev.peek_ahead(2) # two items ahead
|
|
|
|
# Look back
|
|
x1_again = rev.peek_back(1) # previous item without moving cursor
|
|
x0_again = rev.peek_back(2) # two items back
|
|
|
|
# Move backward
|
|
x1_back = rev.prev() # back one step
|
|
next(rev) # returns x2, continues forward from where we were
|
|
```
|
|
"""
|
|
|
|
__slots__ = ("_source", "_back_buf", "_ahead_buf", "_cursor", "_history", "_lookahead")
|
|
|
|
def __init__(self, iterable: Iterable[T], *, history: int = 1, lookahead: int = 0):
|
|
if history < 1:
|
|
raise ValueError("history must be >= 1")
|
|
if lookahead <= 0:
|
|
raise ValueError("lookahead must be > 0")
|
|
|
|
self._source: Iterator[T] = iter(iterable)
|
|
self._back_buf: deque[T] = deque(maxlen=history)
|
|
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
|
self._cursor: int = 0
|
|
self._history = history
|
|
self._lookahead = lookahead
|
|
|
|
def __iter__(self) -> "Backtrackable[T]":
|
|
return self
|
|
|
|
def __next__(self) -> T:
|
|
# If we've stepped back, consume from back buffer first
|
|
if self._cursor < 0: # -1 means "last item", etc.
|
|
self._cursor += 1
|
|
return self._back_buf[self._cursor]
|
|
|
|
# If we have items in the ahead buffer, use them first
|
|
item = self._ahead_buf.popleft() if self._ahead_buf else next(self._source)
|
|
|
|
# Add current item to back buffer and reset cursor
|
|
self._back_buf.append(item)
|
|
self._cursor = 0
|
|
return item
|
|
|
|
def prev(self) -> T:
|
|
"""
|
|
Step one item back in history and return it.
|
|
Raises IndexError if already at the oldest buffered item.
|
|
"""
|
|
if len(self._back_buf) + self._cursor <= 1:
|
|
raise LookBackError("At start of history")
|
|
|
|
self._cursor -= 1
|
|
return self._back_buf[self._cursor]
|
|
|
|
def peek_back(self, n: int = 1) -> T:
|
|
"""
|
|
Look `n` items back (n=1 == previous item) without moving the cursor.
|
|
"""
|
|
if n < 0 or n + 1 > len(self._back_buf) + self._cursor:
|
|
raise LookBackError("peek_back distance out of range")
|
|
|
|
return self._back_buf[self._cursor - (n + 1)]
|
|
|
|
def peek_ahead(self, n: int = 1) -> T:
|
|
"""
|
|
Look `n` items ahead (n=1 == next item) without moving the cursor.
|
|
Fills the ahead buffer if necessary.
|
|
"""
|
|
if n < 1:
|
|
raise LookAheadError("peek_ahead distance must be 1 or more")
|
|
elif n > self._lookahead:
|
|
raise LookAheadError("peek_ahead distance exceeds lookahead limit")
|
|
|
|
# Fill ahead buffer if we don't have enough items
|
|
while len(self._ahead_buf) < n:
|
|
try:
|
|
item = next(self._source)
|
|
self._ahead_buf.append(item)
|
|
|
|
except StopIteration as err:
|
|
raise LookAheadError("peek_ahead: not enough items in source") from err
|
|
|
|
return self._ahead_buf[n - 1]
|
|
|
|
def history(self) -> list[T]:
|
|
"""
|
|
Return a copy of the buffered history (most recent last).
|
|
The list length ≤ `history` argument passed at construction.
|
|
"""
|
|
if self._cursor == 0:
|
|
return list(self._back_buf)
|
|
|
|
# When cursor<0, slice so the order remains chronological
|
|
return list(self._back_buf)[: self._cursor or None]
|
|
|
|
def can_peek_back(self, steps: int = 1) -> bool:
|
|
"""
|
|
Check if we can go back `steps` items without raising an IndexError.
|
|
"""
|
|
return steps <= len(self._back_buf) + self._cursor
|
|
|
|
def can_peek_ahead(self, steps: int = 1) -> bool:
|
|
"""
|
|
Check if we can peek ahead `steps` items.
|
|
This may involve trying to fill the ahead buffer.
|
|
"""
|
|
if self._lookahead > 0 and steps > self._lookahead:
|
|
return False
|
|
|
|
# Try to fill ahead buffer to check if we can peek that far
|
|
try:
|
|
while len(self._ahead_buf) < steps:
|
|
if self._lookahead > 0 and len(self._ahead_buf) >= self._lookahead:
|
|
return False
|
|
item = next(self._source)
|
|
self._ahead_buf.append(item)
|
|
return True
|
|
except StopIteration:
|
|
return False
|
|
|
|
|
|
class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
|
|
"""LeRobotDataset with streaming capabilities.
|
|
|
|
This class extends LeRobotDataset to add streaming functionality, allowing data to be streamed
|
|
rather than loaded entirely into memory. This is especially useful for large datasets that may
|
|
not fit in memory or when you want to quickly explore a dataset without downloading it completely.
|
|
|
|
The key innovation is using a Backtrackable iterator that maintains a bounded buffer of recent
|
|
items, allowing us to access previous frames for delta timestamps without loading the entire
|
|
dataset into memory.
|
|
|
|
Example:
|
|
Basic usage:
|
|
```python
|
|
from lerobot.common.datasets.streaming_dataset import StreamingLeRobotDataset
|
|
|
|
# Create a streaming dataset with delta timestamps
|
|
delta_timestamps = {
|
|
"observation.image": [-1.0, -0.5, 0.0], # 1 sec ago, 0.5 sec ago, current
|
|
"action": [0.0, 0.1, 0.2], # current, 0.1 sec future, 0.2 sec future
|
|
}
|
|
|
|
dataset = StreamingLeRobotDataset(
|
|
repo_id="your-dataset-repo-id",
|
|
delta_timestamps=delta_timestamps,
|
|
streaming=True,
|
|
buffer_size=1000,
|
|
)
|
|
|
|
# Iterate over the dataset
|
|
for i, item in enumerate(dataset):
|
|
print(f"Sample {i}: Episode {item['episode_index']} Frame {item['frame_index']}")
|
|
# item will contain stacked frames according to delta_timestamps
|
|
if i >= 10:
|
|
break
|
|
```
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
repo_id: str,
|
|
root: str | Path | None = None,
|
|
episodes: list[int] | None = None,
|
|
image_transforms: Callable | None = None,
|
|
delta_timestamps: dict[list[float]] | None = None,
|
|
tolerance_s: float = 1e-4,
|
|
revision: str | None = None,
|
|
force_cache_sync: bool = False,
|
|
streaming: bool = True,
|
|
buffer_size: int = 1000,
|
|
max_num_shards: int = 16,
|
|
seed: int = 42,
|
|
rng: np.random.Generator | None = None,
|
|
shuffle: bool = True,
|
|
):
|
|
"""Initialize a StreamingLeRobotDataset.
|
|
|
|
Args:
|
|
repo_id (str): This is the repo id that will be used to fetch the dataset.
|
|
root (Path | None, optional): Local directory to use for downloading/writing files.
|
|
episodes (list[int] | None, optional): If specified, this will only load episodes specified by
|
|
their episode_index in this list.
|
|
image_transforms (Callable | None, optional): Transform to apply to image data.
|
|
tolerance_s (float, optional): Tolerance in seconds for timestamp matching.
|
|
revision (str, optional): Git revision id (branch name, tag, or commit hash).
|
|
force_cache_sync (bool, optional): Flag to sync and refresh local files first.
|
|
streaming (bool, optional): Whether to stream the dataset or load it all. Defaults to True.
|
|
buffer_size (int, optional): Buffer size for shuffling when streaming. Defaults to 1000.
|
|
max_num_shards (int, optional): Number of shards to re-shard the input dataset into. Defaults to 16.
|
|
seed (int, optional): Reproducibility random seed.
|
|
rng (np.random.Generator | None, optional): Random number generator.
|
|
shuffle (bool, optional): Whether to shuffle the dataset across exhaustions. Defaults to True.
|
|
"""
|
|
super().__init__()
|
|
self.repo_id = repo_id
|
|
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
|
self.streaming_from_local = root is not None
|
|
|
|
self.image_transforms = image_transforms
|
|
self.episodes = episodes
|
|
self.tolerance_s = tolerance_s
|
|
self.revision = revision if revision else CODEBASE_VERSION
|
|
self.seed = seed
|
|
self.rng = rng if rng is not None else np.random.default_rng(seed)
|
|
self.shuffle = shuffle
|
|
|
|
self.streaming = streaming
|
|
self.buffer_size = buffer_size
|
|
|
|
# We cache the video decoders to avoid re-initializing them at each frame (avoiding a ~10x slowdown)
|
|
self.video_decoder_cache = None
|
|
|
|
self.root.mkdir(exist_ok=True, parents=True)
|
|
|
|
# Load metadata
|
|
self.meta = LeRobotDatasetMetadata(
|
|
self.repo_id, self.root, self.revision, force_cache_sync=force_cache_sync
|
|
)
|
|
# Check version
|
|
check_version_compatibility(self.repo_id, self.meta._version, CODEBASE_VERSION)
|
|
|
|
self.delta_timestamps = None
|
|
self.delta_indices = None
|
|
|
|
if delta_timestamps is not None:
|
|
self._validate_delta_timestamp_keys(delta_timestamps) # raises ValueError if invalid
|
|
self.delta_timestamps = delta_timestamps
|
|
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
|
|
|
self.hf_dataset: datasets.IterableDataset = load_dataset(
|
|
self.repo_id if not self.streaming_from_local else str(self.root),
|
|
split="train",
|
|
streaming=self.streaming,
|
|
data_files="data/*/*.parquet",
|
|
revision=self.revision,
|
|
)
|
|
|
|
self.num_shards = min(self.hf_dataset.num_shards, max_num_shards)
|
|
|
|
@property
|
|
def num_frames(self):
|
|
return self.meta.total_frames
|
|
|
|
@property
|
|
def num_episodes(self):
|
|
return self.meta.total_episodes
|
|
|
|
@property
|
|
def fps(self):
|
|
return self.meta.fps
|
|
|
|
@staticmethod
|
|
def _iter_random_indices(
|
|
rng: np.random.Generator, buffer_size: int, random_batch_size=100
|
|
) -> Iterator[int]:
|
|
while True:
|
|
yield from (int(i) for i in rng.integers(0, buffer_size, size=random_batch_size))
|
|
|
|
@staticmethod
|
|
def _infinite_generator_over_elements(rng: np.random.Generator, elements: list[int]) -> Iterator[int]:
|
|
while True:
|
|
yield rng.choice(elements)
|
|
|
|
# TODO(fracapuano): Implement multi-threaded prefetching to accelerate data loading.
|
|
# The current sequential iteration is a bottleneck. A producer-consumer pattern
|
|
# could be used with a ThreadPoolExecutor to run `make_frame` (especially video decoding)
|
|
# in parallel, feeding a queue from which this iterator will yield processed items.
|
|
def __iter__(self) -> Iterator[dict[str, torch.Tensor]]:
|
|
if self.video_decoder_cache is None:
|
|
self.video_decoder_cache = VideoDecoderCache()
|
|
|
|
# keep the same seed across exhaustions if shuffle is False, otherwise shuffle data across exhaustions
|
|
rng = np.random.default_rng(self.seed) if not self.shuffle else self.rng
|
|
|
|
buffer_indices_generator = self._iter_random_indices(rng, self.buffer_size)
|
|
|
|
idx_to_backtrack_dataset = {
|
|
idx: self._make_backtrackable_dataset(safe_shard(self.hf_dataset, idx, self.num_shards))
|
|
for idx in range(self.num_shards)
|
|
}
|
|
|
|
# This buffer is populated while iterating on the dataset's shards
|
|
# the logic is to add 2 levels of randomness:
|
|
# (1) sample one shard at random from the ones available, and
|
|
# (2) sample one frame from the shard sampled at (1)
|
|
frames_buffer = []
|
|
while available_shards := list(idx_to_backtrack_dataset.keys()):
|
|
shard_key = next(self._infinite_generator_over_elements(rng, available_shards))
|
|
backtrack_dataset = idx_to_backtrack_dataset[shard_key] # selects which shard to iterate on
|
|
|
|
try:
|
|
for frame in self.make_frame(backtrack_dataset):
|
|
if len(frames_buffer) == self.buffer_size:
|
|
i = next(buffer_indices_generator) # samples a element from the buffer
|
|
yield frames_buffer[i]
|
|
frames_buffer[i] = frame
|
|
else:
|
|
frames_buffer.append(frame)
|
|
break # random shard sampled, switch shard
|
|
except (
|
|
RuntimeError,
|
|
StopIteration,
|
|
): # NOTE: StopIteration inside a generator throws a RuntimeError since python 3.7
|
|
del idx_to_backtrack_dataset[shard_key] # Remove exhausted shard, onto another shard
|
|
|
|
# Once shards are all exhausted, shuffle the buffer and yield the remaining frames
|
|
rng.shuffle(frames_buffer)
|
|
yield from frames_buffer
|
|
|
|
def _get_window_steps(
|
|
self, delta_timestamps: dict[str, list[float]] | None = None, dynamic_bounds: bool = False
|
|
) -> tuple[int, int]:
|
|
if delta_timestamps is None:
|
|
return 1, 1
|
|
|
|
if not dynamic_bounds:
|
|
# Fix the windows
|
|
lookback = LOOKBACK_BACKTRACKTABLE
|
|
lookahead = LOOKAHEAD_BACKTRACKTABLE
|
|
else:
|
|
# Dynamically adjust the windows based on the given delta_timesteps
|
|
all_timestamps = sum(delta_timestamps.values(), [])
|
|
lookback = min(all_timestamps) * self.fps
|
|
lookahead = max(all_timestamps) * self.fps
|
|
|
|
# When lookback is >=0 it means no negative timesteps have been provided
|
|
lookback = 0 if lookback >= 0 else (lookback * -1)
|
|
|
|
return lookback, lookahead
|
|
|
|
def _make_backtrackable_dataset(self, dataset: datasets.IterableDataset) -> Backtrackable:
|
|
lookback, lookahead = self._get_window_steps(self.delta_timestamps)
|
|
return Backtrackable(dataset, history=lookback, lookahead=lookahead)
|
|
|
|
def _make_timestamps_from_indices(
|
|
self, start_ts: float, indices: dict[str, list[int]] | None = None
|
|
) -> dict[str, list[float]]:
|
|
if indices is not None:
|
|
return {
|
|
key: (
|
|
start_ts + torch.tensor(indices[key]) / self.fps
|
|
).tolist() # NOTE: why not delta_timestamps directly?
|
|
for key in self.delta_timestamps
|
|
}
|
|
else:
|
|
return dict.fromkeys(self.meta.video_keys, [start_ts])
|
|
|
|
def _make_padding_camera_frame(self, camera_key: str):
|
|
"""Variable-shape padding frame for given camera keys, given in (H, W, C)"""
|
|
return torch.zeros(self.meta.info["features"][camera_key]["shape"]).permute(-1, 0, 1)
|
|
|
|
def _get_video_frame_padding_mask(
|
|
self,
|
|
video_frames: dict[str, torch.Tensor],
|
|
query_timestamps: dict[str, list[float]],
|
|
original_timestamps: dict[str, list[float]],
|
|
) -> dict[str, torch.BoolTensor]:
|
|
padding_mask = {}
|
|
|
|
for video_key, timestamps in original_timestamps.items():
|
|
if video_key not in video_frames:
|
|
continue # only padding on video keys that are available
|
|
frames = []
|
|
mask = []
|
|
padding_frame = self._make_padding_camera_frame(video_key)
|
|
for ts in timestamps:
|
|
if is_float_in_list(ts, query_timestamps[video_key]):
|
|
idx = find_float_index(ts, query_timestamps[video_key])
|
|
frames.append(video_frames[video_key][idx, :])
|
|
mask.append(False)
|
|
else:
|
|
frames.append(padding_frame)
|
|
mask.append(True)
|
|
|
|
padding_mask[f"{video_key}_is_pad"] = torch.BoolTensor(mask)
|
|
|
|
return padding_mask
|
|
|
|
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
|
|
"""Makes a frame starting from a dataset iterator"""
|
|
item = next(dataset_iterator)
|
|
item = item_to_torch(item)
|
|
|
|
updates = [] # list of "updates" to apply to the item retrieved from hf_dataset (w/o camera features)
|
|
|
|
# Get episode index from the item
|
|
ep_idx = item["episode_index"]
|
|
|
|
# "timestamp" restarts from 0 for each episode, whereas we need a global timestep within the single .mp4 file (given by index/fps)
|
|
current_ts = item["index"] / self.fps
|
|
|
|
episode_boundaries_ts = {
|
|
key: (
|
|
self.meta.episodes[ep_idx][f"videos/{key}/from_timestamp"],
|
|
self.meta.episodes[ep_idx][f"videos/{key}/to_timestamp"],
|
|
)
|
|
for key in self.meta.video_keys
|
|
}
|
|
|
|
# Apply delta querying logic if necessary
|
|
if self.delta_indices is not None:
|
|
query_result, padding = self._get_delta_frames(dataset_iterator, item)
|
|
updates.append(query_result)
|
|
updates.append(padding)
|
|
|
|
# Load video frames, when needed
|
|
if len(self.meta.video_keys) > 0:
|
|
original_timestamps = self._make_timestamps_from_indices(current_ts, self.delta_indices)
|
|
|
|
# Some timestamps might not result available considering the episode's boundaries
|
|
query_timestamps = self._get_query_timestamps(
|
|
current_ts, self.delta_indices, episode_boundaries_ts
|
|
)
|
|
video_frames = self._query_videos(query_timestamps, ep_idx)
|
|
|
|
if self.image_transforms is not None:
|
|
image_keys = self.meta.camera_keys
|
|
for cam in image_keys:
|
|
video_frames[cam] = self.image_transforms(video_frames[cam])
|
|
|
|
updates.append(video_frames)
|
|
|
|
if self.delta_indices is not None:
|
|
# We always return the same number of frames. Unavailable frames are padded.
|
|
padding_mask = self._get_video_frame_padding_mask(
|
|
video_frames, query_timestamps, original_timestamps
|
|
)
|
|
updates.append(padding_mask)
|
|
|
|
result = item.copy()
|
|
for update in updates:
|
|
result.update(update)
|
|
|
|
result["task"] = self.meta.tasks.iloc[item["task_index"]].name
|
|
|
|
yield result
|
|
|
|
def _get_query_timestamps(
|
|
self,
|
|
current_ts: float,
|
|
query_indices: dict[str, list[int]] | None = None,
|
|
episode_boundaries_ts: dict[str, tuple[float, float]] | None = None,
|
|
) -> dict[str, list[float]]:
|
|
query_timestamps = {}
|
|
keys_to_timestamps = self._make_timestamps_from_indices(current_ts, query_indices)
|
|
for key in self.meta.video_keys:
|
|
if query_indices is not None and key in query_indices:
|
|
timestamps = keys_to_timestamps[key]
|
|
# Clamp out timesteps outside of episode boundaries
|
|
query_timestamps[key] = torch.clamp(
|
|
torch.tensor(timestamps), *episode_boundaries_ts[key]
|
|
).tolist()
|
|
|
|
else:
|
|
query_timestamps[key] = [current_ts]
|
|
|
|
return query_timestamps
|
|
|
|
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict:
|
|
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
|
in the main process (e.g. by using a second Dataloader with num_workers=0). It will result in a
|
|
Segmentation Fault. This probably happens because a memory reference to the video loader is created in
|
|
the main process and a subprocess fails to access it.
|
|
"""
|
|
|
|
item = {}
|
|
for video_key, query_ts in query_timestamps.items():
|
|
root = self.meta.url_root if self.streaming and not self.streaming_from_local else self.root
|
|
video_path = f"{root}/{self.meta.get_video_file_path(ep_idx, video_key)}"
|
|
frames = decode_video_frames_torchcodec(
|
|
video_path, query_ts, self.tolerance_s, decoder_cache=self.video_decoder_cache
|
|
)
|
|
|
|
item[video_key] = frames.squeeze(0) if len(query_ts) == 1 else frames
|
|
|
|
return item
|
|
|
|
def _get_delta_frames(self, dataset_iterator: Backtrackable, current_item: dict):
|
|
# TODO(fracapuano): Modularize this function, refactor the code
|
|
"""Get frames with delta offsets using the backtrackable iterator.
|
|
|
|
Args:
|
|
current_item (dict): Current item from the iterator.
|
|
ep_idx (int): Episode index.
|
|
|
|
Returns:
|
|
tuple: (query_result, padding) - frames at delta offsets and padding info.
|
|
"""
|
|
current_episode_idx = current_item["episode_index"]
|
|
|
|
# Prepare results
|
|
query_result = {}
|
|
padding = {}
|
|
|
|
for key, delta_indices in self.delta_indices.items():
|
|
if key in self.meta.video_keys:
|
|
continue # visual frames are decoded separately
|
|
|
|
target_frames = []
|
|
is_pad = []
|
|
|
|
# Create a results dictionary to store frames in processing order, then reconstruct original order for stacking
|
|
delta_results = {}
|
|
|
|
# Separate and sort deltas by difficulty (easier operations first)
|
|
negative_deltas = sorted([d for d in delta_indices if d < 0], reverse=True) # [-1, -2, -3, ...]
|
|
positive_deltas = sorted([d for d in delta_indices if d > 0]) # [1, 2, 3, ...]
|
|
zero_deltas = [d for d in delta_indices if d == 0]
|
|
|
|
# Process zero deltas (current frame)
|
|
for delta in zero_deltas:
|
|
delta_results[delta] = (
|
|
current_item[key],
|
|
False,
|
|
)
|
|
|
|
# Process negative deltas in order of increasing difficulty
|
|
lookback_failed = False
|
|
|
|
last_successful_frame = current_item[key]
|
|
|
|
for delta in negative_deltas:
|
|
if lookback_failed:
|
|
delta_results[delta] = (last_successful_frame, True)
|
|
continue
|
|
|
|
try:
|
|
steps_back = abs(delta)
|
|
if dataset_iterator.can_peek_back(steps_back):
|
|
past_item = dataset_iterator.peek_back(steps_back)
|
|
past_item = item_to_torch(past_item)
|
|
|
|
if past_item["episode_index"] == current_episode_idx:
|
|
delta_results[delta] = (past_item[key], False)
|
|
last_successful_frame = past_item[key]
|
|
|
|
else:
|
|
raise LookBackError("Retrieved frame is from different episode!")
|
|
else:
|
|
raise LookBackError("Cannot go back further than the history buffer!")
|
|
|
|
except LookBackError:
|
|
delta_results[delta] = (last_successful_frame, True)
|
|
lookback_failed = True # All subsequent negative deltas will also fail
|
|
|
|
# Process positive deltas in order of increasing difficulty
|
|
lookahead_failed = False
|
|
last_successful_frame = current_item[key]
|
|
|
|
for delta in positive_deltas:
|
|
if lookahead_failed:
|
|
delta_results[delta] = (last_successful_frame, True)
|
|
continue
|
|
|
|
try:
|
|
if dataset_iterator.can_peek_ahead(delta):
|
|
future_item = dataset_iterator.peek_ahead(delta)
|
|
future_item = item_to_torch(future_item)
|
|
|
|
if future_item["episode_index"] == current_episode_idx:
|
|
delta_results[delta] = (future_item[key], False)
|
|
last_successful_frame = future_item[key]
|
|
|
|
else:
|
|
raise LookAheadError("Retrieved frame is from different episode!")
|
|
else:
|
|
raise LookAheadError("Cannot go ahead further than the lookahead buffer!")
|
|
|
|
except LookAheadError:
|
|
delta_results[delta] = (last_successful_frame, True)
|
|
lookahead_failed = True # All subsequent positive deltas will also fail
|
|
|
|
# Reconstruct original order for stacking
|
|
for delta in delta_indices:
|
|
frame, is_padded = delta_results[delta]
|
|
|
|
# add batch dimension for stacking
|
|
target_frames.append(frame) # frame.unsqueeze(0))
|
|
is_pad.append(is_padded)
|
|
|
|
# Stack frames and add to results
|
|
if target_frames:
|
|
query_result[key] = torch.stack(target_frames)
|
|
padding[f"{key}_is_pad"] = torch.BoolTensor(is_pad)
|
|
|
|
return query_result, padding
|
|
|
|
def _validate_delta_timestamp_keys(self, delta_timestamps: dict[list[float]]) -> None:
|
|
"""
|
|
Validate that all keys in delta_timestamps correspond to actual features in the dataset.
|
|
|
|
Raises:
|
|
ValueError: If any delta timestamp key doesn't correspond to a dataset feature.
|
|
"""
|
|
if delta_timestamps is None:
|
|
return
|
|
|
|
# Get all available feature keys from the dataset metadata
|
|
available_features = set(self.meta.features.keys())
|
|
|
|
# Get all keys from delta_timestamps
|
|
delta_keys = set(delta_timestamps.keys())
|
|
|
|
# Find any keys that don't correspond to features
|
|
invalid_keys = delta_keys - available_features
|
|
|
|
if invalid_keys:
|
|
raise ValueError(
|
|
f"The following delta_timestamp keys do not correspond to dataset features: {invalid_keys}. "
|
|
f"Available features are: {sorted(available_features)}"
|
|
)
|