From 9014f9a7c571393f5fe7733b91a8928c76882f40 Mon Sep 17 00:00:00 2001 From: Michel Aractingi Date: Wed, 17 Dec 2025 15:52:26 +0100 Subject: [PATCH] refactor(buffer): use Gymnasium terminology (terminated/truncated) - Rename 'done' to 'terminated' for true task completion - Use 'truncated' for time-limit termination - Change torch.empty to torch.zeros for storage initialization - Convert _lerobotdataset_to_transitions to generator for memory efficiency - Add proper docstrings to BatchTransition and ReplayBuffer - Update concatenate_batch_transitions to use new terminology - Update tests to use new field names This aligns ReplayBuffer with Gymnasium's termination semantics where: - terminated: Episode ended due to task success/failure - truncated: Episode ended due to time limit or external factors --- src/lerobot/rl/buffer.py | 146 +++++++++++++++++------------- tests/utils/test_replay_buffer.py | 33 +++---- 2 files changed, 101 insertions(+), 78 deletions(-) diff --git a/src/lerobot/rl/buffer.py b/src/lerobot/rl/buffer.py index 81aa29c48..625f4bf6a 100644 --- a/src/lerobot/rl/buffer.py +++ b/src/lerobot/rl/buffer.py @@ -15,7 +15,8 @@ # limitations under the License. import functools -from collections.abc import Callable, Sequence +import itertools +from collections.abc import Callable, Generator, Sequence from contextlib import suppress from typing import TypedDict @@ -29,13 +30,20 @@ from lerobot.utils.transition import Transition class BatchTransition(TypedDict): + """Batch transition for single-step RL algorithms. + + Uses Gymnasium terminology: + - terminated: True termination due to task success/failure + - truncated: Termination due to time limit or other external factors + """ + state: dict[str, torch.Tensor] action: torch.Tensor reward: torch.Tensor next_state: dict[str, torch.Tensor] - done: torch.Tensor - truncated: torch.Tensor - complementary_info: dict[str, torch.Tensor | float | int] | None = None + terminated: torch.Tensor # True termination due to task success/failure + truncated: torch.Tensor # Termination due to time limit + complementary_info: dict[str, torch.Tensor] | None def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: @@ -78,6 +86,8 @@ def random_shift(images: torch.Tensor, pad: int = 4): class ReplayBuffer: + """Replay buffer for storing transitions used in RL training (e.g., SAC).""" + def __init__( self, capacity: int, @@ -133,25 +143,24 @@ class ReplayBuffer: self, state: dict[str, torch.Tensor], action: torch.Tensor, - complementary_info: dict[str, torch.Tensor] | None = None, + complementary_info: dict[str, torch.Tensor | float | int] | None = None, ): """Initialize the storage tensors based on the first transition.""" # Determine shapes from the first transition state_shapes = {key: val.squeeze(0).shape for key, val in state.items()} action_shape = action.squeeze(0).shape - # Pre-allocate tensors for storage self.states = { - key: torch.empty((self.capacity, *shape), device=self.storage_device) + key: torch.zeros((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } - self.actions = torch.empty((self.capacity, *action_shape), device=self.storage_device) - self.rewards = torch.empty((self.capacity,), device=self.storage_device) + self.actions = torch.zeros((self.capacity, *action_shape), device=self.storage_device) + self.rewards = torch.zeros((self.capacity,), device=self.storage_device) if not self.optimize_memory: # Standard approach: store states and next_states separately self.next_states = { - key: torch.empty((self.capacity, *shape), device=self.storage_device) + key: torch.zeros((self.capacity, *shape), device=self.storage_device) for key, shape in state_shapes.items() } else: @@ -159,8 +168,8 @@ class ReplayBuffer: # Just create a reference to states for consistent API self.next_states = self.states # Just a reference for API consistency - self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) - self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.terminated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device) + self.truncated = torch.zeros((self.capacity,), dtype=torch.bool, device=self.storage_device) # Initialize storage for complementary_info self.has_complementary_info = complementary_info is not None @@ -173,12 +182,12 @@ class ReplayBuffer: for key, value in complementary_info.items(): if isinstance(value, torch.Tensor): value_shape = value.squeeze(0).shape - self.complementary_info[key] = torch.empty( + self.complementary_info[key] = torch.zeros( (self.capacity, *value_shape), device=self.storage_device ) elif isinstance(value, (int | float)): # Handle scalar values similar to reward - self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device) + self.complementary_info[key] = torch.zeros((self.capacity,), device=self.storage_device) else: raise ValueError(f"Unsupported type {type(value)} for complementary_info[{key}]") @@ -195,7 +204,7 @@ class ReplayBuffer: next_state: dict[str, torch.Tensor], done: bool, truncated: bool, - complementary_info: dict[str, torch.Tensor] | None = None, + complementary_info: dict[str, torch.Tensor | float | int] | None = None, ): """Saves a transition, ensuring tensors are stored on the designated storage device.""" # Initialize storage if this is the first transition @@ -212,8 +221,8 @@ class ReplayBuffer: self.actions[self.position].copy_(action.squeeze(dim=0)) self.rewards[self.position] = reward - self.dones[self.position] = done - self.truncateds[self.position] = truncated + self.terminated[self.position] = done + self.truncated[self.position] = truncated # Handle complementary_info if provided and storage is initialized if complementary_info is not None and self.has_complementary_info: @@ -283,8 +292,8 @@ class ReplayBuffer: # Sample other tensors batch_actions = self.actions[idx].to(self.device) batch_rewards = self.rewards[idx].to(self.device) - batch_dones = self.dones[idx].to(self.device).float() - batch_truncateds = self.truncateds[idx].to(self.device).float() + batch_terminated = self.terminated[idx].to(self.device).float() + batch_truncated = self.truncated[idx].to(self.device).float() # Sample complementary_info if available batch_complementary_info = None @@ -298,8 +307,8 @@ class ReplayBuffer: action=batch_actions, reward=batch_rewards, next_state=batch_next_state, - done=batch_dones, - truncated=batch_truncateds, + terminated=batch_terminated, + truncated=batch_truncated, complementary_info=batch_complementary_info, ) @@ -431,7 +440,6 @@ class ReplayBuffer: device (str): The device for sampling tensors. Defaults to "cuda:0". state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`. capacity (int | None): Buffer capacity. If None, uses dataset length. - action_mask (Sequence[int] | None): Indices of action dimensions to keep. image_augmentation_function (Callable | None): Function for image augmentation. If None, uses default random shift with pad=4. use_drq (bool): Whether to use DrQ image augmentation when sampling. @@ -460,12 +468,16 @@ class ReplayBuffer: optimize_memory=optimize_memory, ) - # Convert dataset to transitions - list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) + # Convert dataset to transitions generator + transitions_generator = cls._lerobotdataset_to_transitions( + dataset=lerobot_dataset, state_keys=state_keys + ) + + # Get first transition to initialize storage + first_transition = next(transitions_generator, None) # Initialize the buffer with the first transition to set up storage tensors - if list_transition: - first_transition = list_transition[0] + if first_transition is not None: first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_action = first_transition[ACTION].to(device) @@ -483,26 +495,28 @@ class ReplayBuffer: state=first_state, action=first_action, complementary_info=first_complementary_info ) - # Fill the buffer with all transitions - for data in list_transition: - for k, v in data.items(): - if isinstance(v, dict): - for key, tensor in v.items(): - v[key] = tensor.to(storage_device) - elif isinstance(v, torch.Tensor): - data[k] = v.to(storage_device) + # Fill the buffer with all transitions (first + remaining) + if first_transition is not None: + for data in itertools.chain([first_transition], transitions_generator): + for k, v in data.items(): + if isinstance(v, dict): + for key, tensor in v.items(): + if isinstance(tensor, torch.Tensor): + v[key] = tensor.to(storage_device) + elif isinstance(v, torch.Tensor): + data[k] = v.to(storage_device) - action = data[ACTION] + action = data[ACTION] - replay_buffer.add( - state=data["state"], - action=action, - reward=data["reward"], - next_state=data["next_state"], - done=data["done"], - truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset - complementary_info=data.get("complementary_info", None), - ) + replay_buffer.add( + state=data["state"], + action=action, + reward=data["reward"], + next_state=data["next_state"], + done=data["done"], + truncated=data["truncated"], + complementary_info=data.get("complementary_info"), + ) return replay_buffer @@ -576,10 +590,12 @@ class ReplayBuffer: for key in self.states: frame_dict[key] = self.states[key][actual_idx].cpu() - # Fill action, reward, done + # Fill action, reward, done (done = terminated or truncated) frame_dict[ACTION] = self.actions[actual_idx].cpu() frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() - frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() + frame_dict[DONE] = torch.tensor( + [self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool + ).cpu() frame_dict["task"] = task_name # Add complementary_info if available @@ -599,7 +615,7 @@ class ReplayBuffer: lerobot_dataset.add_frame(frame_dict) # If we reached an episode boundary, call save_episode, reset counters - if self.dones[actual_idx] or self.truncateds[actual_idx]: + if self.terminated[actual_idx] or self.truncated[actual_idx]: lerobot_dataset.save_episode() # Save any remaining frames in the buffer @@ -615,9 +631,11 @@ class ReplayBuffer: def _lerobotdataset_to_transitions( dataset: LeRobotDataset, state_keys: Sequence[str] | None = None, - ) -> list[Transition]: + ) -> Generator[Transition, None, None]: """ - Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. + Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions. + + Using a generator instead of a list is more memory efficient for large datasets. Args: dataset (LeRobotDataset): @@ -637,14 +655,12 @@ class ReplayBuffer: ["observation.state", "observation.environment_state"]. If None, you must handle or define default keys. - Returns: - transitions (List[Transition]): - A list of Transition dictionaries with the same length as `dataset`. + Yields: + Transition: A transition dictionary. """ if state_keys is None: raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") - transitions = [] num_frames = len(dataset) # Check if the dataset has "next.done" key @@ -687,8 +703,17 @@ class ReplayBuffer: if next_sample["episode_index"] != current_sample["episode_index"]: done = True - # TODO: (azouitine) Handle truncation (using the same value as done for now) - truncated = done + # Handle truncation separately from done + # This is important if the dataset has truncations (e.g., time limits) + truncated = False + if not done: + # If this is the last frame or if next frame is in a different episode, mark as truncated + if i == num_frames - 1: + truncated = True + elif i < num_frames - 1: + next_sample = dataset[i + 1] + if next_sample["episode_index"] != current_sample["episode_index"]: + truncated = True # ----- 4) Next state ----- # If not done and the next sample is in the same episode, we pull the next sample's state. @@ -716,7 +741,6 @@ class ReplayBuffer: if isinstance(val, torch.Tensor): complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension else: - # TODO: (azouitine) Check if it's necessary to convert to tensor # For non-tensor values, use directly complementary_info[clean_key] = val @@ -730,12 +754,10 @@ class ReplayBuffer: truncated=truncated, complementary_info=complementary_info, ) - transitions.append(transition) - return transitions + yield transition -# Utility function to guess shapes/dtypes from a tensor def guess_feature_info(t, name: str): """ Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. @@ -805,9 +827,9 @@ def concatenate_batch_transitions( for key in left_batch_transitions["next_state"] } - # Concatenate done and truncated fields - left_batch_transitions["done"] = torch.cat( - [left_batch_transitions["done"], right_batch_transition["done"]], dim=0 + # Concatenate terminated and truncated fields + left_batch_transitions["terminated"] = torch.cat( + [left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0 ) left_batch_transitions["truncated"] = torch.cat( [left_batch_transitions["truncated"], right_batch_transition["truncated"]], diff --git a/tests/utils/test_replay_buffer.py b/tests/utils/test_replay_buffer.py index b9d3a1ac0..207dfc742 100644 --- a/tests/utils/test_replay_buffer.py +++ b/tests/utils/test_replay_buffer.py @@ -68,7 +68,7 @@ def create_dummy_transition() -> dict: OBS_STATE: torch.randn( 10, ), - "done": torch.tensor(False), + "terminated": torch.tensor(False), "truncated": torch.tensor(False), "complementary_info": {}, } @@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action): "Action should be equal to the first transition." ) assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." - assert not replay_buffer.dones[0], "Done should be False for the first transition." - assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." + assert not replay_buffer.terminated[0], "Terminated should be False for the first transition." + assert not replay_buffer.truncated[0], "Truncated should be False for the first transition." for dim in state_dims(): assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( @@ -232,8 +232,8 @@ def test_add_over_capacity(): "Action should be equal to the last transition." ) assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." - assert replay_buffer.dones[0], "Done should be True for the first transition." - assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." + assert replay_buffer.terminated[0], "Terminated should be True for the first transition." + assert replay_buffer.truncated[0], "Truncated should be True for the first transition." def test_sample_from_empty_buffer(replay_buffer): @@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, action=dummy_action.clone(), reward=1.0, next_state=clone_state(next_dummy_state), - done=False, + terminated=False, truncated=False, ) @@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size( action=dummy_action, reward=1.0, next_state=next_dummy_state, - done=False, + terminated=False, truncated=False, ) @@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path): elif feature == REWARD: assert torch.equal(value, buffer.rewards[i]) elif feature == DONE: - assert torch.equal(value, buffer.dones[i]) + # DONE in dataset is terminated OR truncated + assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i]) elif feature == OBS_IMAGE: # Tensor -> numpy is not precise, so we have some diff there # TODO: Check and fix it @@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path): reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] ), "Rewards from converted buffer should be equal to the original replay buffer." assert torch.equal( - reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] - ), "Dones from converted buffer should be equal to the original replay buffer." + reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)] + ), "Terminated flags from converted buffer should be equal to the original replay buffer." - # Lerobot DS haven't supported truncateds yet - expected_truncateds = torch.zeros(len(replay_buffer)).bool() - assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( + # LeRobot DS hasn't supported truncated yet + expected_truncated = torch.zeros(len(replay_buffer)).bool() + assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), ( "Truncateds from converted buffer should be equal False" ) @@ -498,7 +499,7 @@ def test_buffer_sample_alignment(): action_val = batch[ACTION][i].item() reward_val = batch["reward"][i].item() next_state_sig = batch["next_state"]["state_value"][i].item() - is_done = batch["done"][i].item() > 0.5 + is_terminated = batch["terminated"][i].item() > 0.5 # Verify relationships assert abs(action_val - 2.0 * state_sig) < 1e-4, ( @@ -509,9 +510,9 @@ def test_buffer_sample_alignment(): f"Reward {reward_val} should be 3x state signature {state_sig}" ) - if is_done: + if is_terminated: assert abs(next_state_sig - state_sig) < 1e-4, ( - f"For done states, next_state {next_state_sig} should equal state {state_sig}" + f"For terminated states, next_state {next_state_sig} should equal state {state_sig}" ) else: # Either it's the next sequential state (+0.01) or same state (for episode boundaries)