mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3539cb78e | |||
| 9014f9a7c5 |
+78
-55
@@ -15,7 +15,8 @@
|
||||
# limitations under the License.
|
||||
|
||||
import functools
|
||||
from collections.abc import Callable, Sequence
|
||||
import itertools
|
||||
from collections.abc import Callable, Generator, Sequence
|
||||
from contextlib import suppress
|
||||
from typing import TypedDict
|
||||
|
||||
@@ -29,13 +30,20 @@ from lerobot.utils.transition import Transition
|
||||
|
||||
|
||||
class BatchTransition(TypedDict):
|
||||
"""Batch transition for single-step RL algorithms.
|
||||
|
||||
Uses Gymnasium terminology:
|
||||
- terminated: True termination due to task success/failure
|
||||
- truncated: Termination due to time limit or other external factors
|
||||
"""
|
||||
|
||||
state: dict[str, torch.Tensor]
|
||||
action: torch.Tensor
|
||||
reward: torch.Tensor
|
||||
next_state: dict[str, torch.Tensor]
|
||||
done: torch.Tensor
|
||||
truncated: torch.Tensor
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
||||
terminated: torch.Tensor # True termination due to task success/failure
|
||||
truncated: torch.Tensor # Termination due to time limit
|
||||
complementary_info: dict[str, torch.Tensor] | None
|
||||
|
||||
|
||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||
@@ -78,6 +86,8 @@ def random_shift(images: torch.Tensor, pad: int = 4):
|
||||
|
||||
|
||||
class ReplayBuffer:
|
||||
"""Replay buffer for storing transitions used in RL training (e.g., SAC)."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
capacity: int,
|
||||
@@ -133,7 +143,7 @@ class ReplayBuffer:
|
||||
self,
|
||||
state: dict[str, torch.Tensor],
|
||||
action: torch.Tensor,
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
|
||||
):
|
||||
"""Initialize the storage tensors based on the first transition."""
|
||||
# Determine shapes from the first transition
|
||||
@@ -159,8 +169,8 @@ class ReplayBuffer:
|
||||
# Just create a reference to states for consistent API
|
||||
self.next_states = self.states # Just a reference for API consistency
|
||||
|
||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||
|
||||
# Initialize storage for complementary_info
|
||||
self.has_complementary_info = complementary_info is not None
|
||||
@@ -195,7 +205,7 @@ class ReplayBuffer:
|
||||
next_state: dict[str, torch.Tensor],
|
||||
done: bool,
|
||||
truncated: bool,
|
||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
|
||||
):
|
||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||
# Initialize storage if this is the first transition
|
||||
@@ -212,8 +222,8 @@ class ReplayBuffer:
|
||||
|
||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||
self.rewards[self.position] = reward
|
||||
self.dones[self.position] = done
|
||||
self.truncateds[self.position] = truncated
|
||||
self.terminated[self.position] = done
|
||||
self.truncated[self.position] = truncated
|
||||
|
||||
# Handle complementary_info if provided and storage is initialized
|
||||
if complementary_info is not None and self.has_complementary_info:
|
||||
@@ -283,8 +293,8 @@ class ReplayBuffer:
|
||||
# Sample other tensors
|
||||
batch_actions = self.actions[idx].to(self.device)
|
||||
batch_rewards = self.rewards[idx].to(self.device)
|
||||
batch_dones = self.dones[idx].to(self.device).float()
|
||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
||||
batch_terminated = self.terminated[idx].to(self.device).float()
|
||||
batch_truncated = self.truncated[idx].to(self.device).float()
|
||||
|
||||
# Sample complementary_info if available
|
||||
batch_complementary_info = None
|
||||
@@ -298,8 +308,8 @@ class ReplayBuffer:
|
||||
action=batch_actions,
|
||||
reward=batch_rewards,
|
||||
next_state=batch_next_state,
|
||||
done=batch_dones,
|
||||
truncated=batch_truncateds,
|
||||
terminated=batch_terminated,
|
||||
truncated=batch_truncated,
|
||||
complementary_info=batch_complementary_info,
|
||||
)
|
||||
|
||||
@@ -431,7 +441,6 @@ class ReplayBuffer:
|
||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
|
||||
capacity (int | None): Buffer capacity. If None, uses dataset length.
|
||||
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
|
||||
image_augmentation_function (Callable | None): Function for image augmentation.
|
||||
If None, uses default random shift with pad=4.
|
||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||
@@ -460,12 +469,16 @@ class ReplayBuffer:
|
||||
optimize_memory=optimize_memory,
|
||||
)
|
||||
|
||||
# Convert dataset to transitions
|
||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
||||
# Convert dataset to transitions generator
|
||||
transitions_generator = cls._lerobotdataset_to_transitions(
|
||||
dataset=lerobot_dataset, state_keys=state_keys
|
||||
)
|
||||
|
||||
# Get first transition to initialize storage
|
||||
first_transition = next(transitions_generator, None)
|
||||
|
||||
# Initialize the buffer with the first transition to set up storage tensors
|
||||
if list_transition:
|
||||
first_transition = list_transition[0]
|
||||
if first_transition is not None:
|
||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||
first_action = first_transition[ACTION].to(device)
|
||||
|
||||
@@ -483,26 +496,28 @@ class ReplayBuffer:
|
||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||
)
|
||||
|
||||
# Fill the buffer with all transitions
|
||||
for data in list_transition:
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
# Fill the buffer with all transitions (first + remaining)
|
||||
if first_transition is not None:
|
||||
for data in itertools.chain([first_transition], transitions_generator):
|
||||
for k, v in data.items():
|
||||
if isinstance(v, dict):
|
||||
for key, tensor in v.items():
|
||||
if isinstance(tensor, torch.Tensor):
|
||||
v[key] = tensor.to(storage_device)
|
||||
elif isinstance(v, torch.Tensor):
|
||||
data[k] = v.to(storage_device)
|
||||
|
||||
action = data[ACTION]
|
||||
action = data[ACTION]
|
||||
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
||||
complementary_info=data.get("complementary_info", None),
|
||||
)
|
||||
replay_buffer.add(
|
||||
state=data["state"],
|
||||
action=action,
|
||||
reward=data["reward"],
|
||||
next_state=data["next_state"],
|
||||
done=data["done"],
|
||||
truncated=data["truncated"],
|
||||
complementary_info=data.get("complementary_info"),
|
||||
)
|
||||
|
||||
return replay_buffer
|
||||
|
||||
@@ -576,10 +591,12 @@ class ReplayBuffer:
|
||||
for key in self.states:
|
||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||
|
||||
# Fill action, reward, done
|
||||
# Fill action, reward, done (done = terminated or truncated)
|
||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
||||
frame_dict[DONE] = torch.tensor(
|
||||
[self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool
|
||||
).cpu()
|
||||
frame_dict["task"] = task_name
|
||||
|
||||
# Add complementary_info if available
|
||||
@@ -599,7 +616,7 @@ class ReplayBuffer:
|
||||
lerobot_dataset.add_frame(frame_dict)
|
||||
|
||||
# If we reached an episode boundary, call save_episode, reset counters
|
||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
||||
if self.terminated[actual_idx] or self.truncated[actual_idx]:
|
||||
lerobot_dataset.save_episode()
|
||||
|
||||
# Save any remaining frames in the buffer
|
||||
@@ -615,9 +632,11 @@ class ReplayBuffer:
|
||||
def _lerobotdataset_to_transitions(
|
||||
dataset: LeRobotDataset,
|
||||
state_keys: Sequence[str] | None = None,
|
||||
) -> list[Transition]:
|
||||
) -> Generator[Transition, None, None]:
|
||||
"""
|
||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
||||
Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions.
|
||||
|
||||
Using a generator instead of a list is more memory efficient for large datasets.
|
||||
|
||||
Args:
|
||||
dataset (LeRobotDataset):
|
||||
@@ -637,14 +656,12 @@ class ReplayBuffer:
|
||||
["observation.state", "observation.environment_state"].
|
||||
If None, you must handle or define default keys.
|
||||
|
||||
Returns:
|
||||
transitions (List[Transition]):
|
||||
A list of Transition dictionaries with the same length as `dataset`.
|
||||
Yields:
|
||||
Transition: A transition dictionary.
|
||||
"""
|
||||
if state_keys is None:
|
||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||
|
||||
transitions = []
|
||||
num_frames = len(dataset)
|
||||
|
||||
# Check if the dataset has "next.done" key
|
||||
@@ -687,8 +704,17 @@ class ReplayBuffer:
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
done = True
|
||||
|
||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
||||
truncated = done
|
||||
# Handle truncation separately from done
|
||||
# This is important if the dataset has truncations (e.g., time limits)
|
||||
truncated = False
|
||||
if not done:
|
||||
# If this is the last frame or if next frame is in a different episode, mark as truncated
|
||||
if i == num_frames - 1:
|
||||
truncated = True
|
||||
elif i < num_frames - 1:
|
||||
next_sample = dataset[i + 1]
|
||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||
truncated = True
|
||||
|
||||
# ----- 4) Next state -----
|
||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||
@@ -716,7 +742,6 @@ class ReplayBuffer:
|
||||
if isinstance(val, torch.Tensor):
|
||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||
else:
|
||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
||||
# For non-tensor values, use directly
|
||||
complementary_info[clean_key] = val
|
||||
|
||||
@@ -730,12 +755,10 @@ class ReplayBuffer:
|
||||
truncated=truncated,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
transitions.append(transition)
|
||||
|
||||
return transitions
|
||||
yield transition
|
||||
|
||||
|
||||
# Utility function to guess shapes/dtypes from a tensor
|
||||
def guess_feature_info(t, name: str):
|
||||
"""
|
||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||
@@ -805,9 +828,9 @@ def concatenate_batch_transitions(
|
||||
for key in left_batch_transitions["next_state"]
|
||||
}
|
||||
|
||||
# Concatenate done and truncated fields
|
||||
left_batch_transitions["done"] = torch.cat(
|
||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
||||
# Concatenate terminated and truncated fields
|
||||
left_batch_transitions["terminated"] = torch.cat(
|
||||
[left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0
|
||||
)
|
||||
left_batch_transitions["truncated"] = torch.cat(
|
||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||
|
||||
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"terminated": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
|
||||
assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
@@ -232,8 +232,8 @@ def test_add_over_capacity():
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
|
||||
assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
terminated=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
terminated=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
elif feature == REWARD:
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == DONE:
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
# DONE in dataset is terminated OR truncated
|
||||
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
|
||||
), "Terminated flags from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
# LeRobot DS hasn't supported truncated yet
|
||||
expected_truncated = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
|
||||
action_val = batch[ACTION][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
is_terminated = batch["terminated"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
if is_terminated:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
|
||||
Reference in New Issue
Block a user