Files
lerobot/tests/utils/test_replay_buffer.py
T
Khalil Meftah e963e5a0c4 RL stack refactoring (#3075)
* refactor: RL stack refactoring — RLAlgorithm, RLTrainer, DataMixer, and SAC restructuring

* chore: clarify torch.compile disabled note in SACAlgorithm

* fix(teleop): keyboard EE teleop not registering special keys and losing intervention state

Fixes #2345

Co-authored-by: jpizarrom <jpizarrom@gmail.com>

* fix: remove leftover normalization calls from reward classifier predict_reward

Fixes #2355

* fix: add thread synchronization to ReplayBuffer to prevent race condition between add() and sample()

* refactor: update SACAlgorithm to pass action_dim to _init_critics and fix encoder reference

* perf: remove redundant CPU→GPU→CPU transition move in learner

* Fix: add kwargs in reward classifier __init__()

* fix: include IS_INTERVENTION in complementary_info sent to learner for offline replay buffer

* fix: add try/finally to control_loop to ensure image writer cleanup on exit

* fix: use string key for IS_INTERVENTION in complementary_info to avoid torch.load serialization error

* fix: skip tests that require grpc if not available

* fix(tests): ensure tensor stats comparison accounts for reshaping in normalization tests

* fix(tests): skip tests that require grpc if not available

* refactor(rl): expose public API in rl/__init__ and use relative imports in sub-packages

* fix(config): update vision encoder model name to lerobot/resnet10

* fix(sac): clarify torch.compile status

* refactor(rl): update shutdown_event type hints from 'any' to 'Any' for consistency and clarity

* refactor(sac): simplify optimizer return structure

* perf(rl): use async iterators in OnlineOfflineMixer.get_iterator

* refactor(sac): decouple algorithm hyperparameters from policy config

* update losses names in tests

* fix docstring

* remove unused type alias

* fix test for flat dict structure

* refactor(policies): rename policies/sac → policies/gaussian_actor

* refactor(rl/sac): consolidate hyperparameter ownership and clean up discrete critic

* perf(observation_processor): add CUDA support for image processing

* fix(rl): correctly wire HIL-SERL gripper penalty through processor pipeline

(cherry picked from commit 9c2af818ff)

* fix(rl): add time limit processor to environment pipeline

(cherry picked from commit cd105f65cb)

* fix(rl): clarify discrete gripper action mapping in GripperVelocityToJoint for SO100

(cherry picked from commit 494f469a2b)

* fix(rl): update neutral gripper action

(cherry picked from commit 9c9064e5be)

* fix(rl): merge environment and action-processor info in transition processing

(cherry picked from commit 30e1886b64)

* fix(rl): mirror gym_manipulator in actor

(cherry picked from commit d2a046dfc5)

* fix(rl): postprocess action in actor

(cherry picked from commit c2556439e5)

* fix(rl): improve action processing for discrete and continuous actions

(cherry picked from commit f887ab3f6a)

* fix(rl): enhance intervention handling in actor and learner

(cherry picked from commit ef8bfffbd7)

* Revert "perf(observation_processor): add CUDA support for image processing"

This reverts commit 38b88c414c.

* refactor(rl): make algorithm a nested config so all SAC hyperparameters are JSON-addressable

* refactor(rl): add make_algorithm_config function for RLAlgorithmConfig instantiation

* refactor(rl): add type property to RLAlgorithmConfig for better clarity

* refactor(rl): make RLAlgorithmConfig an abstract base class for better extensibility

* refactor(tests): remove grpc import checks from test files for cleaner code

* fix(tests): gate RL tests on the `datasets` extra

* refactor: simplify docstrings for clarity and conciseness across multiple files

* fix(rl): update gripper position key and handle action absence during reset

* fix(rl): record pre-step observation so (obs, action, next.reward) align in gym_manipulator dataset

* refactor: clean up import statements

* chore: address reviewer comments

* chore: improve visual stats reshaping logic and update docstring for clarity

* refactor: enforce mandatory config_class and name attributes in RLAlgorithm

* refactor: implement NotImplementedError for abstract methods in RLAlgorithm and DataMixer

* refactor: replace build_algorithm with make_algorithm for SACAlgorithmConfig and update related tests

* refactor: add require_package calls for grpcio and gym-hil in relevant modules

* refactor(rl): move grpcio guards to runtime entry points

* feat(rl): consolidate HIL-SERL checkpoint into HF-style components

Make `RLAlgorithmConfig` and `RLAlgorithm` `HubMixin`s, add abstract
`state_dict()` / `load_state_dict()` for critic ensemble, target nets
and `log_alpha`, and persist them as a sibling `algorithm/` component
next to `pretrained_model/`. Replace the pickled `training_state.pt`
with an enriched `training_step.json` carrying `step` and
`interaction_step`, so resume restores actor + critics + target nets +
temperature + optimizers + RNG + counters from HF-standard files.

* refactor(rl): move actor weight-sync wire format from policy to algorithm

* refactor(rl): update type hints for learner and actor functions

* refactor(rl): hoist grpcio guard to module top in actor/learner

* chore(rl): manage import pattern in actor (#3564)

* chore(rl): manage import pattern in actor

* chore(rl): optional grpc imports in learner; quote grpc ServicerContext types

---------

Co-authored-by: Khalil Meftah <khalil.meftah@huggingface.co>

* update uv.lock

* chore(doc): update doc

---------

Co-authored-by: jpizarrom <jpizarrom@gmail.com>
Co-authored-by: Steven Palma <imstevenpmwork@ieee.org>
2026-05-12 15:49:54 +02:00

685 lines
25 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters
This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import sys
from collections.abc import Callable
import pytest
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import torch # noqa: E402
from lerobot.datasets.lerobot_dataset import LeRobotDataset # noqa: E402
from lerobot.rl.buffer import BatchTransition, ReplayBuffer, random_crop_vectorized # noqa: E402
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGE, OBS_STATE, OBS_STR, REWARD # noqa: E402
from tests.fixtures.constants import DUMMY_REPO_ID # noqa: E402
def state_dims() -> list[str]:
return [OBS_IMAGE, OBS_STATE]
@pytest.fixture
def replay_buffer() -> ReplayBuffer:
return create_empty_replay_buffer()
def clone_state(state: dict) -> dict:
return {k: v.clone() for k, v in state.items()}
def create_empty_replay_buffer(
optimize_memory: bool = False,
use_drq: bool = False,
image_augmentation_function: Callable | None = None,
) -> ReplayBuffer:
buffer_capacity = 10
device = "cpu"
return ReplayBuffer(
buffer_capacity,
device,
state_dims(),
optimize_memory=optimize_memory,
use_drq=use_drq,
image_augmentation_function=image_augmentation_function,
)
def create_random_image() -> torch.Tensor:
return torch.rand(3, 84, 84)
def create_dummy_transition() -> dict:
return {
OBS_IMAGE: create_random_image(),
ACTION: torch.randn(4),
"reward": torch.tensor(1.0),
OBS_STATE: torch.randn(
10,
),
"done": torch.tensor(False),
"truncated": torch.tensor(False),
"complementary_info": {},
}
def create_dataset_from_replay_buffer(tmp_path) -> tuple[LeRobotDataset, ReplayBuffer]:
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
root = tmp_path / "test"
return (replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root), replay_buffer)
def create_dummy_state() -> dict:
return {
OBS_IMAGE: create_random_image(),
OBS_STATE: torch.randn(
10,
),
}
def get_tensor_memory_consumption(tensor):
return tensor.nelement() * tensor.element_size()
def get_tensors_memory_consumption(obj, visited_addresses):
total_size = 0
address = id(obj)
if address in visited_addresses:
return 0
visited_addresses.add(address)
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list | tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):
for value in obj.values():
total_size += get_tensors_memory_consumption(value, visited_addresses)
elif hasattr(obj, "__dict__"):
# It's an object, we need to get the size of the attributes
for _, attr in vars(obj).items():
total_size += get_tensors_memory_consumption(attr, visited_addresses)
return total_size
def get_object_memory(obj):
# Track visited addresses to avoid infinite loops
# and cases when two properties point to the same object
visited_addresses = set()
# Get the size of the object in bytes
total_size = sys.getsizeof(obj)
# Get the size of the tensor attributes
total_size += get_tensors_memory_consumption(obj, visited_addresses)
return total_size
def create_dummy_action() -> torch.Tensor:
return torch.randn(4)
def dict_properties() -> list:
return ["state", "next_state"]
@pytest.fixture
def dummy_state() -> dict:
return create_dummy_state()
@pytest.fixture
def next_dummy_state() -> dict:
return create_dummy_state()
@pytest.fixture
def dummy_action() -> torch.Tensor:
return torch.randn(4)
def test_empty_buffer_sample_raises_error(replay_buffer):
assert len(replay_buffer) == 0, "Replay buffer should be empty."
assert replay_buffer.capacity == 10, "Replay buffer capacity should be 10."
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
replay_buffer.sample(1)
def test_zero_capacity_buffer_raises_error():
with pytest.raises(ValueError, match="Capacity must be greater than 0."):
ReplayBuffer(0, "cpu", [OBS_STR, "next_observation"])
def test_add_transition(replay_buffer, dummy_state, dummy_action):
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
assert len(replay_buffer) == 1, "Replay buffer should have one transition after adding."
assert torch.equal(replay_buffer.actions[0], dummy_action), (
"Action should be equal to the first transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
assert not replay_buffer.dones[0], "Done should be False for the first transition."
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
"Observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state[dim]), (
"Next observation should be equal to the first transition."
)
def test_add_over_capacity():
replay_buffer = ReplayBuffer(2, "cpu", [OBS_STR, "next_observation"])
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
assert len(replay_buffer) == 2, "Replay buffer should have 2 transitions after adding 3."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state_3[dim]), (
"Observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.next_states[dim][0], dummy_state_3[dim]), (
"Next observation should be equal to the first transition."
)
assert torch.equal(replay_buffer.actions[0], dummy_action_3), (
"Action should be equal to the last transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
assert replay_buffer.dones[0], "Done should be True for the first transition."
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
def test_sample_from_empty_buffer(replay_buffer):
with pytest.raises(RuntimeError, match="Cannot sample from an empty buffer"):
replay_buffer.sample(1)
def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state, dummy_action):
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
got_batch_transition = replay_buffer.sample(1)
expected_batch_transition = BatchTransition(
state=clone_state(dummy_state),
action=dummy_action.clone(),
reward=1.0,
next_state=clone_state(next_dummy_state),
done=False,
truncated=False,
)
for buffer_property in dict_properties():
for k, v in expected_batch_transition[buffer_property].items():
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
assert got_state.device.type == "cpu", f"{k} should be on cpu."
assert torch.equal(got_state[0], v), f"{k} should be equal to the expected batch transition."
for key, _value in expected_batch_transition.items():
if key in dict_properties():
continue
got_value = got_batch_transition[key]
v_tensor = expected_batch_transition[key]
if not isinstance(v_tensor, torch.Tensor):
v_tensor = torch.tensor(v_tensor)
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
assert got_value.device.type == "cpu", f"{key} should be on cpu."
assert torch.equal(got_value[0], v_tensor), f"{key} should be equal to the expected batch transition."
def test_sample_with_batch_bigger_than_buffer_size(
replay_buffer, dummy_state, next_dummy_state, dummy_action
):
replay_buffer.add(dummy_state, dummy_action, 1.0, next_dummy_state, False, False)
got_batch_transition = replay_buffer.sample(10)
expected_batch_transition = BatchTransition(
state=dummy_state,
action=dummy_action,
reward=1.0,
next_state=next_dummy_state,
done=False,
truncated=False,
)
for buffer_property in dict_properties():
for k in expected_batch_transition[buffer_property]:
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 1, f"{k} should have 1 transition."
for key in expected_batch_transition:
if key in dict_properties():
continue
got_value = got_batch_transition[key]
assert got_value.shape[0] == 1, f"{key} should have 1 transition."
def test_sample_batch(replay_buffer):
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 2.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 3.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 4.0, dummy_state_4, True, True)
dummy_states = [dummy_state_1, dummy_state_2, dummy_state_3, dummy_state_4]
dummy_actions = [dummy_action_1, dummy_action_2, dummy_action_3, dummy_action_4]
got_batch_transition = replay_buffer.sample(3)
for buffer_property in dict_properties():
for k in got_batch_transition[buffer_property]:
got_state = got_batch_transition[buffer_property][k]
assert got_state.shape[0] == 3, f"{k} should have 3 transition."
for got_state_item in got_state:
assert any(torch.equal(got_state_item, dummy_state[k]) for dummy_state in dummy_states), (
f"{k} should be equal to one of the dummy states."
)
for got_action_item in got_batch_transition[ACTION]:
assert any(torch.equal(got_action_item, dummy_action) for dummy_action in dummy_actions), (
"Actions should be equal to the dummy actions."
)
for k in got_batch_transition:
if k in dict_properties() or k == "complementary_info":
continue
got_value = got_batch_transition[k]
assert got_value.shape[0] == 3, f"{k} should have 3 transition."
def test_to_lerobot_dataset_with_empty_buffer(replay_buffer):
with pytest.raises(ValueError, match="The replay buffer is empty. Cannot convert to a dataset."):
replay_buffer.to_lerobot_dataset("dummy_repo")
def test_to_lerobot_dataset(tmp_path):
ds, buffer = create_dataset_from_replay_buffer(tmp_path)
assert len(ds) == len(buffer), "Dataset should have the same size as the Replay Buffer"
assert ds.fps == 1, "FPS should be 1"
assert ds.repo_id == "dummy/repo", "The dataset should have `dummy/repo` repo id"
for dim in state_dims():
assert dim in ds.features
assert ds.features[dim]["shape"] == buffer.states[dim][0].shape
assert ds.num_episodes == 2
assert ds.num_frames == 4
for j, value in enumerate(ds):
print(torch.equal(value[OBS_IMAGE], buffer.next_states[OBS_IMAGE][j]))
for i in range(len(ds)):
for feature, value in ds[i].items():
if feature == ACTION:
assert torch.equal(value, buffer.actions[i])
elif feature == REWARD:
assert torch.equal(value, buffer.rewards[i])
elif feature == DONE:
assert torch.equal(value, buffer.dones[i])
elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
torch.testing.assert_close(value, buffer.states[OBS_IMAGE][i], rtol=0.3, atol=0.003)
elif feature == OBS_STATE:
assert torch.equal(value, buffer.states[OBS_STATE][i])
def test_from_lerobot_dataset(tmp_path):
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_1, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_3, True, True)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
root = tmp_path / "test"
ds = replay_buffer.to_lerobot_dataset(DUMMY_REPO_ID, root=root)
reconverted_buffer = ReplayBuffer.from_lerobot_dataset(
ds, state_keys=list(state_dims()), device="cpu", capacity=replay_buffer.capacity, use_drq=False
)
# Check only the part of the buffer that's actually filled with data
assert torch.equal(
reconverted_buffer.actions[: len(replay_buffer)],
replay_buffer.actions[: len(replay_buffer)],
), "Actions from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
"Truncateds from converted buffer should be equal False"
)
assert torch.equal(
replay_buffer.states[OBS_STATE][: len(replay_buffer)],
reconverted_buffer.states[OBS_STATE][: len(replay_buffer)],
), "State should be the same after converting to dataset and return back"
for i in range(4):
torch.testing.assert_close(
replay_buffer.states[OBS_IMAGE][i],
reconverted_buffer.states[OBS_IMAGE][i],
rtol=0.4,
atol=0.004,
)
# The 2, 3 frames have done flag, so their values will be equal to the current state
for i in range(2):
# In the current implementation we take the next state from the `states` and ignore `next_states`
next_index = (i + 1) % 4
torch.testing.assert_close(
replay_buffer.states[OBS_IMAGE][next_index],
reconverted_buffer.next_states[OBS_IMAGE][i],
rtol=0.4,
atol=0.004,
)
for i in range(2, 4):
assert torch.equal(
replay_buffer.states[OBS_STATE][i],
reconverted_buffer.next_states[OBS_STATE][i],
)
def test_buffer_sample_alignment():
# Initialize buffer
buffer = ReplayBuffer(capacity=100, device="cpu", state_keys=["state_value"], storage_device="cpu")
# Fill buffer with patterned data
for i in range(100):
signature = float(i) / 100.0
state = {"state_value": torch.tensor([[signature]]).float()}
action = torch.tensor([[2.0 * signature]]).float()
reward = 3.0 * signature
is_end = (i + 1) % 10 == 0
if is_end:
next_state = {"state_value": torch.tensor([[signature]]).float()}
done = True
else:
next_signature = float(i + 1) / 100.0
next_state = {"state_value": torch.tensor([[next_signature]]).float()}
done = False
buffer.add(state, action, reward, next_state, done, False)
# Sample and verify
batch = buffer.sample(50)
for i in range(50):
state_sig = batch["state"]["state_value"][i].item()
action_val = batch[ACTION][i].item()
reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5
# Verify relationships
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
f"Action {action_val} should be 2x state signature {state_sig}"
)
assert abs(reward_val - 3.0 * state_sig) < 1e-4, (
f"Reward {reward_val} should be 3x state signature {state_sig}"
)
if is_done:
assert abs(next_state_sig - state_sig) < 1e-4, (
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
)
else:
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
valid_next = (
abs(next_state_sig - state_sig - 0.01) < 1e-4 or abs(next_state_sig - state_sig) < 1e-4
)
assert valid_next, (
f"Next state {next_state_sig} should be either state+0.01 or same as state {state_sig}"
)
def test_memory_optimization():
dummy_state_1 = create_dummy_state()
dummy_action_1 = create_dummy_action()
dummy_state_2 = create_dummy_state()
dummy_action_2 = create_dummy_action()
dummy_state_3 = create_dummy_state()
dummy_action_3 = create_dummy_action()
dummy_state_4 = create_dummy_state()
dummy_action_4 = create_dummy_action()
replay_buffer = create_empty_replay_buffer()
replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, dummy_state_4, True, True)
optimized_replay_buffer = create_empty_replay_buffer(True)
optimized_replay_buffer.add(dummy_state_1, dummy_action_1, 1.0, dummy_state_2, False, False)
optimized_replay_buffer.add(dummy_state_2, dummy_action_2, 1.0, dummy_state_3, False, False)
optimized_replay_buffer.add(dummy_state_3, dummy_action_3, 1.0, dummy_state_4, False, False)
optimized_replay_buffer.add(dummy_state_4, dummy_action_4, 1.0, None, True, True)
assert get_object_memory(optimized_replay_buffer) < get_object_memory(replay_buffer), (
"Optimized replay buffer should be smaller than the original replay buffer"
)
def test_check_image_augmentations_with_drq_and_dummy_image_augmentation_function(dummy_state, dummy_action):
def dummy_image_augmentation_function(x):
return torch.ones_like(x) * 10
replay_buffer = create_empty_replay_buffer(
use_drq=True, image_augmentation_function=dummy_image_augmentation_function
)
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
sampled_transitions = replay_buffer.sample(1)
assert torch.all(sampled_transitions["state"][OBS_IMAGE] == 10), "Image augmentations should be applied"
assert torch.all(sampled_transitions["next_state"][OBS_IMAGE] == 10), (
"Image augmentations should be applied"
)
def test_check_image_augmentations_with_drq_and_default_image_augmentation_function(
dummy_state, dummy_action
):
replay_buffer = create_empty_replay_buffer(use_drq=True)
replay_buffer.add(dummy_state, dummy_action, 1.0, dummy_state, False, False)
# Let's check that it doesn't fail and shapes are correct
sampled_transitions = replay_buffer.sample(1)
assert sampled_transitions["state"][OBS_IMAGE].shape == (1, 3, 84, 84)
assert sampled_transitions["next_state"][OBS_IMAGE].shape == (1, 3, 84, 84)
def test_random_crop_vectorized_basic():
# Create a batch of 2 images with known patterns
batch_size, channels, height, width = 2, 3, 10, 8
images = torch.zeros((batch_size, channels, height, width))
# Fill with unique values for testing
for b in range(batch_size):
images[b] = b + 1
crop_size = (6, 4) # Smaller than original
cropped = random_crop_vectorized(images, crop_size)
# Check output shape
assert cropped.shape == (batch_size, channels, *crop_size)
# Check that values are preserved (should be either 1s or 2s for respective batches)
assert torch.all(cropped[0] == 1)
assert torch.all(cropped[1] == 2)
def test_random_crop_vectorized_invalid_size():
images = torch.zeros((2, 3, 10, 8))
# Test crop size larger than image
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
random_crop_vectorized(images, (12, 8))
with pytest.raises(ValueError, match="Requested crop size .* is bigger than the image size"):
random_crop_vectorized(images, (10, 10))
def _populate_buffer_for_async_test(capacity: int = 10) -> ReplayBuffer:
"""Create a small buffer with deterministic 3×128×128 images and 11-D state."""
buffer = ReplayBuffer(
capacity=capacity,
device="cpu",
state_keys=[OBS_IMAGE, OBS_STATE],
storage_device="cpu",
)
for i in range(capacity):
img = torch.ones(3, 128, 128) * i
state_vec = torch.arange(11).float() + i
state = {
OBS_IMAGE: img,
OBS_STATE: state_vec,
}
buffer.add(
state=state,
action=torch.tensor([0.0]),
reward=0.0,
next_state=state,
done=False,
truncated=False,
)
return buffer
def test_async_iterator_shapes_basic():
buffer = _populate_buffer_for_async_test()
batch_size = 2
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=1)
batch = next(iterator)
images = batch["state"][OBS_IMAGE]
states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11)
next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11)
def test_async_iterator_multiple_iterations():
buffer = _populate_buffer_for_async_test()
batch_size = 2
iterator = buffer.get_iterator(batch_size=batch_size, async_prefetch=True, queue_size=2)
for _ in range(5):
batch = next(iterator)
images = batch["state"][OBS_IMAGE]
states = batch["state"][OBS_STATE]
assert images.shape == (batch_size, 3, 128, 128)
assert states.shape == (batch_size, 11)
next_images = batch["next_state"][OBS_IMAGE]
next_states = batch["next_state"][OBS_STATE]
assert next_images.shape == (batch_size, 3, 128, 128)
assert next_states.shape == (batch_size, 11)
# Ensure iterator can be disposed without blocking
del iterator