feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861)

- Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording.
- Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features.
This commit is contained in:
Adil Zouitine
2025-09-04 16:17:31 +02:00
committed by GitHub
parent 793ad86fc9
commit fc43246942
2 changed files with 27 additions and 4 deletions
+18 -3
View File
@@ -359,11 +359,26 @@ def transition_to_dataset_frame(
# Add transition metadata
if tr.get(TransitionKey.REWARD) is not None:
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
reward_val = _from_tensor(tr[TransitionKey.REWARD])
# Check if features expect array format, otherwise keep as scalar
if REWARD in features and features[REWARD].get("shape") == (1,):
batch[REWARD] = np.array([reward_val], dtype=np.float32)
else:
batch[REWARD] = reward_val
if tr.get(TransitionKey.DONE) is not None:
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
done_val = _from_tensor(tr[TransitionKey.DONE])
if DONE in features and features[DONE].get("shape") == (1,):
batch[DONE] = np.array([done_val], dtype=bool)
else:
batch[DONE] = done_val
if tr.get(TransitionKey.TRUNCATED) is not None:
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
else:
batch[TRUNCATED] = truncated_val
# Complementary data flags and task
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
+9 -1
View File
@@ -406,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
dataset_features = {**action_features, **obs_features}
# Add next.* features that are generated during recording
transition_features = {
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
}
dataset_features = {**action_features, **obs_features, **transition_features}
if cfg.resume:
dataset = LeRobotDataset(