mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-19 01:07:18 +00:00
feat(record): add transition features to dataset and handle scalar vs array formatting in converters (#1861)
- Introduced new transition features (`next.reward`, `next.done`, `next.truncated`) in the dataset during recording. - Updated the `transition_to_dataset_frame` function to handle scalar values correctly, ensuring compatibility with expected array formats for reward, done, and truncated features.
This commit is contained in:
@@ -359,11 +359,26 @@ def transition_to_dataset_frame(
|
||||
|
||||
# Add transition metadata
|
||||
if tr.get(TransitionKey.REWARD) is not None:
|
||||
batch[REWARD] = _from_tensor(tr[TransitionKey.REWARD])
|
||||
reward_val = _from_tensor(tr[TransitionKey.REWARD])
|
||||
# Check if features expect array format, otherwise keep as scalar
|
||||
if REWARD in features and features[REWARD].get("shape") == (1,):
|
||||
batch[REWARD] = np.array([reward_val], dtype=np.float32)
|
||||
else:
|
||||
batch[REWARD] = reward_val
|
||||
|
||||
if tr.get(TransitionKey.DONE) is not None:
|
||||
batch[DONE] = _from_tensor(tr[TransitionKey.DONE])
|
||||
done_val = _from_tensor(tr[TransitionKey.DONE])
|
||||
if DONE in features and features[DONE].get("shape") == (1,):
|
||||
batch[DONE] = np.array([done_val], dtype=bool)
|
||||
else:
|
||||
batch[DONE] = done_val
|
||||
|
||||
if tr.get(TransitionKey.TRUNCATED) is not None:
|
||||
batch[TRUNCATED] = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
truncated_val = _from_tensor(tr[TransitionKey.TRUNCATED])
|
||||
if TRUNCATED in features and features[TRUNCATED].get("shape") == (1,):
|
||||
batch[TRUNCATED] = np.array([truncated_val], dtype=bool)
|
||||
else:
|
||||
batch[TRUNCATED] = truncated_val
|
||||
|
||||
# Complementary data flags and task
|
||||
comp = tr.get(TransitionKey.COMPLEMENTARY_DATA) or {}
|
||||
|
||||
@@ -406,7 +406,15 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action", cfg.dataset.video)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation", cfg.dataset.video)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Add next.* features that are generated during recording
|
||||
transition_features = {
|
||||
"next.reward": {"dtype": "float32", "shape": (1,), "names": None},
|
||||
"next.done": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
"next.truncated": {"dtype": "bool", "shape": (1,), "names": None},
|
||||
}
|
||||
|
||||
dataset_features = {**action_features, **obs_features, **transition_features}
|
||||
|
||||
if cfg.resume:
|
||||
dataset = LeRobotDataset(
|
||||
|
||||
Reference in New Issue
Block a user