mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-22 17:39:39 +00:00
update state/action names
This commit is contained in:
+17
-2
@@ -39,7 +39,7 @@ import tensorflow as tf
|
|||||||
import tensorflow_datasets as tfds
|
import tensorflow_datasets as tfds
|
||||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||||
|
|
||||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
|
||||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||||
|
|
||||||
np.set_printoptions(precision=2)
|
np.set_printoptions(precision=2)
|
||||||
@@ -92,6 +92,21 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
|||||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
||||||
elif state_encoding == StateEncoding.POS_QUAT:
|
elif state_encoding == StateEncoding.POS_QUAT:
|
||||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||||
|
elif state_encoding == StateEncoding.JOINT:
|
||||||
|
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||||
|
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
||||||
|
pad_count = state_obs_keys[:-1].count(None)
|
||||||
|
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
|
||||||
|
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
|
||||||
|
|
||||||
|
action_names = [f"motor_{i}" for i in range(8)]
|
||||||
|
if dataset_name in OXE_DATASET_CONFIGS:
|
||||||
|
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
|
||||||
|
if action_encoding == ActionEncoding.EEF_POS:
|
||||||
|
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
|
||||||
|
elif action_encoding == ActionEncoding.JOINT_POS:
|
||||||
|
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_FEATURES = {
|
DEFAULT_FEATURES = {
|
||||||
"observation.state": {
|
"observation.state": {
|
||||||
@@ -102,7 +117,7 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
|||||||
"action": {
|
"action": {
|
||||||
"dtype": "float32",
|
"dtype": "float32",
|
||||||
"shape": (7,),
|
"shape": (7,),
|
||||||
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
|
"names": {"motors": action_names},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user