mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-05-22 09:29:44 +00:00
fix kuka_dataset_transform & filter
This commit is contained in:
+6
-1
@@ -190,7 +190,12 @@ def create_lerobot_dataset(
|
|||||||
|
|
||||||
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
builder = tfds.builder(dataset_name, data_dir=data_dir, version=version)
|
||||||
features = generate_features_from_raw(builder, use_videos)
|
features = generate_features_from_raw(builder, use_videos)
|
||||||
raw_dataset = builder.as_dataset(split="train").map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
filter_fn = lambda e: e["success"] if dataset_name == "kuka" else True
|
||||||
|
raw_dataset = (
|
||||||
|
builder.as_dataset(split="train")
|
||||||
|
.filter(filter_fn)
|
||||||
|
.map(partial(transform_raw_dataset, dataset_name=dataset_name))
|
||||||
|
)
|
||||||
|
|
||||||
if fps is None:
|
if fps is None:
|
||||||
if dataset_name in OXE_DATASET_CONFIGS:
|
if dataset_name in OXE_DATASET_CONFIGS:
|
||||||
|
|||||||
@@ -192,19 +192,6 @@ def kuka_dataset_transform(trajectory: Dict[str, Any]) -> Dict[str, Any]:
|
|||||||
),
|
),
|
||||||
axis=-1,
|
axis=-1,
|
||||||
)
|
)
|
||||||
# decode compressed state
|
|
||||||
eef_value = tf.io.decode_compressed(
|
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"],
|
|
||||||
compression_type="ZLIB",
|
|
||||||
)
|
|
||||||
eef_value = tf.io.decode_raw(eef_value, tf.float32)
|
|
||||||
trajectory["observation"]["clip_function_input/base_pose_tool_reached"] = tf.reshape(eef_value, (-1, 7))
|
|
||||||
gripper_value = tf.io.decode_compressed(trajectory["observation"]["gripper_closed"], compression_type="ZLIB")
|
|
||||||
gripper_value = tf.io.decode_raw(gripper_value, tf.float32)
|
|
||||||
trajectory["observation"]["gripper_closed"] = tf.reshape(gripper_value, (-1, 1))
|
|
||||||
# trajectory["language_instruction"] = tf.fill(
|
|
||||||
# tf.shape(trajectory["observation"]["natural_language_instruction"]), ""
|
|
||||||
# ) # delete uninformative language instruction
|
|
||||||
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
trajectory["language_instruction"] = trajectory["observation"]["natural_language_instruction"]
|
||||||
return trajectory
|
return trajectory
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user