refactor(buffer): use Gymnasium terminology (terminated/truncated)

- Rename 'done' to 'terminated' for true task completion
- Use 'truncated' for time-limit termination
- Change torch.empty to torch.zeros for storage initialization
- Convert _lerobotdataset_to_transitions to generator for memory efficiency
- Add proper docstrings to BatchTransition and ReplayBuffer
- Update concatenate_batch_transitions to use new terminology
- Update tests to use new field names

This aligns ReplayBuffer with Gymnasium's termination semantics where:
- terminated: Episode ended due to task success/failure
- truncated: Episode ended due to time limit or external factors
This commit is contained in:
Michel Aractingi
2025-12-17 15:52:26 +01:00
parent cb920235c4
commit 9014f9a7c5
2 changed files with 101 additions and 78 deletions
+17 -16
View File
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
OBS_STATE: torch.randn(
10,
),
"done": torch.tensor(False),
"terminated": torch.tensor(False),
"truncated": torch.tensor(False),
"complementary_info": {},
}
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
"Action should be equal to the first transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
assert not replay_buffer.dones[0], "Done should be False for the first transition."
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
for dim in state_dims():
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
@@ -232,8 +232,8 @@ def test_add_over_capacity():
"Action should be equal to the last transition."
)
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
assert replay_buffer.dones[0], "Done should be True for the first transition."
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
def test_sample_from_empty_buffer(replay_buffer):
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
action=dummy_action.clone(),
reward=1.0,
next_state=clone_state(next_dummy_state),
done=False,
terminated=False,
truncated=False,
)
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
action=dummy_action,
reward=1.0,
next_state=next_dummy_state,
done=False,
terminated=False,
truncated=False,
)
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
elif feature == REWARD:
assert torch.equal(value, buffer.rewards[i])
elif feature == DONE:
assert torch.equal(value, buffer.dones[i])
# DONE in dataset is terminated OR truncated
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
elif feature == OBS_IMAGE:
# Tensor -> numpy is not precise, so we have some diff there
# TODO: Check and fix it
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
), "Rewards from converted buffer should be equal to the original replay buffer."
assert torch.equal(
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
), "Dones from converted buffer should be equal to the original replay buffer."
reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
), "Terminated flags from converted buffer should be equal to the original replay buffer."
# Lerobot DS haven't supported truncateds yet
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
# LeRobot DS hasn't supported truncated yet
expected_truncated = torch.zeros(len(replay_buffer)).bool()
assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
"Truncateds from converted buffer should be equal False"
)
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
action_val = batch[ACTION][i].item()
reward_val = batch["reward"][i].item()
next_state_sig = batch["next_state"]["state_value"][i].item()
is_done = batch["done"][i].item() > 0.5
is_terminated = batch["terminated"][i].item() > 0.5
# Verify relationships
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
f"Reward {reward_val} should be 3x state signature {state_sig}"
)
if is_done:
if is_terminated:
assert abs(next_state_sig - state_sig) < 1e-4, (
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
)
else:
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)