mirror of
https://github.com/Tavish9/any4lerobot.git
synced 2026-06-16 19:26:59 +00:00
Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 7a8642edfc |
+46
-20
@@ -8,13 +8,13 @@ In this dataset, we have made several key improvements:
|
||||
|
||||
- **OpenVLA-based LIBERO Regeneration**: Resolution enhancement, No-op action filtration, 180° RGB frame rotation, Failed trajectory filtering.
|
||||
- **State Data Preservation**: Maintained native LIBERO state information (accessible via `states.ee_state`, `states.joint_state` and etc.).
|
||||
- **Robust Conversion Pipeline**: Using DataTrove framework for High-speed dataset transformation and automatic failure recovery during conversion
|
||||
- **Robust Conversion Pipeline**: Using the shared `generic_converter` pipeline with local and Ray DataTrove executors for high-speed dataset transformation and resumable conversion.
|
||||
|
||||
Dataset Structure of `meta/info.json`:
|
||||
|
||||
```json
|
||||
{
|
||||
"codebase_version": "v3.0", // lastest lerobot format
|
||||
"codebase_version": "v3.0", // latest lerobot format
|
||||
"robot_type": "franka", // specific robot type
|
||||
"fps": 20, // control frequency
|
||||
"features": {
|
||||
@@ -41,7 +41,30 @@ Dataset Structure of `meta/info.json`:
|
||||
"has_audio": false
|
||||
}
|
||||
},
|
||||
// for more states key, see configs
|
||||
"observation.images.wrist_image": {
|
||||
"dtype": "video",
|
||||
"shape": [
|
||||
256,
|
||||
256,
|
||||
3
|
||||
],
|
||||
"names": [
|
||||
"height",
|
||||
"width",
|
||||
"rgb"
|
||||
],
|
||||
"info": {
|
||||
"video.height": 256,
|
||||
"video.width": 256,
|
||||
"video.codec": "av1",
|
||||
"video.pix_fmt": "yuv420p",
|
||||
"video.is_depth_map": false,
|
||||
"video.fps": 20,
|
||||
"video.channels": 3,
|
||||
"has_audio": false
|
||||
}
|
||||
},
|
||||
// for more state keys, see LiberoAdapter.features in libero_h5.py
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": [
|
||||
@@ -52,9 +75,9 @@ Dataset Structure of `meta/info.json`:
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"roll",
|
||||
"pitch",
|
||||
"yaw",
|
||||
"axis_angle1",
|
||||
"axis_angle2",
|
||||
"axis_angle3",
|
||||
"gripper",
|
||||
"gripper"
|
||||
]
|
||||
@@ -71,9 +94,9 @@ Dataset Structure of `meta/info.json`:
|
||||
"x",
|
||||
"y",
|
||||
"z",
|
||||
"roll",
|
||||
"pitch",
|
||||
"yaw",
|
||||
"axis_angle1",
|
||||
"axis_angle2",
|
||||
"axis_angle3",
|
||||
"gripper"
|
||||
]
|
||||
}
|
||||
@@ -89,31 +112,33 @@ Dataset Structure of `meta/info.json`:
|
||||
Follow instructions in [official repo](https://github.com/huggingface/lerobot?tab=readme-ov-file#installation).
|
||||
|
||||
2. Install others:
|
||||
We use datatrove[ray] for parallel conversion, significantly speeding up data processing tasks by distributing the workload across multiple cores or nodes (if any).
|
||||
We use DataTrove for conversion. Install the Ray extra if you want distributed execution across multiple cores or nodes.
|
||||
```bash
|
||||
pip install h5py
|
||||
pip install -U datatrove
|
||||
pip install -U "datatrove[ray]" # if you want ray features
|
||||
pip install -U "datatrove[ray]" # optional, for --executor ray
|
||||
```
|
||||
|
||||
## Get started
|
||||
|
||||
> [!NOTE]
|
||||
> This script supports converting from original hdf5 to lerobot. If you want to convert from rlds to lerobot, check [openx2lerobot](../openx2lerobot/README.md).
|
||||
> This script supports converting LIBERO-style HDF5 directories to LeRobot. If you want to convert from RLDS to LeRobot, check [openx2lerobot](../openx2lerobot/README.md).
|
||||
|
||||
### Download source code:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Tavish9/any4lerobot.git
|
||||
cd any4lerobot/libero2lerobot
|
||||
```
|
||||
|
||||
### Regenerate LIBERO Trajectory:
|
||||
|
||||
1. [Install LIBERO dependency](https://github.com/Lifelong-Robot-Learning/LIBERO?tab=readme-ov-file#installtion)
|
||||
2. Replace `libero_90` with your target libero dataset.
|
||||
3. The converter feature schema expects `256x256x3` RGB observations. If your source HDF5 files are the original `128x128` LIBERO files, regenerate them first with `--resolution 256`, or update the image feature shapes in `libero_h5.py` to match your data.
|
||||
|
||||
```bash
|
||||
python libero_utils/regenerate_libero_dataset.py \
|
||||
python regenerate_libero_dataset.py \
|
||||
--resolution 256 \
|
||||
--libero_task_suite libero_90 \
|
||||
--libero_raw_data_dir /path/to/libero/datasets/libero_90 \
|
||||
@@ -122,16 +147,17 @@ python libero_utils/regenerate_libero_dataset.py \
|
||||
|
||||
### Modify in `convert.sh`:
|
||||
|
||||
1. If you have installed `datatrove[ray]`, we recommend using `ray` executor for faster conversion.
|
||||
2. Increase `workers` and `tasks-per-job` if you have sufficient computing resources.
|
||||
3. To merge many datasets into one, simply specify both paths like: `--src-paths /path/libero_10 /path/libero_90`
|
||||
4. To resume from a previous conversion, provide the appropriate log directory using `--resume-from-save` and `--resume-from-aggregate`
|
||||
5. If you want different image resolution, regenerate the trajectory, and change the [config](./libero_utils/config.py). (DO NOT use resize)
|
||||
1. `--src-paths` accepts one or more directories containing `*.hdf5` LIBERO task files. To merge many suites into one LeRobot dataset, specify all source directories, for example `--src-paths /path/libero_10 /path/libero_90`.
|
||||
2. `--output-path` is the final aggregated LeRobot dataset root. Temporary per-task datasets are written next to it under `<output-name>_temp` and removed after aggregation.
|
||||
3. If you have installed `datatrove[ray]`, use `--executor ray` for faster conversion. Increase `--workers`, `--tasks-per-job`, and `--cpus-per-task` if you have enough CPU and memory.
|
||||
4. To resume a previous conversion, pass the existing DataTrove log directory with `--resume-dir /path/to/logs/...`.
|
||||
5. Use `--debug` for a small local smoke test. It converts only the first two tasks, forces local execution, and disables Hub upload.
|
||||
6. Use `--repo-id <namespace/name>` together with `--push-to-hub` to upload the aggregated dataset. Without `--push-to-hub`, `--repo-id` only controls the local aggregate repo id.
|
||||
|
||||
```bash
|
||||
python libero_h5.py \
|
||||
--src-paths /path/to/libero/ \
|
||||
--output-path /path/to/local \
|
||||
--src-paths /path/to/libero/datasets/libero_90_no_noops \
|
||||
--output-path /path/to/local/libero_90_lerobot \
|
||||
--executor local \
|
||||
--tasks-per-job 3 \
|
||||
--workers 10
|
||||
|
||||
+128
-204
@@ -1,142 +1,127 @@
|
||||
import argparse
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
import sys
|
||||
from collections.abc import Iterable, Sequence
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
import ray
|
||||
from datatrove.executor import LocalPipelineExecutor, RayPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from lerobot.datasets import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.aggregate import (
|
||||
aggregate_data,
|
||||
aggregate_metadata,
|
||||
aggregate_stats,
|
||||
aggregate_videos,
|
||||
validate_all_metadata,
|
||||
)
|
||||
from lerobot.datasets.io_utils import write_info, write_stats, write_tasks
|
||||
from lerobot.datasets.utils import (
|
||||
DEFAULT_CHUNK_SIZE,
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB,
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB,
|
||||
)
|
||||
from libero_utils.config import LIBERO_FEATURES
|
||||
from libero_utils.libero_utils import load_local_episodes
|
||||
from ray.runtime_env import RuntimeEnv
|
||||
from tqdm import tqdm
|
||||
import numpy as np
|
||||
from h5py import File
|
||||
|
||||
REPO_ROOT = Path(__file__).resolve().parents[1]
|
||||
if str(REPO_ROOT) not in sys.path:
|
||||
sys.path.insert(0, str(REPO_ROOT))
|
||||
|
||||
from generic_converter import BaseAdapter, ConversionTask, run_converter # noqa: E402
|
||||
|
||||
|
||||
def setup_logger():
|
||||
import sys
|
||||
class LiberoAdapter(BaseAdapter):
|
||||
dataset_type = "libero"
|
||||
fps = 20
|
||||
robot_type = "franka"
|
||||
features = {
|
||||
"observation.images.image": {
|
||||
"dtype": "video",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.images.wrist_image": {
|
||||
"dtype": "video",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper", "gripper"]},
|
||||
},
|
||||
"observation.states.ee_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3"]},
|
||||
},
|
||||
"observation.states.joint_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {"motors": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"]},
|
||||
},
|
||||
"observation.states.gripper_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {"motors": ["gripper", "gripper"]},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper"]},
|
||||
},
|
||||
}
|
||||
tags = ["libero", "franka"]
|
||||
|
||||
from datatrove.utils.logging import logger
|
||||
def __init__(self, src_paths: list[Path], output_path: Path):
|
||||
super().__init__(output_path)
|
||||
self.src_paths = src_paths
|
||||
|
||||
logger.remove()
|
||||
logger.add(sys.stdout, level="INFO", colorize=True)
|
||||
return logger
|
||||
def load_tasks(self) -> list[ConversionTask]:
|
||||
tasks = []
|
||||
for src_path in self.src_paths:
|
||||
for input_h5 in src_path.glob("*.hdf5"):
|
||||
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
||||
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
|
||||
|
||||
match = pattern1.search(input_h5.name)
|
||||
if match is None:
|
||||
match = pattern2.search(input_h5.name)
|
||||
if match is None:
|
||||
continue
|
||||
else:
|
||||
task_instruction = match.group(1).replace("_", " ")
|
||||
|
||||
class SaveLerobotDataset(PipelineStep):
|
||||
name = "Save Temp LerobotDataset"
|
||||
type = "libero2lerobot"
|
||||
tasks.append(
|
||||
ConversionTask(
|
||||
input_path=input_h5.resolve(),
|
||||
output_path=(
|
||||
self.temp_output_path
|
||||
/ f"{src_path.name}"
|
||||
/ input_h5.stem
|
||||
).resolve(),
|
||||
local_repo_id=f"{input_h5.parent.name}/{input_h5.name}",
|
||||
metadata={"task": task_instruction},
|
||||
)
|
||||
)
|
||||
return tasks
|
||||
|
||||
def __init__(self, tasks: list[tuple[Path, Path, str]]):
|
||||
super().__init__()
|
||||
self.tasks = tasks
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
logger = setup_logger()
|
||||
|
||||
input_h5, output_path, task_instruction = self.tasks[rank]
|
||||
|
||||
if output_path.exists():
|
||||
shutil.rmtree(output_path)
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=f"{input_h5.parent.name}/{input_h5.name}",
|
||||
root=output_path,
|
||||
fps=20,
|
||||
robot_type="franka",
|
||||
features=LIBERO_FEATURES,
|
||||
)
|
||||
|
||||
logger.info(f"start processing for {input_h5}, saving to {output_path}")
|
||||
|
||||
raw_dataset = load_local_episodes(input_h5)
|
||||
for episode_index, episode_data in enumerate(raw_dataset):
|
||||
with self.track_time("saving episode"):
|
||||
for frame_data in episode_data:
|
||||
frame_data["task"] = task_instruction
|
||||
dataset.add_frame(frame_data)
|
||||
dataset.save_episode()
|
||||
logger.info(f"process done for {dataset.repo_id}, episode {episode_index}, len {len(episode_data)}")
|
||||
|
||||
|
||||
def create_aggr_dataset(raw_dirs: list[Path], aggregated_dir: Path):
|
||||
logger = setup_logger()
|
||||
|
||||
all_metadata = [LeRobotDatasetMetadata("", root=raw_dir) for raw_dir in raw_dirs]
|
||||
|
||||
fps, robot_type, features = validate_all_metadata(all_metadata)
|
||||
|
||||
if aggregated_dir.exists():
|
||||
shutil.rmtree(aggregated_dir)
|
||||
|
||||
aggr_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=f"{aggregated_dir.parent.name}/{aggregated_dir.name}",
|
||||
root=aggregated_dir,
|
||||
fps=fps,
|
||||
robot_type=robot_type,
|
||||
features=features,
|
||||
)
|
||||
|
||||
video_keys = [key for key in features if features[key]["dtype"] == "video"]
|
||||
unique_tasks = pd.concat([m.tasks for m in all_metadata]).index.unique()
|
||||
aggr_meta.tasks = pd.DataFrame({"task_index": range(len(unique_tasks))}, index=unique_tasks)
|
||||
|
||||
meta_idx = {"chunk": 0, "file": 0}
|
||||
data_idx = {"chunk": 0, "file": 0}
|
||||
videos_idx = {key: {"chunk": 0, "file": 0, "latest_duration": 0, "episode_duration": 0} for key in video_keys}
|
||||
|
||||
aggr_meta.episodes = {}
|
||||
|
||||
for src_meta in tqdm(all_metadata, desc="Copy data and videos"):
|
||||
videos_idx = aggregate_videos(
|
||||
src_meta, aggr_meta, videos_idx, DEFAULT_VIDEO_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE
|
||||
)
|
||||
data_idx = aggregate_data(src_meta, aggr_meta, data_idx, DEFAULT_DATA_FILE_SIZE_IN_MB, DEFAULT_CHUNK_SIZE)
|
||||
|
||||
meta_idx = aggregate_metadata(src_meta, aggr_meta, meta_idx, data_idx, videos_idx)
|
||||
|
||||
aggr_meta.info["total_episodes"] += src_meta.total_episodes
|
||||
aggr_meta.info["total_frames"] += src_meta.total_frames
|
||||
|
||||
logger.info("write tasks")
|
||||
write_tasks(aggr_meta.tasks, aggr_meta.root)
|
||||
|
||||
logger.info("write info")
|
||||
aggr_meta.info.update(
|
||||
{
|
||||
"total_tasks": len(aggr_meta.tasks),
|
||||
"total_episodes": sum(m.total_episodes for m in all_metadata),
|
||||
"total_frames": sum(m.total_frames for m in all_metadata),
|
||||
"splits": {"train": f"0:{sum(m.total_episodes for m in all_metadata)}"},
|
||||
}
|
||||
)
|
||||
write_info(aggr_meta.info, aggr_meta.root)
|
||||
|
||||
logger.info("write stats")
|
||||
aggr_meta.stats = aggregate_stats([m.stats for m in all_metadata])
|
||||
write_stats(aggr_meta.stats, aggr_meta.root)
|
||||
|
||||
|
||||
def delete_temp_data(temp_dirs: list[Path]):
|
||||
logger = setup_logger()
|
||||
logger.info("Delete temp data_dir")
|
||||
for temp_dir in temp_dirs:
|
||||
shutil.rmtree(temp_dir)
|
||||
def load_subset(self, task: ConversionTask) -> Iterable[Sequence[dict]]:
|
||||
input_h5 = task.input_path
|
||||
task_instruction = task.metadata.get("task")
|
||||
with File(input_h5, "r") as f:
|
||||
for demo in f["data"].values():
|
||||
demo_len = len(demo["obs/agentview_rgb"])
|
||||
# (-1: open, 1: close) -> (0: close, 1: open)
|
||||
action = np.array(demo["actions"])
|
||||
action = np.concatenate(
|
||||
[
|
||||
action[:, :6],
|
||||
(1 - np.clip(action[:, -1], 0, 1))[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
state = np.concatenate(
|
||||
[
|
||||
np.array(demo["obs/ee_states"]),
|
||||
np.array(demo["obs/gripper_states"]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
episode = {
|
||||
"observation.images.image": np.array(demo["obs/agentview_rgb"]),
|
||||
"observation.images.wrist_image": np.array(demo["obs/eye_in_hand_rgb"]),
|
||||
"observation.state": np.array(state, dtype=np.float32),
|
||||
"observation.states.ee_state": np.array(demo["obs/ee_states"], dtype=np.float32),
|
||||
"observation.states.joint_state": np.array(demo["obs/joint_states"], dtype=np.float32),
|
||||
"observation.states.gripper_state": np.array(demo["obs/gripper_states"], dtype=np.float32),
|
||||
"action": np.array(action, dtype=np.float32),
|
||||
}
|
||||
yield [{**{k: v[i] for k, v in episode.items()}, "task": task_instruction} for i in range(demo_len)]
|
||||
|
||||
|
||||
def main(
|
||||
@@ -146,86 +131,25 @@ def main(
|
||||
cpus_per_task: int,
|
||||
tasks_per_job: int,
|
||||
workers: int,
|
||||
resume_dir: Path = None,
|
||||
resume_dir: Path | None = None,
|
||||
debug: bool = False,
|
||||
repo_id: str = None,
|
||||
repo_id: str | None = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
tasks = []
|
||||
pattern1 = re.compile(r"_SCENE\d+_(.*?)_demo\.hdf5")
|
||||
pattern2 = re.compile(r"(.*?)_demo\.hdf5")
|
||||
for src_path in src_paths:
|
||||
for input_h5 in src_path.glob("*.hdf5"):
|
||||
match = pattern1.search(input_h5.name)
|
||||
if match is None:
|
||||
match = pattern2.search(input_h5.name)
|
||||
if match is None:
|
||||
continue
|
||||
tasks.append(
|
||||
(
|
||||
input_h5,
|
||||
(output_path / (src_path.name + "_temp") / input_h5.stem).resolve(),
|
||||
match.group(1).replace("_", " "),
|
||||
)
|
||||
)
|
||||
if len(src_paths) > 1:
|
||||
aggregate_output_path = output_path / (
|
||||
"_".join([src_path.name for src_path in src_paths]) + "_aggregated_lerobot"
|
||||
)
|
||||
else:
|
||||
aggregate_output_path = output_path / f"{src_paths[0].name}_lerobot"
|
||||
aggregate_output_path = aggregate_output_path.resolve()
|
||||
adapter = LiberoAdapter(src_paths, output_path)
|
||||
|
||||
if debug:
|
||||
executor = "local"
|
||||
workers = 1
|
||||
tasks = tasks[:2]
|
||||
push_to_hub = False
|
||||
|
||||
match executor:
|
||||
case "local":
|
||||
workers = os.cpu_count() // cpus_per_task if workers == -1 else workers
|
||||
executor = LocalPipelineExecutor
|
||||
case "ray":
|
||||
runtime_env = RuntimeEnv(
|
||||
env_vars={
|
||||
"HDF5_USE_FILE_LOCKING": "FALSE",
|
||||
"HF_DATASETS_DISABLE_PROGRESS_BARS": "TRUE",
|
||||
"SVT_LOG": "1",
|
||||
},
|
||||
)
|
||||
ray.init(runtime_env=runtime_env)
|
||||
executor = RayPipelineExecutor
|
||||
case _:
|
||||
raise ValueError(f"Executor {executor} not supported")
|
||||
|
||||
executor_config = {
|
||||
"tasks": len(tasks),
|
||||
"workers": workers,
|
||||
**({"cpus_per_task": cpus_per_task, "tasks_per_job": tasks_per_job} if executor is RayPipelineExecutor else {}),
|
||||
}
|
||||
|
||||
executor(pipeline=[SaveLerobotDataset(tasks)], **executor_config, logging_dir=resume_dir).run()
|
||||
create_aggr_dataset([task[1] for task in tasks], aggregate_output_path)
|
||||
delete_temp_data([task[1] for task in tasks])
|
||||
|
||||
for task in tasks:
|
||||
shutil.rmtree(task[1].parent, ignore_errors=True)
|
||||
|
||||
if push_to_hub:
|
||||
assert repo_id is not None
|
||||
tags = ["LeRobot", "libero", "franka"]
|
||||
tags.extend([src_path.name for src_path in src_paths])
|
||||
LeRobotDataset(
|
||||
repo_id=repo_id,
|
||||
root=aggregate_output_path,
|
||||
).push_to_hub(
|
||||
tags=tags,
|
||||
private=False,
|
||||
push_videos=True,
|
||||
license="apache-2.0",
|
||||
upload_large_folder=False,
|
||||
)
|
||||
run_converter(
|
||||
adapter=adapter,
|
||||
executor=executor,
|
||||
cpus_per_task=cpus_per_task,
|
||||
tasks_per_job=tasks_per_job,
|
||||
workers=workers,
|
||||
resume_dir=resume_dir,
|
||||
debug=debug,
|
||||
local_repo_id=repo_id,
|
||||
hub_repo_id=repo_id,
|
||||
push_to_hub=push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
LIBERO_FEATURES = {
|
||||
"observation.images.image": {
|
||||
"dtype": "video",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.images.wrist_image": {
|
||||
"dtype": "video",
|
||||
"shape": (256, 256, 3),
|
||||
"names": ["height", "width", "rgb"],
|
||||
},
|
||||
"observation.state": {
|
||||
"dtype": "float32",
|
||||
"shape": (8,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper", "gripper"]},
|
||||
},
|
||||
"observation.states.ee_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (6,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3"]},
|
||||
},
|
||||
"observation.states.joint_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {"motors": ["joint_0", "joint_1", "joint_2", "joint_3", "joint_4", "joint_5", "joint_6"]},
|
||||
},
|
||||
"observation.states.gripper_state": {
|
||||
"dtype": "float32",
|
||||
"shape": (2,),
|
||||
"names": {"motors": ["gripper", "gripper"]},
|
||||
},
|
||||
"action": {
|
||||
"dtype": "float32",
|
||||
"shape": (7,),
|
||||
"names": {"motors": ["x", "y", "z", "axis_angle1", "axis_angle2", "axis_angle3", "gripper"]},
|
||||
},
|
||||
}
|
||||
@@ -1,36 +0,0 @@
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
from h5py import File
|
||||
|
||||
|
||||
def load_local_episodes(input_h5: Path):
|
||||
with File(input_h5, "r") as f:
|
||||
for demo in f["data"].values():
|
||||
demo_len = len(demo["obs/agentview_rgb"])
|
||||
# (-1: open, 1: close) -> (0: close, 1: open)
|
||||
action = np.array(demo["actions"])
|
||||
action = np.concatenate(
|
||||
[
|
||||
action[:, :6],
|
||||
(1 - np.clip(action[:, -1], 0, 1))[:, None],
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
state = np.concatenate(
|
||||
[
|
||||
np.array(demo["obs/ee_states"]),
|
||||
np.array(demo["obs/gripper_states"]),
|
||||
],
|
||||
axis=1,
|
||||
)
|
||||
episode = {
|
||||
"observation.images.image": np.array(demo["obs/agentview_rgb"]),
|
||||
"observation.images.wrist_image": np.array(demo["obs/eye_in_hand_rgb"]),
|
||||
"observation.state": np.array(state, dtype=np.float32),
|
||||
"observation.states.ee_state": np.array(demo["obs/ee_states"], dtype=np.float32),
|
||||
"observation.states.joint_state": np.array(demo["obs/joint_states"], dtype=np.float32),
|
||||
"observation.states.gripper_state": np.array(demo["obs/gripper_states"], dtype=np.float32),
|
||||
"action": np.array(action, dtype=np.float32),
|
||||
}
|
||||
yield [{**{k: v[i] for k, v in episode.items()}} for i in range(demo_len)]
|
||||
Reference in New Issue
Block a user