mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-11 12:09:41 +00:00
update state/action names
This commit is contained in:
+17
-2
@@ -39,7 +39,7 @@ import tensorflow as tf
|
||||
import tensorflow_datasets as tfds
|
||||
from lerobot.common.datasets.lerobot_dataset import LEROBOT_HOME, LeRobotDataset
|
||||
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, StateEncoding
|
||||
from oxe_utils.configs import OXE_DATASET_CONFIGS, ActionEncoding, StateEncoding
|
||||
from oxe_utils.transforms import OXE_STANDARDIZATION_TRANSFORMS
|
||||
|
||||
np.set_printoptions(precision=2)
|
||||
@@ -92,6 +92,21 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
||||
state_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper", "gripper"] # 2D gripper state
|
||||
elif state_encoding == StateEncoding.POS_QUAT:
|
||||
state_names = ["x", "y", "z", "rx", "ry", "rz", "rw", "gripper"]
|
||||
elif state_encoding == StateEncoding.JOINT:
|
||||
state_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||
state_obs_keys = OXE_DATASET_CONFIGS[dataset_name]["state_obs_keys"]
|
||||
pad_count = state_obs_keys[:-1].count(None)
|
||||
state_names[-pad_count - 1 : -1] = ["pad"] * pad_count
|
||||
state_names[-1] = "pad" if state_obs_keys[-1] is None else state_names[-1]
|
||||
|
||||
action_names = [f"motor_{i}" for i in range(8)]
|
||||
if dataset_name in OXE_DATASET_CONFIGS:
|
||||
action_encoding = OXE_DATASET_CONFIGS[dataset_name]["action_encoding"]
|
||||
if action_encoding == ActionEncoding.EEF_POS:
|
||||
action_names = ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]
|
||||
elif action_encoding == ActionEncoding.JOINT_POS:
|
||||
action_names = [f"motor_{i}" for i in range(7)] + ["gripper"]
|
||||
|
||||
|
||||
DEFAULT_FEATURES = {
|
||||
"observation.state": {
|
||||
@@ -102,7 +117,7 @@ def generate_features_from_raw(builder: tfds.core.DatasetBuilder, use_videos: bo
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {"motors": ["x", "y", "z", "roll", "pitch", "yaw", "gripper"]},
|
||||
"names": {"motors": action_names},
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user