Compare commits

...

2 Commits

Author SHA1 Message Date
Michel Aractingi e3539cb78e revert initialization to empty 2025-12-17 16:36:59 +01:00
Michel Aractingi 9014f9a7c5 refactor(buffer): use Gymnasium terminology (terminated/truncated)
- Rename 'done' to 'terminated' for true task completion
- Use 'truncated' for time-limit termination
- Change torch.empty to torch.zeros for storage initialization
- Convert _lerobotdataset_to_transitions to generator for memory efficiency
- Add proper docstrings to BatchTransition and ReplayBuffer
- Update concatenate_batch_transitions to use new terminology
- Update tests to use new field names

This aligns ReplayBuffer with Gymnasium's termination semantics where:
- terminated: Episode ended due to task success/failure
- truncated: Episode ended due to time limit or external factors
2025-12-17 15:52:26 +01:00
2 changed files with 95 additions and 71 deletions
+78 -55
View File
@@ -15,7 +15,8 @@
# limitations under the License.
import functools
from collections.abc import Callable, Sequence
import itertools
from collections.abc import Callable, Generator, Sequence
from contextlib import suppress
from typing import TypedDict
@@ -29,13 +30,20 @@ from lerobot.utils.transition import Transition
class BatchTransition(TypedDict):
"""Batch transition for single-step RL algorithms.
Uses Gymnasium terminology:
- terminated: True termination due to task success/failure
- truncated: Termination due to time limit or other external factors
"""
state: dict[str, torch.Tensor]
action: torch.Tensor
reward: torch.Tensor
next_state: dict[str, torch.Tensor]
done: torch.Tensor
truncated: torch.Tensor
complementary_info: dict[str, torch.Tensor | float | int] | None = None
terminated: torch.Tensor # True termination due to task success/failure
truncated: torch.Tensor # Termination due to time limit
complementary_info: dict[str, torch.Tensor] | None
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
@@ -78,6 +86,8 @@ def random_shift(images: torch.Tensor, pad: int = 4):
class ReplayBuffer:
"""Replay buffer for storing transitions used in RL training (e.g., SAC)."""
def __init__(
self,
capacity: int,
@@ -133,7 +143,7 @@ class ReplayBuffer:
self,
state: dict[str, torch.Tensor],
action: torch.Tensor,
complementary_info: dict[str, torch.Tensor] | None = None,
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
):
"""Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition
@@ -159,8 +169,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None
@@ -195,7 +205,7 @@ class ReplayBuffer:
next_state: dict[str, torch.Tensor],
done: bool,
truncated: bool,
complementary_info: dict[str, torch.Tensor] | None = None,
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
):
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition
@@ -212,8 +222,8 @@ class ReplayBuffer:
self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward
self.dones[self.position] = done
self.truncateds[self.position] = truncated
self.terminated[self.position] = done
self.truncated[self.position] = truncated
# Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info:
@@ -283,8 +293,8 @@ class ReplayBuffer:
# Sample other tensors
batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float()
batch_terminated = self.terminated[idx].to(self.device).float()
batch_truncated = self.truncated[idx].to(self.device).float()
# Sample complementary_info if available
batch_complementary_info = None
@@ -298,8 +308,8 @@ class ReplayBuffer:
action=batch_actions,
reward=batch_rewards,
next_state=batch_next_state,
done=batch_dones,
truncated=batch_truncateds,
terminated=batch_terminated,
truncated=batch_truncated,
complementary_info=batch_complementary_info,
)
@@ -431,7 +441,6 @@ class ReplayBuffer:
device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
capacity (int | None): Buffer capacity. If None, uses dataset length.
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
image_augmentation_function (Callable | None): Function for image augmentation.
If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -460,12 +469,16 @@ class ReplayBuffer:
optimize_memory=optimize_memory,
)
# Convert dataset to transitions
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
# Convert dataset to transitions generator
transitions_generator = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Get first transition to initialize storage
first_transition = next(transitions_generator, None)
# Initialize the buffer with the first transition to set up storage tensors
if list_transition:
first_transition = list_transition[0]
if first_transition is not None:
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition[ACTION].to(device)
@@ -483,26 +496,28 @@ class ReplayBuffer:
state=first_state, action=first_action, complementary_info=first_complementary_info
)
# Fill the buffer with all transitions
for data in list_transition:
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
# Fill the buffer with all transitions (first + remaining)
if first_transition is not None:
for data in itertools.chain([first_transition], transitions_generator):
for k, v in data.items():
if isinstance(v, dict):
for key, tensor in v.items():
if isinstance(tensor, torch.Tensor):
v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device)
action = data[ACTION]
action = data[ACTION]
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
complementary_info=data.get("complementary_info", None),
)
replay_buffer.add(
state=data["state"],
action=action,
reward=data["reward"],
next_state=data["next_state"],
done=data["done"],
truncated=data["truncated"],
complementary_info=data.get("complementary_info"),
)
return replay_buffer
@@ -576,10 +591,12 @@ class ReplayBuffer:
for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done
# Fill action, reward, done (done = terminated or truncated)
frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
frame_dict[DONE] = torch.tensor(
[self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool
).cpu()
frame_dict["task"] = task_name
# Add complementary_info if available
@@ -599,7 +616,7 @@ class ReplayBuffer:
lerobot_dataset.add_frame(frame_dict)
# If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]:
if self.terminated[actual_idx] or self.truncated[actual_idx]:
lerobot_dataset.save_episode()
# Save any remaining frames in the buffer
@@ -615,9 +632,11 @@ class ReplayBuffer:
def _lerobotdataset_to_transitions(
dataset: LeRobotDataset,
state_keys: Sequence[str] | None = None,
) -> list[Transition]:
) -> Generator[Transition, None, None]:
"""
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions.
Using a generator instead of a list is more memory efficient for large datasets.
Args:
dataset (LeRobotDataset):
@@ -637,14 +656,12 @@ class ReplayBuffer:
["observation.state", "observation.environment_state"].
If None, you must handle or define default keys.
Returns:
transitions (List[Transition]):
A list of Transition dictionaries with the same length as `dataset`.
Yields:
Transition: A transition dictionary.
"""
if state_keys is None:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset)
# Check if the dataset has "next.done" key
@@ -687,8 +704,17 @@ class ReplayBuffer:
if next_sample["episode_index"] != current_sample["episode_index"]:
done = True
# TODO: (azouitine) Handle truncation (using the same value as done for now)
truncated = done
# Handle truncation separately from done
# This is important if the dataset has truncations (e.g., time limits)
truncated = False
if not done:
# If this is the last frame or if next frame is in a different episode, mark as truncated
if i == num_frames - 1:
truncated = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
truncated = True
# ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state.
@@ -716,7 +742,6 @@ class ReplayBuffer:
if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly
complementary_info[clean_key] = val
@@ -730,12 +755,10 @@ class ReplayBuffer:
truncated=truncated,
complementary_info=complementary_info,
)
transitions.append(transition)
return transitions
yield transition
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t, name: str):
"""
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
@@ -805,9 +828,9 @@ def concatenate_batch_transitions(
for key in left_batch_transitions["next_state"]
}
# Concatenate done and truncated fields
left_batch_transitions["done"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
# Concatenate terminated and truncated fields
left_batch_transitions["terminated"] = torch.cat(
[left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0
)
left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
+17 -16
View File
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
OBS_STATE: torch.randn(
10,
),
"done": torch.tensor(False),
"terminated": torch.tensor(False),
"truncated": torch.tensor(False),
"complementary_info": {},
}
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
"Action should be equal to the first transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
assert not replay_buffer.dones[0], "Done should be False for the first transition."
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
@@ -232,8 +232,8 @@ def test_add_over_capacity():
"Action should be equal to the last transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
assert replay_buffer.dones[0], "Done should be True for the first transition."
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
def test_sample_from_empty_buffer(replay_buffer):
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
action=dummy_action.clone(),
reward=1.0,
next_state=clone_state(next_dummy_state),
done=False,
terminated=False,
truncated=False,
)
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
action=dummy_action,
reward=1.0,
next_state=next_dummy_state,
done=False,
terminated=False,
truncated=False,
)
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
elif feature == REWARD:
assert torch.equal(value, buffer.rewards[i])
elif feature == DONE:
assert torch.equal(value, buffer.dones[i])
# DONE in dataset is terminated OR truncated
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer."
reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
), "Terminated flags from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
# LeRobot DS hasn't supported truncated yet
expected_truncated = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
"Truncateds from converted buffer should be equal False"
)
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
action_val = batch[ACTION][i].item()
reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5
is_terminated = batch["terminated"][i].item() > 0.5
# Verify relationships
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
f"Reward {reward_val} should be 3x state signature {state_sig}"
)
if is_done:
if is_terminated:
assert abs(next_state_sig - state_sig) < 1e-4, (
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
)
else:
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)