mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
Compare commits
2 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| e3539cb78e | |||
| 9014f9a7c5 |
+64
-41
@@ -15,7 +15,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import functools
|
import functools
|
||||||
from collections.abc import Callable, Sequence
|
import itertools
|
||||||
|
from collections.abc import Callable, Generator, Sequence
|
||||||
from contextlib import suppress
|
from contextlib import suppress
|
||||||
from typing import TypedDict
|
from typing import TypedDict
|
||||||
|
|
||||||
@@ -29,13 +30,20 @@ from lerobot.utils.transition import Transition
|
|||||||
|
|
||||||
|
|
||||||
class BatchTransition(TypedDict):
|
class BatchTransition(TypedDict):
|
||||||
|
"""Batch transition for single-step RL algorithms.
|
||||||
|
|
||||||
|
Uses Gymnasium terminology:
|
||||||
|
- terminated: True termination due to task success/failure
|
||||||
|
- truncated: Termination due to time limit or other external factors
|
||||||
|
"""
|
||||||
|
|
||||||
state: dict[str, torch.Tensor]
|
state: dict[str, torch.Tensor]
|
||||||
action: torch.Tensor
|
action: torch.Tensor
|
||||||
reward: torch.Tensor
|
reward: torch.Tensor
|
||||||
next_state: dict[str, torch.Tensor]
|
next_state: dict[str, torch.Tensor]
|
||||||
done: torch.Tensor
|
terminated: torch.Tensor # True termination due to task success/failure
|
||||||
truncated: torch.Tensor
|
truncated: torch.Tensor # Termination due to time limit
|
||||||
complementary_info: dict[str, torch.Tensor | float | int] | None = None
|
complementary_info: dict[str, torch.Tensor] | None
|
||||||
|
|
||||||
|
|
||||||
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
def random_crop_vectorized(images: torch.Tensor, output_size: tuple) -> torch.Tensor:
|
||||||
@@ -78,6 +86,8 @@ def random_shift(images: torch.Tensor, pad: int = 4):
|
|||||||
|
|
||||||
|
|
||||||
class ReplayBuffer:
|
class ReplayBuffer:
|
||||||
|
"""Replay buffer for storing transitions used in RL training (e.g., SAC)."""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
capacity: int,
|
capacity: int,
|
||||||
@@ -133,7 +143,7 @@ class ReplayBuffer:
|
|||||||
self,
|
self,
|
||||||
state: dict[str, torch.Tensor],
|
state: dict[str, torch.Tensor],
|
||||||
action: torch.Tensor,
|
action: torch.Tensor,
|
||||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
|
||||||
):
|
):
|
||||||
"""Initialize the storage tensors based on the first transition."""
|
"""Initialize the storage tensors based on the first transition."""
|
||||||
# Determine shapes from the first transition
|
# Determine shapes from the first transition
|
||||||
@@ -159,8 +169,8 @@ class ReplayBuffer:
|
|||||||
# Just create a reference to states for consistent API
|
# Just create a reference to states for consistent API
|
||||||
self.next_states = self.states # Just a reference for API consistency
|
self.next_states = self.states # Just a reference for API consistency
|
||||||
|
|
||||||
self.dones = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.terminated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
self.truncateds = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
self.truncated = torch.empty((self.capacity,), dtype=torch.bool, device=self.storage_device)
|
||||||
|
|
||||||
# Initialize storage for complementary_info
|
# Initialize storage for complementary_info
|
||||||
self.has_complementary_info = complementary_info is not None
|
self.has_complementary_info = complementary_info is not None
|
||||||
@@ -195,7 +205,7 @@ class ReplayBuffer:
|
|||||||
next_state: dict[str, torch.Tensor],
|
next_state: dict[str, torch.Tensor],
|
||||||
done: bool,
|
done: bool,
|
||||||
truncated: bool,
|
truncated: bool,
|
||||||
complementary_info: dict[str, torch.Tensor] | None = None,
|
complementary_info: dict[str, torch.Tensor | float | int] | None = None,
|
||||||
):
|
):
|
||||||
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
"""Saves a transition, ensuring tensors are stored on the designated storage device."""
|
||||||
# Initialize storage if this is the first transition
|
# Initialize storage if this is the first transition
|
||||||
@@ -212,8 +222,8 @@ class ReplayBuffer:
|
|||||||
|
|
||||||
self.actions[self.position].copy_(action.squeeze(dim=0))
|
self.actions[self.position].copy_(action.squeeze(dim=0))
|
||||||
self.rewards[self.position] = reward
|
self.rewards[self.position] = reward
|
||||||
self.dones[self.position] = done
|
self.terminated[self.position] = done
|
||||||
self.truncateds[self.position] = truncated
|
self.truncated[self.position] = truncated
|
||||||
|
|
||||||
# Handle complementary_info if provided and storage is initialized
|
# Handle complementary_info if provided and storage is initialized
|
||||||
if complementary_info is not None and self.has_complementary_info:
|
if complementary_info is not None and self.has_complementary_info:
|
||||||
@@ -283,8 +293,8 @@ class ReplayBuffer:
|
|||||||
# Sample other tensors
|
# Sample other tensors
|
||||||
batch_actions = self.actions[idx].to(self.device)
|
batch_actions = self.actions[idx].to(self.device)
|
||||||
batch_rewards = self.rewards[idx].to(self.device)
|
batch_rewards = self.rewards[idx].to(self.device)
|
||||||
batch_dones = self.dones[idx].to(self.device).float()
|
batch_terminated = self.terminated[idx].to(self.device).float()
|
||||||
batch_truncateds = self.truncateds[idx].to(self.device).float()
|
batch_truncated = self.truncated[idx].to(self.device).float()
|
||||||
|
|
||||||
# Sample complementary_info if available
|
# Sample complementary_info if available
|
||||||
batch_complementary_info = None
|
batch_complementary_info = None
|
||||||
@@ -298,8 +308,8 @@ class ReplayBuffer:
|
|||||||
action=batch_actions,
|
action=batch_actions,
|
||||||
reward=batch_rewards,
|
reward=batch_rewards,
|
||||||
next_state=batch_next_state,
|
next_state=batch_next_state,
|
||||||
done=batch_dones,
|
terminated=batch_terminated,
|
||||||
truncated=batch_truncateds,
|
truncated=batch_truncated,
|
||||||
complementary_info=batch_complementary_info,
|
complementary_info=batch_complementary_info,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -431,7 +441,6 @@ class ReplayBuffer:
|
|||||||
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
device (str): The device for sampling tensors. Defaults to "cuda:0".
|
||||||
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
|
state_keys (Sequence[str] | None): The list of keys that appear in `state` and `next_state`.
|
||||||
capacity (int | None): Buffer capacity. If None, uses dataset length.
|
capacity (int | None): Buffer capacity. If None, uses dataset length.
|
||||||
action_mask (Sequence[int] | None): Indices of action dimensions to keep.
|
|
||||||
image_augmentation_function (Callable | None): Function for image augmentation.
|
image_augmentation_function (Callable | None): Function for image augmentation.
|
||||||
If None, uses default random shift with pad=4.
|
If None, uses default random shift with pad=4.
|
||||||
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
use_drq (bool): Whether to use DrQ image augmentation when sampling.
|
||||||
@@ -460,12 +469,16 @@ class ReplayBuffer:
|
|||||||
optimize_memory=optimize_memory,
|
optimize_memory=optimize_memory,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Convert dataset to transitions
|
# Convert dataset to transitions generator
|
||||||
list_transition = cls._lerobotdataset_to_transitions(dataset=lerobot_dataset, state_keys=state_keys)
|
transitions_generator = cls._lerobotdataset_to_transitions(
|
||||||
|
dataset=lerobot_dataset, state_keys=state_keys
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get first transition to initialize storage
|
||||||
|
first_transition = next(transitions_generator, None)
|
||||||
|
|
||||||
# Initialize the buffer with the first transition to set up storage tensors
|
# Initialize the buffer with the first transition to set up storage tensors
|
||||||
if list_transition:
|
if first_transition is not None:
|
||||||
first_transition = list_transition[0]
|
|
||||||
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
first_state = {k: v.to(device) for k, v in first_transition["state"].items()}
|
||||||
first_action = first_transition[ACTION].to(device)
|
first_action = first_transition[ACTION].to(device)
|
||||||
|
|
||||||
@@ -483,11 +496,13 @@ class ReplayBuffer:
|
|||||||
state=first_state, action=first_action, complementary_info=first_complementary_info
|
state=first_state, action=first_action, complementary_info=first_complementary_info
|
||||||
)
|
)
|
||||||
|
|
||||||
# Fill the buffer with all transitions
|
# Fill the buffer with all transitions (first + remaining)
|
||||||
for data in list_transition:
|
if first_transition is not None:
|
||||||
|
for data in itertools.chain([first_transition], transitions_generator):
|
||||||
for k, v in data.items():
|
for k, v in data.items():
|
||||||
if isinstance(v, dict):
|
if isinstance(v, dict):
|
||||||
for key, tensor in v.items():
|
for key, tensor in v.items():
|
||||||
|
if isinstance(tensor, torch.Tensor):
|
||||||
v[key] = tensor.to(storage_device)
|
v[key] = tensor.to(storage_device)
|
||||||
elif isinstance(v, torch.Tensor):
|
elif isinstance(v, torch.Tensor):
|
||||||
data[k] = v.to(storage_device)
|
data[k] = v.to(storage_device)
|
||||||
@@ -500,8 +515,8 @@ class ReplayBuffer:
|
|||||||
reward=data["reward"],
|
reward=data["reward"],
|
||||||
next_state=data["next_state"],
|
next_state=data["next_state"],
|
||||||
done=data["done"],
|
done=data["done"],
|
||||||
truncated=False, # NOTE: Truncation are not supported yet in lerobot dataset
|
truncated=data["truncated"],
|
||||||
complementary_info=data.get("complementary_info", None),
|
complementary_info=data.get("complementary_info"),
|
||||||
)
|
)
|
||||||
|
|
||||||
return replay_buffer
|
return replay_buffer
|
||||||
@@ -576,10 +591,12 @@ class ReplayBuffer:
|
|||||||
for key in self.states:
|
for key in self.states:
|
||||||
frame_dict[key] = self.states[key][actual_idx].cpu()
|
frame_dict[key] = self.states[key][actual_idx].cpu()
|
||||||
|
|
||||||
# Fill action, reward, done
|
# Fill action, reward, done (done = terminated or truncated)
|
||||||
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
frame_dict[ACTION] = self.actions[actual_idx].cpu()
|
||||||
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
frame_dict[REWARD] = torch.tensor([self.rewards[actual_idx]], dtype=torch.float32).cpu()
|
||||||
frame_dict[DONE] = torch.tensor([self.dones[actual_idx]], dtype=torch.bool).cpu()
|
frame_dict[DONE] = torch.tensor(
|
||||||
|
[self.terminated[actual_idx] or self.truncated[actual_idx]], dtype=torch.bool
|
||||||
|
).cpu()
|
||||||
frame_dict["task"] = task_name
|
frame_dict["task"] = task_name
|
||||||
|
|
||||||
# Add complementary_info if available
|
# Add complementary_info if available
|
||||||
@@ -599,7 +616,7 @@ class ReplayBuffer:
|
|||||||
lerobot_dataset.add_frame(frame_dict)
|
lerobot_dataset.add_frame(frame_dict)
|
||||||
|
|
||||||
# If we reached an episode boundary, call save_episode, reset counters
|
# If we reached an episode boundary, call save_episode, reset counters
|
||||||
if self.dones[actual_idx] or self.truncateds[actual_idx]:
|
if self.terminated[actual_idx] or self.truncated[actual_idx]:
|
||||||
lerobot_dataset.save_episode()
|
lerobot_dataset.save_episode()
|
||||||
|
|
||||||
# Save any remaining frames in the buffer
|
# Save any remaining frames in the buffer
|
||||||
@@ -615,9 +632,11 @@ class ReplayBuffer:
|
|||||||
def _lerobotdataset_to_transitions(
|
def _lerobotdataset_to_transitions(
|
||||||
dataset: LeRobotDataset,
|
dataset: LeRobotDataset,
|
||||||
state_keys: Sequence[str] | None = None,
|
state_keys: Sequence[str] | None = None,
|
||||||
) -> list[Transition]:
|
) -> Generator[Transition, None, None]:
|
||||||
"""
|
"""
|
||||||
Convert a LeRobotDataset into a list of RL (s, a, r, s', done) transitions.
|
Convert a LeRobotDataset into a generator of RL (s, a, r, s', done) transitions.
|
||||||
|
|
||||||
|
Using a generator instead of a list is more memory efficient for large datasets.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
dataset (LeRobotDataset):
|
dataset (LeRobotDataset):
|
||||||
@@ -637,14 +656,12 @@ class ReplayBuffer:
|
|||||||
["observation.state", "observation.environment_state"].
|
["observation.state", "observation.environment_state"].
|
||||||
If None, you must handle or define default keys.
|
If None, you must handle or define default keys.
|
||||||
|
|
||||||
Returns:
|
Yields:
|
||||||
transitions (List[Transition]):
|
Transition: A transition dictionary.
|
||||||
A list of Transition dictionaries with the same length as `dataset`.
|
|
||||||
"""
|
"""
|
||||||
if state_keys is None:
|
if state_keys is None:
|
||||||
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
raise ValueError("State keys must be provided when converting LeRobotDataset to Transitions.")
|
||||||
|
|
||||||
transitions = []
|
|
||||||
num_frames = len(dataset)
|
num_frames = len(dataset)
|
||||||
|
|
||||||
# Check if the dataset has "next.done" key
|
# Check if the dataset has "next.done" key
|
||||||
@@ -687,8 +704,17 @@ class ReplayBuffer:
|
|||||||
if next_sample["episode_index"] != current_sample["episode_index"]:
|
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||||
done = True
|
done = True
|
||||||
|
|
||||||
# TODO: (azouitine) Handle truncation (using the same value as done for now)
|
# Handle truncation separately from done
|
||||||
truncated = done
|
# This is important if the dataset has truncations (e.g., time limits)
|
||||||
|
truncated = False
|
||||||
|
if not done:
|
||||||
|
# If this is the last frame or if next frame is in a different episode, mark as truncated
|
||||||
|
if i == num_frames - 1:
|
||||||
|
truncated = True
|
||||||
|
elif i < num_frames - 1:
|
||||||
|
next_sample = dataset[i + 1]
|
||||||
|
if next_sample["episode_index"] != current_sample["episode_index"]:
|
||||||
|
truncated = True
|
||||||
|
|
||||||
# ----- 4) Next state -----
|
# ----- 4) Next state -----
|
||||||
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
# If not done and the next sample is in the same episode, we pull the next sample's state.
|
||||||
@@ -716,7 +742,6 @@ class ReplayBuffer:
|
|||||||
if isinstance(val, torch.Tensor):
|
if isinstance(val, torch.Tensor):
|
||||||
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
complementary_info[clean_key] = val.unsqueeze(0) # Add batch dimension
|
||||||
else:
|
else:
|
||||||
# TODO: (azouitine) Check if it's necessary to convert to tensor
|
|
||||||
# For non-tensor values, use directly
|
# For non-tensor values, use directly
|
||||||
complementary_info[clean_key] = val
|
complementary_info[clean_key] = val
|
||||||
|
|
||||||
@@ -730,12 +755,10 @@ class ReplayBuffer:
|
|||||||
truncated=truncated,
|
truncated=truncated,
|
||||||
complementary_info=complementary_info,
|
complementary_info=complementary_info,
|
||||||
)
|
)
|
||||||
transitions.append(transition)
|
|
||||||
|
|
||||||
return transitions
|
yield transition
|
||||||
|
|
||||||
|
|
||||||
# Utility function to guess shapes/dtypes from a tensor
|
|
||||||
def guess_feature_info(t, name: str):
|
def guess_feature_info(t, name: str):
|
||||||
"""
|
"""
|
||||||
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
Return a dictionary with the 'dtype' and 'shape' for a given tensor or scalar value.
|
||||||
@@ -805,9 +828,9 @@ def concatenate_batch_transitions(
|
|||||||
for key in left_batch_transitions["next_state"]
|
for key in left_batch_transitions["next_state"]
|
||||||
}
|
}
|
||||||
|
|
||||||
# Concatenate done and truncated fields
|
# Concatenate terminated and truncated fields
|
||||||
left_batch_transitions["done"] = torch.cat(
|
left_batch_transitions["terminated"] = torch.cat(
|
||||||
[left_batch_transitions["done"], right_batch_transition["done"]], dim=0
|
[left_batch_transitions["terminated"], right_batch_transition["terminated"]], dim=0
|
||||||
)
|
)
|
||||||
left_batch_transitions["truncated"] = torch.cat(
|
left_batch_transitions["truncated"] = torch.cat(
|
||||||
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
[left_batch_transitions["truncated"], right_batch_transition["truncated"]],
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
|
|||||||
OBS_STATE: torch.randn(
|
OBS_STATE: torch.randn(
|
||||||
10,
|
10,
|
||||||
),
|
),
|
||||||
"done": torch.tensor(False),
|
"terminated": torch.tensor(False),
|
||||||
"truncated": torch.tensor(False),
|
"truncated": torch.tensor(False),
|
||||||
"complementary_info": {},
|
"complementary_info": {},
|
||||||
}
|
}
|
||||||
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
|||||||
"Action should be equal to the first transition."
|
"Action should be equal to the first transition."
|
||||||
)
|
)
|
||||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
|
||||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
|
||||||
|
|
||||||
for dim in state_dims():
|
for dim in state_dims():
|
||||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||||
@@ -232,8 +232,8 @@ def test_add_over_capacity():
|
|||||||
"Action should be equal to the last transition."
|
"Action should be equal to the last transition."
|
||||||
)
|
)
|
||||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
|
||||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
|
||||||
|
|
||||||
|
|
||||||
def test_sample_from_empty_buffer(replay_buffer):
|
def test_sample_from_empty_buffer(replay_buffer):
|
||||||
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
|
|||||||
action=dummy_action.clone(),
|
action=dummy_action.clone(),
|
||||||
reward=1.0,
|
reward=1.0,
|
||||||
next_state=clone_state(next_dummy_state),
|
next_state=clone_state(next_dummy_state),
|
||||||
done=False,
|
terminated=False,
|
||||||
truncated=False,
|
truncated=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
|
|||||||
action=dummy_action,
|
action=dummy_action,
|
||||||
reward=1.0,
|
reward=1.0,
|
||||||
next_state=next_dummy_state,
|
next_state=next_dummy_state,
|
||||||
done=False,
|
terminated=False,
|
||||||
truncated=False,
|
truncated=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
|
|||||||
elif feature == REWARD:
|
elif feature == REWARD:
|
||||||
assert torch.equal(value, buffer.rewards[i])
|
assert torch.equal(value, buffer.rewards[i])
|
||||||
elif feature == DONE:
|
elif feature == DONE:
|
||||||
assert torch.equal(value, buffer.dones[i])
|
# DONE in dataset is terminated OR truncated
|
||||||
|
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
|
||||||
elif feature == OBS_IMAGE:
|
elif feature == OBS_IMAGE:
|
||||||
# Tensor -> numpy is not precise, so we have some diff there
|
# Tensor -> numpy is not precise, so we have some diff there
|
||||||
# TODO: Check and fix it
|
# TODO: Check and fix it
|
||||||
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
|
|||||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||||
assert torch.equal(
|
assert torch.equal(
|
||||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
|
||||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
), "Terminated flags from converted buffer should be equal to the original replay buffer."
|
||||||
|
|
||||||
# Lerobot DS haven't supported truncateds yet
|
# LeRobot DS hasn't supported truncated yet
|
||||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
expected_truncated = torch.zeros(len(replay_buffer)).bool()
|
||||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
|
||||||
"Truncateds from converted buffer should be equal False"
|
"Truncateds from converted buffer should be equal False"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
|
|||||||
action_val = batch[ACTION][i].item()
|
action_val = batch[ACTION][i].item()
|
||||||
reward_val = batch["reward"][i].item()
|
reward_val = batch["reward"][i].item()
|
||||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||||
is_done = batch["done"][i].item() > 0.5
|
is_terminated = batch["terminated"][i].item() > 0.5
|
||||||
|
|
||||||
# Verify relationships
|
# Verify relationships
|
||||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||||
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
|
|||||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||||
)
|
)
|
||||||
|
|
||||||
if is_done:
|
if is_terminated:
|
||||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||||
|
|||||||
Reference in New Issue
Block a user