Compare commits

...

2 Commits

Author SHA1 Message Date
Michel Aractingi e3539cb78e revert initialization to empty 2025-12-17 16:36:59 +01:00
Michel Aractingi 9014f9a7c5 refactor(buffer): use Gymnasium terminology (terminated/truncated)
- Rename 'done' to 'terminated' for true task completion
- Use 'truncated' for time-limit termination
- Change torch.empty to torch.zeros for storage initialization
- Convert _lerobotdataset_to_transitions to generator for memory efficiency
- Add proper docstrings to BatchTransition and ReplayBuffer
- Update concatenate_batch_transitions to use new terminology
- Update tests to use new field names

This aligns ReplayBuffer with Gymnasium's termination semantics where:
- terminated: Episode ended due to task success/failure
- truncated: Episode ended due to time limit or external factors
2025-12-17 15:52:26 +01:00
2 changed files with 95 additions and 71 deletions
+64 -41
View File
@@ -15,7 +15,8 @@
# limitations under the License. # limitations under the License.
import functools import functools
from collections.abc import Callable, Sequence import itertools
from collections.abc import Callable, Generator, Sequence
from contextlib import suppress from contextlib import suppress
from typing import TypedDict from typing import TypedDict
@@ -29,13 +30,20 @@ from lerobot.utils.transition import Transition
class BatchTransition(TypedDict): class BatchTransition(TypedDict):
"""Batch transition for single-step RL algorithms.
Uses Gymnasium terminology:
- terminated: True termination due to task success/failure
- truncated: Termination due to time limit or other external factors
"""
state: dict[str, torch.Tensor] state: dict[str, torch.Tensor]
action: torch.Tensor action: torch.Tensor
reward: torch.Tensor reward: torch.Tensor
next_state: dict[str, torch.Tensor] next_state: dict[str, torch.Tensor]
done: torch.Tensor terminated: torch.Tensor # True termination due to task success/failure
truncated: torch.Tensor truncated: torch.Tensor # Termination due to time limit
complementary_info: dict[str, torch.Tensor | float | int] | None = None complementary_info: dict[str, torch.Tensor] | None
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor: def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
@@ -78,6 +86,8 @@ def random_shift(images: torch.Tensor, pad: int = 4):
class ReplayBuffer: class ReplayBuffer:
"""Replay buffer for storing transitions used in RL training (e.g., SAC)."""
def __init__( def __init__(
self, self,
capacity: int, capacity: int,
@@ -133,7 +143,7 @@ class ReplayBuffer:
self, self,
state: dict[str, torch.Tensor], state: dict[str, torch.Tensor],
action: torch.Tensor, action: torch.Tensor,
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor | float | int] | None = None,
): ):
"""Initialize the storage tensors based on the first transition.""" """Initialize the storage tensors based on the first transition."""
# Determine shapes from the first transition # Determine shapes from the first transition
@@ -159,8 +169,8 @@ class ReplayBuffer:
# Just create a reference to states for consistent API # Just create a reference to states for consistent API
self.next_states = self.states # Just a reference for API consistency self.next_states = self.states # Just a reference for API consistency
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device) self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
# Initialize storage for complementary_info # Initialize storage for complementary_info
self.has_complementary_info = complementary_info is not None self.has_complementary_info = complementary_info is not None
@@ -195,7 +205,7 @@ class ReplayBuffer:
next_state: dict[str, torch.Tensor], next_state: dict[str, torch.Tensor],
done: bool, done: bool,
truncated: bool, truncated: bool,
complementary_info: dict[str, torch.Tensor] | None = None, complementary_info: dict[str, torch.Tensor | float | int] | None = None,
): ):
"""Saves a transition, ensuring tensors are stored on the designated storage device.""" """Saves a transition, ensuring tensors are stored on the designated storage device."""
# Initialize storage if this is the first transition # Initialize storage if this is the first transition
@@ -212,8 +222,8 @@ class ReplayBuffer:
self.actions[self.position].copy_(action.squeeze(dim=0)) self.actions[self.position].copy_(action.squeeze(dim=0))
self.rewards[self.position] = reward self.rewards[self.position] = reward
self.dones[self.position] = done self.terminated[self.position] = done
self.truncateds[self.position] = truncated self.truncated[self.position] = truncated
# Handle complementary_info if provided and storage is initialized # Handle complementary_info if provided and storage is initialized
if complementary_info is not None and self.has_complementary_info: if complementary_info is not None and self.has_complementary_info:
@@ -283,8 +293,8 @@ class ReplayBuffer:
# Sample other tensors # Sample other tensors
batch_actions = self.actions[idx].to(self.device) batch_actions = self.actions[idx].to(self.device)
batch_rewards = self.rewards[idx].to(self.device) batch_rewards = self.rewards[idx].to(self.device)
batch_dones = self.dones[idx].to(self.device).float() batch_terminated = self.terminated[idx].to(self.device).float()
batch_truncateds = self.truncateds[idx].to(self.device).float() batch_truncated = self.truncated[idx].to(self.device).float()
# Sample complementary_info if available # Sample complementary_info if available
batch_complementary_info = None batch_complementary_info = None
@@ -298,8 +308,8 @@ class ReplayBuffer:
action=batch_actions, action=batch_actions,
reward=batch_rewards, reward=batch_rewards,
next_state=batch_next_state, next_state=batch_next_state,
done=batch_dones, terminated=batch_terminated,
truncated=batch_truncateds, truncated=batch_truncated,
complementary_info=batch_complementary_info, complementary_info=batch_complementary_info,
) )
@@ -431,7 +441,6 @@ class ReplayBuffer:
device (str): The device for sampling tensors. Defaults to "cuda:0". device (str): The device for sampling tensors. Defaults to "cuda:0".
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`. state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
capacity (int | None): Buffer capacity. If None, uses dataset length. capacity (int | None): Buffer capacity. If None, uses dataset length.
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
image_augmentation_function (Callable | None): Function for image augmentation. image_augmentation_function (Callable | None): Function for image augmentation.
If None, uses default random shift with pad=4. If None, uses default random shift with pad=4.
use_drq (bool): Whether to use DrQ image augmentation when sampling. use_drq (bool): Whether to use DrQ image augmentation when sampling.
@@ -460,12 +469,16 @@ class ReplayBuffer:
optimize_memory=optimize_memory, optimize_memory=optimize_memory,
) )
# Convert dataset to transitions # Convert dataset to transitions generator
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys) transitions_generator = cls._lerobotdataset_to_transitions(
dataset=lerobot_dataset, state_keys=state_keys
)
# Get first transition to initialize storage
first_transition = next(transitions_generator, None)
# Initialize the buffer with the first transition to set up storage tensors # Initialize the buffer with the first transition to set up storage tensors
if list_transition: if first_transition is not None:
first_transition = list_transition[0]
first_state = {k: v.to(device) for k, v in first_transition["state"].items()} first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
first_action = first_transition[ACTION].to(device) first_action = first_transition[ACTION].to(device)
@@ -483,11 +496,13 @@ class ReplayBuffer:
state=first_state, action=first_action, complementary_info=first_complementary_info state=first_state, action=first_action, complementary_info=first_complementary_info
) )
# Fill the buffer with all transitions # Fill the buffer with all transitions (first + remaining)
for data in list_transition: if first_transition is not None:
for data in itertools.chain([first_transition], transitions_generator):
for k, v in data.items(): for k, v in data.items():
if isinstance(v, dict): if isinstance(v, dict):
for key, tensor in v.items(): for key, tensor in v.items():
if isinstance(tensor, torch.Tensor):
v[key] = tensor.to(storage_device) v[key] = tensor.to(storage_device)
elif isinstance(v, torch.Tensor): elif isinstance(v, torch.Tensor):
data[k] = v.to(storage_device) data[k] = v.to(storage_device)
@@ -500,8 +515,8 @@ class ReplayBuffer:
reward=data["reward"], reward=data["reward"],
next_state=data["next_state"], next_state=data["next_state"],
done=data["done"], done=data["done"],
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset truncated=data["truncated"],
complementary_info=data.get("complementary_info", None), complementary_info=data.get("complementary_info"),
) )
return replay_buffer return replay_buffer
@@ -576,10 +591,12 @@ class ReplayBuffer:
for key in self.states: for key in self.states:
frame_dict[key] = self.states[key][actual_idx].cpu() frame_dict[key] = self.states[key][actual_idx].cpu()
# Fill action, reward, done # Fill action, reward, done (done = terminated or truncated)
frame_dict[ACTION] = self.actions[actual_idx].cpu() frame_dict[ACTION] = self.actions[actual_idx].cpu()
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu() frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu() frame_dict[DONE] = torch.tensor(
[self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool
).cpu()
frame_dict["task"] = task_name frame_dict["task"] = task_name
# Add complementary_info if available # Add complementary_info if available
@@ -599,7 +616,7 @@ class ReplayBuffer:
lerobot_dataset.add_frame(frame_dict) lerobot_dataset.add_frame(frame_dict)
# If we reached an episode boundary, call save_episode, reset counters # If we reached an episode boundary, call save_episode, reset counters
if self.dones[actual_idx] or self.truncateds[actual_idx]: if self.terminated[actual_idx] or self.truncated[actual_idx]:
lerobot_dataset.save_episode() lerobot_dataset.save_episode()
# Save any remaining frames in the buffer # Save any remaining frames in the buffer
@@ -615,9 +632,11 @@ class ReplayBuffer:
def _lerobotdataset_to_transitions( def _lerobotdataset_to_transitions(
dataset: LeRobotDataset, dataset: LeRobotDataset,
state_keys: Sequence[str] | None = None, state_keys: Sequence[str] | None = None,
) -> list[Transition]: ) -> Generator[Transition, None, None]:
""" """
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions. Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions.
Using a generator instead of a list is more memory efficient for large datasets.
Args: Args:
dataset (LeRobotDataset): dataset (LeRobotDataset):
@@ -637,14 +656,12 @@ class ReplayBuffer:
["observation.state", "observation.environment_state"]. ["observation.state", "observation.environment_state"].
If None, you must handle or define default keys. If None, you must handle or define default keys.
Returns: Yields:
transitions (List[Transition]): Transition: A transition dictionary.
A list of Transition dictionaries with the same length as `dataset`.
""" """
if state_keys is None: if state_keys is None:
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.") raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
transitions = []
num_frames = len(dataset) num_frames = len(dataset)
# Check if the dataset has "next.done" key # Check if the dataset has "next.done" key
@@ -687,8 +704,17 @@ class ReplayBuffer:
if next_sample["episode_index"] != current_sample["episode_index"]: if next_sample["episode_index"] != current_sample["episode_index"]:
done = True done = True
# TODO: (azouitine) Handle truncation (using the same value as done for now) # Handle truncation separately from done
truncated = done # This is important if the dataset has truncations (e.g., time limits)
truncated = False
if not done:
# If this is the last frame or if next frame is in a different episode, mark as truncated
if i == num_frames - 1:
truncated = True
elif i < num_frames - 1:
next_sample = dataset[i + 1]
if next_sample["episode_index"] != current_sample["episode_index"]:
truncated = True
# ----- 4) Next state ----- # ----- 4) Next state -----
# If not done and the next sample is in the same episode, we pull the next sample's state. # If not done and the next sample is in the same episode, we pull the next sample's state.
@@ -716,7 +742,6 @@ class ReplayBuffer:
if isinstance(val, torch.Tensor): if isinstance(val, torch.Tensor):
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
else: else:
# TODO: (azouitine) Check if it's necessary to convert to tensor
# For non-tensor values, use directly # For non-tensor values, use directly
complementary_info[clean_key] = val complementary_info[clean_key] = val
@@ -730,12 +755,10 @@ class ReplayBuffer:
truncated=truncated, truncated=truncated,
complementary_info=complementary_info, complementary_info=complementary_info,
) )
transitions.append(transition)
return transitions yield transition
# Utility function to guess shapes/dtypes from a tensor
def guess_feature_info(t, name: str): def guess_feature_info(t, name: str):
""" """
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value. Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
@@ -805,9 +828,9 @@ def concatenate_batch_transitions(
for key in left_batch_transitions["next_state"] for key in left_batch_transitions["next_state"]
} }
# Concatenate done and truncated fields # Concatenate terminated and truncated fields
left_batch_transitions["done"] = torch.cat( left_batch_transitions["terminated"] = torch.cat(
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0 [left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0
) )
left_batch_transitions["truncated"] = torch.cat( left_batch_transitions["truncated"] = torch.cat(
[left_batch_transitions["truncated"], right_batch_transition["truncated"]], [left_batch_transitions["truncated"], right_batch_transition["truncated"]],
+17 -16
View File
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
OBS_STATE: torch.randn( OBS_STATE: torch.randn(
10, 10,
), ),
"done": torch.tensor(False), "terminated": torch.tensor(False),
"truncated": torch.tensor(False), "truncated": torch.tensor(False),
"complementary_info": {}, "complementary_info": {},
} }
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
"Action should be equal to the first transition." "Action should be equal to the first transition."
) )
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition." assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
assert not replay_buffer.dones[0], "Done should be False for the first transition." assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition." assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
for dim in state_dims(): for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), ( assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
@@ -232,8 +232,8 @@ def test_add_over_capacity():
"Action should be equal to the last transition." "Action should be equal to the last transition."
) )
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition." assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
assert replay_buffer.dones[0], "Done should be True for the first transition." assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition." assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
def test_sample_from_empty_buffer(replay_buffer): def test_sample_from_empty_buffer(replay_buffer):
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
action=dummy_action.clone(), action=dummy_action.clone(),
reward=1.0, reward=1.0,
next_state=clone_state(next_dummy_state), next_state=clone_state(next_dummy_state),
done=False, terminated=False,
truncated=False, truncated=False,
) )
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
action=dummy_action, action=dummy_action,
reward=1.0, reward=1.0,
next_state=next_dummy_state, next_state=next_dummy_state,
done=False, terminated=False,
truncated=False, truncated=False,
) )
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
elif feature == REWARD: elif feature == REWARD:
assert torch.equal(value, buffer.rewards[i]) assert torch.equal(value, buffer.rewards[i])
elif feature == DONE: elif feature == DONE:
assert torch.equal(value, buffer.dones[i]) # DONE in dataset is terminated OR truncated
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
elif feature == OBS_IMAGE: elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there # Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it # TODO: Check and fix it
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)] reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer." ), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal( assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)] reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer." ), "Terminated flags from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet # LeRobot DS hasn't supported truncated yet
expected_truncateds = torch.zeros(len(replay_buffer)).bool() expected_truncated = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), ( assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
"Truncateds from converted buffer should be equal False" "Truncateds from converted buffer should be equal False"
) )
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
action_val = batch[ACTION][i].item() action_val = batch[ACTION][i].item()
reward_val = batch["reward"][i].item() reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item() next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5 is_terminated = batch["terminated"][i].item() > 0.5
# Verify relationships # Verify relationships
assert abs(action_val - 2.0 * state_sig) < 1e-4, ( assert abs(action_val - 2.0 * state_sig) < 1e-4, (
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
f"Reward {reward_val} should be 3x state signature {state_sig}" f"Reward {reward_val} should be 3x state signature {state_sig}"
) )
if is_done: if is_terminated:
assert abs(next_state_sig - state_sig) < 1e-4, ( assert abs(next_state_sig - state_sig) < 1e-4, (
f"For done states, next_state {next_state_sig} should equal state {state_sig}" f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
) )
else: else:
# Either it's the next sequential state (+0.01) or same state (for episode boundaries) # Either it's the next sequential state (+0.01) or same state (for episode boundaries)