update state/action names

This commit is contained in:
Tavish
2025-02-20 18:34:45 +08:00
parent 0a89f48fb9
commit d2456c1506
+17 -2
View File
@@ -39,7 +39,7 @@ import tensorflow as tf
import tensorflow_datasets as tfds
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
np.set_printoptions(precision=2)
@@ -92,6 +92,21 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
elif state_encoding == StateEncoding.POS_QUAT:
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
elif state_encoding == StateEncoding.JOINT:
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
pad_count = state_obs_keys[:-1].count(None)
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
action_names = [f"motor_{i}" for i in range(8)]
if dataset_name in OXE_DATASET_CONFIGS:
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
if action_encoding == ActionEncoding.EEF_POS:
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
elif action_encoding == ActionEncoding.JOINT_POS:
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
DEFAULT_FEATURES = {
"observation.state": {
@@ -102,7 +117,7 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
"action": {
"dtype": "float32",
"shape": (7,),
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
"names": {"motors": action_names},
},
}