diff --git a/src/lerobot/processor/converters.py b/src/lerobot/processor/converters.py index bc768952f..79db11a2e 100644 --- a/src/lerobot/processor/converters.py +++ b/src/lerobot/processor/converters.py @@ -359,11 +359,26 @@ def transition_to_dataset_frame( # Add transition metadata if tr.get(TransitionKey.REWARD) is not None: - batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD]) + reward_val = _from_tensor(tr[TransitionKey.REWARD]) + # Check if features expect array format, otherwise keep as scalar + if REWARD in features and features[REWARD].get("shape") == (1,): + batch[REWARD] = np.array([reward_val], dtype=np.float32) + else: + batch[REWARD] = reward_val + if tr.get(TransitionKey.DONE) is not None: - batch[DONE] = _from_tensor(tr[TransitionKey.DONE]) + done_val = _from_tensor(tr[TransitionKey.DONE]) + if DONE in features and features[DONE].get("shape") == (1,): + batch[DONE] = np.array([done_val], dtype=bool) + else: + batch[DONE] = done_val + if tr.get(TransitionKey.TRUNCATED) is not None: - batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED]) + truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED]) + if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,): + batch[TRUNCATED] = np.array([truncated_val], dtype=bool) + else: + batch[TRUNCATED] = truncated_val # Complementary data flags and task comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {} diff --git a/src/lerobot/record.py b/src/lerobot/record.py index 65dec20f8..b86b1613d 100644 --- a/src/lerobot/record.py +++ b/src/lerobot/record.py @@ -406,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset: action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video) obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video) - dataset_features = {**action_features, **obs_features} + + # Add next.* features that are generated during recording + transition_features = { + "next.reward": {"dtype": "float32", "shape": (1,), "names": None}, + "next.done": {"dtype": "bool", "shape": (1,), "names": None}, + "next.truncated": {"dtype": "bool", "shape": (1,), "names": None}, + } + + dataset_features = {**action_features, **obs_features, **transition_features} if cfg.resume: dataset = LeRobotDataset(