mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
refactor(buffer): use Gymnasium terminology (terminated/truncated)
- Rename 'done' to 'terminated' for true task completion - Use 'truncated' for time-limit termination - Change torch.empty to torch.zeros for storage initialization - Convert _lerobotdataset_to_transitions to generator for memory efficiency - Add proper docstrings to BatchTransition and ReplayBuffer - Update concatenate_batch_transitions to use new terminology - Update tests to use new field names This aligns ReplayBuffer with Gymnasium's termination semantics where: - terminated: Episode ended due to task success/failure - truncated: Episode ended due to time limit or external factors
This commit is contained in:
@@ -68,7 +68,7 @@ def create_dummy_transition() -> dict:
|
||||
OBS_STATE: torch.randn(
|
||||
10,
|
||||
),
|
||||
"done": torch.tensor(False),
|
||||
"terminated": torch.tensor(False),
|
||||
"truncated": torch.tensor(False),
|
||||
"complementary_info": {},
|
||||
}
|
||||
@@ -191,8 +191,8 @@ def test_add_transition(replay_buffer, dummy_state, dummy_action):
|
||||
"Action should be equal to the first transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the first transition."
|
||||
assert not replay_buffer.dones[0], "Done should be False for the first transition."
|
||||
assert not replay_buffer.truncateds[0], "Truncated should be False for the first transition."
|
||||
assert not replay_buffer.terminated[0], "Terminated should be False for the first transition."
|
||||
assert not replay_buffer.truncated[0], "Truncated should be False for the first transition."
|
||||
|
||||
for dim in state_dims():
|
||||
assert torch.equal(replay_buffer.states[dim][0], dummy_state[dim]), (
|
||||
@@ -232,8 +232,8 @@ def test_add_over_capacity():
|
||||
"Action should be equal to the last transition."
|
||||
)
|
||||
assert replay_buffer.rewards[0] == 1.0, "Reward should be equal to the last transition."
|
||||
assert replay_buffer.dones[0], "Done should be True for the first transition."
|
||||
assert replay_buffer.truncateds[0], "Truncated should be True for the first transition."
|
||||
assert replay_buffer.terminated[0], "Terminated should be True for the first transition."
|
||||
assert replay_buffer.truncated[0], "Truncated should be True for the first transition."
|
||||
|
||||
|
||||
def test_sample_from_empty_buffer(replay_buffer):
|
||||
@@ -250,7 +250,7 @@ def test_sample_with_1_transition(replay_buffer, dummy_state, next_dummy_state,
|
||||
action=dummy_action.clone(),
|
||||
reward=1.0,
|
||||
next_state=clone_state(next_dummy_state),
|
||||
done=False,
|
||||
terminated=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
@@ -289,7 +289,7 @@ def test_sample_with_batch_bigger_than_buffer_size(
|
||||
action=dummy_action,
|
||||
reward=1.0,
|
||||
next_state=next_dummy_state,
|
||||
done=False,
|
||||
terminated=False,
|
||||
truncated=False,
|
||||
)
|
||||
|
||||
@@ -383,7 +383,8 @@ def test_to_lerobot_dataset(tmp_path):
|
||||
elif feature == REWARD:
|
||||
assert torch.equal(value, buffer.rewards[i])
|
||||
elif feature == DONE:
|
||||
assert torch.equal(value, buffer.dones[i])
|
||||
# DONE in dataset is terminated OR truncated
|
||||
assert torch.equal(value, buffer.terminated[i] | buffer.truncated[i])
|
||||
elif feature == OBS_IMAGE:
|
||||
# Tensor -> numpy is not precise, so we have some diff there
|
||||
# TODO: Check and fix it
|
||||
@@ -427,12 +428,12 @@ def test_from_lerobot_dataset(tmp_path):
|
||||
reconverted_buffer.rewards[: len(replay_buffer)], replay_buffer.rewards[: len(replay_buffer)]
|
||||
), "Rewards from converted buffer should be equal to the original replay buffer."
|
||||
assert torch.equal(
|
||||
reconverted_buffer.dones[: len(replay_buffer)], replay_buffer.dones[: len(replay_buffer)]
|
||||
), "Dones from converted buffer should be equal to the original replay buffer."
|
||||
reconverted_buffer.terminated[: len(replay_buffer)], replay_buffer.terminated[: len(replay_buffer)]
|
||||
), "Terminated flags from converted buffer should be equal to the original replay buffer."
|
||||
|
||||
# Lerobot DS haven't supported truncateds yet
|
||||
expected_truncateds = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncateds[: len(replay_buffer)], expected_truncateds), (
|
||||
# LeRobot DS hasn't supported truncated yet
|
||||
expected_truncated = torch.zeros(len(replay_buffer)).bool()
|
||||
assert torch.equal(reconverted_buffer.truncated[: len(replay_buffer)], expected_truncated), (
|
||||
"Truncateds from converted buffer should be equal False"
|
||||
)
|
||||
|
||||
@@ -498,7 +499,7 @@ def test_buffer_sample_alignment():
|
||||
action_val = batch[ACTION][i].item()
|
||||
reward_val = batch["reward"][i].item()
|
||||
next_state_sig = batch["next_state"]["state_value"][i].item()
|
||||
is_done = batch["done"][i].item() > 0.5
|
||||
is_terminated = batch["terminated"][i].item() > 0.5
|
||||
|
||||
# Verify relationships
|
||||
assert abs(action_val - 2.0 * state_sig) < 1e-4, (
|
||||
@@ -509,9 +510,9 @@ def test_buffer_sample_alignment():
|
||||
f"Reward {reward_val} should be 3x state signature {state_sig}"
|
||||
)
|
||||
|
||||
if is_done:
|
||||
if is_terminated:
|
||||
assert abs(next_state_sig - state_sig) < 1e-4, (
|
||||
f"For done states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
f"For terminated states, next_state {next_state_sig} should equal state {state_sig}"
|
||||
)
|
||||
else:
|
||||
# Either it's the next sequential state (+0.01) or same state (for episode boundaries)
|
||||
|
||||
Reference in New Issue
Block a user