mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 73780046b2 | |||
| 093a85f946 | |||
| a669049da2 | |||
| ce348a3460 | |||
| cb920235c4 | |||
| 7f40b3bf82 | |||
| 2e9c9fd832 | |||
| f9cb5e659c | |||
| 0217e1e3ad | |||
| d79dd6d31f | |||
| 56b43cc888 | |||
| 77fe5a09ed | |||
| 89ae7813a7 | |||
| e003108cf8 | |||
| 5766eea377 | |||
| f8a4cf225b | |||
| 43b0f17eb9 | |||
| b0b755471b | |||
| 35c5a27352 | |||
| afb90e17e7 | |||
| 9ec9ee781a | |||
| 0b497fc37d | |||
| 797cd2725a | |||
| af4766b602 | |||
| 37f43df88a | |||
| 5f7b5f2817 | |||
| c55fbe1b3e | |||
| 58f70b6bd3 | |||
| b07160eb1b | |||
| 648ea8f485 | |||
| 581dd45eae | |||
| 17581a9449 | |||
| 87bee86640 | |||
| 18b32dced9 | |||
| 36e8feefe3 |
@@ -31,7 +31,8 @@ jobs:
|
||||
name: Upload Preview and Comment
|
||||
if: >
|
||||
github.event.workflow_run.event == 'pull_request' &&
|
||||
github.event.workflow_run.conclusion == 'success'
|
||||
github.event.workflow_run.conclusion == 'success' &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
uses: huggingface/doc-builder/.github/workflows/upload_pr_documentation.yml@main
|
||||
with:
|
||||
package_name: lerobot
|
||||
|
||||
@@ -42,7 +42,9 @@ jobs:
|
||||
# This job builds and deploys the official documentation.
|
||||
build_main_docs:
|
||||
name: Build Main Docs
|
||||
if: github.event_name == 'push' || github.event_name == 'workflow_dispatch'
|
||||
if: >
|
||||
(github.event_name == 'push' || github.event_name == 'workflow_dispatch') &&
|
||||
github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
uses: huggingface/doc-builder/.github/workflows/build_main_documentation.yml@main
|
||||
@@ -58,7 +60,7 @@ jobs:
|
||||
# The result of this job triggers the 'Upload PR Documentation' workflow.
|
||||
build_pr_docs:
|
||||
name: Build PR Docs
|
||||
if: github.event_name == 'pull_request'
|
||||
if: github.event_name == 'pull_request' && github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
contents: read
|
||||
pull-requests: write
|
||||
|
||||
@@ -45,7 +45,6 @@ permissions:
|
||||
env:
|
||||
UV_VERSION: "0.8.0"
|
||||
PYTHON_VERSION: "3.10"
|
||||
DOCKER_IMAGE_NAME: huggingface/lerobot-gpu
|
||||
|
||||
# Ensures that only the latest commit for a PR or branch is built, canceling older runs.
|
||||
concurrency:
|
||||
@@ -60,12 +59,19 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
# TODO(Steven): Evaluate the need of these dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
|
||||
@@ -58,12 +58,19 @@ jobs:
|
||||
github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
@@ -43,6 +43,7 @@ jobs:
|
||||
name: Build CPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_CPU }}
|
||||
steps:
|
||||
@@ -77,6 +78,7 @@ jobs:
|
||||
name: Build GPU Docker for Nightly
|
||||
runs-on:
|
||||
group: aws-general-8-plus
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
image_tag: ${{ env.DOCKER_IMAGE_NAME_GPU }}
|
||||
steps:
|
||||
|
||||
@@ -29,6 +29,7 @@ jobs:
|
||||
build-and-publish:
|
||||
name: Build and publish Python distributions
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
outputs:
|
||||
version: ${{ steps.extract_info.outputs.tag_version }}
|
||||
permissions:
|
||||
|
||||
@@ -45,6 +45,7 @@ jobs:
|
||||
stale:
|
||||
name: Close Stale Issues and PRs
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
permissions:
|
||||
actions: write
|
||||
contents: write # only for delete-branch option
|
||||
|
||||
@@ -43,14 +43,22 @@ jobs:
|
||||
full-tests:
|
||||
name: Full Unbound Tests
|
||||
runs-on: ubuntu-latest
|
||||
if: github.repository == 'huggingface/lerobot'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
class TimeBenchmark(ContextDecorator):
|
||||
"""
|
||||
Measures execution time using a context manager or decorator.
|
||||
|
||||
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
||||
environments.
|
||||
|
||||
Args:
|
||||
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
||||
to False.
|
||||
|
||||
Examples:
|
||||
|
||||
Using as a context manager:
|
||||
|
||||
>>> benchmark = TimeBenchmark()
|
||||
>>> with benchmark:
|
||||
... time.sleep(1)
|
||||
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
||||
Block took approximately 1.0000 seconds
|
||||
|
||||
Using with multithreading:
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
benchmark = TimeBenchmark()
|
||||
|
||||
|
||||
def context_manager_example():
|
||||
with benchmark:
|
||||
time.sleep(0.01)
|
||||
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
||||
|
||||
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
t1 = threading.Thread(target=context_manager_example)
|
||||
threads.append(t1)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
```
|
||||
Expected output:
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.local.end_time = time.perf_counter()
|
||||
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
||||
if self.print_time:
|
||||
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
||||
return False
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return getattr(self.local, "elapsed_time", None)
|
||||
|
||||
@property
|
||||
def result_ms(self):
|
||||
return self.result * 1e3
|
||||
@@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Capture video feed from a camera as raw images."""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
|
||||
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||
rr.init("lerobot_capture_camera_feed")
|
||||
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
capture_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Opens the default webcam
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Error: Could not open video stream.")
|
||||
return
|
||||
|
||||
cap.set(cv2.CAP_PROP_FPS, fps)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
rr.log("video/stream", rr.Image(frame), static=True)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Release the capture
|
||||
cap.release()
|
||||
|
||||
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/cam_capture/"),
|
||||
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames Per Second of the capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1280,
|
||||
help="Width of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Duration in seconds for which the video stream should be captured.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
@@ -21,11 +21,13 @@ See the provided README.md or run `python benchmark/video/run_video_benchmark.py
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -35,13 +37,13 @@ import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmarks.video.benchmark import TimeBenchmark
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -86,7 +88,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
@@ -97,21 +99,21 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -125,7 +127,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
@@ -149,18 +151,6 @@ def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> lis
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
@@ -172,8 +162,8 @@ def benchmark_decoding(
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int):
|
||||
time_benchmark = TimeBenchmark()
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
@@ -182,13 +172,13 @@ def benchmark_decoding(
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
@@ -215,8 +205,10 @@ def benchmark_decoding(
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
@@ -358,24 +350,27 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
@@ -409,9 +404,9 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
@@ -419,7 +414,7 @@ if __name__ == "__main__":
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["libx264", "hevc", "libsvtav1"],
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -468,7 +463,7 @@ if __name__ == "__main__":
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["pyav", "video_reader"],
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -9,6 +9,8 @@
|
||||
title: Imitation Learning for Robots
|
||||
- local: cameras
|
||||
title: Cameras
|
||||
- local: bring_your_own_policies
|
||||
title: Bring Your Own Policies
|
||||
- local: integrate_hardware
|
||||
title: Bring Your Own Hardware
|
||||
- local: hilserl
|
||||
@@ -37,6 +39,8 @@
|
||||
title: π₀.₅ (Pi05)
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
- local: xvla
|
||||
title: X-VLA
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: async
|
||||
@@ -47,8 +51,8 @@
|
||||
- sections:
|
||||
- local: envhub
|
||||
title: Environments from the Hub
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: envhub_leisaac
|
||||
title: Control & Train Robots in Sim (LeIsaac)
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
- local: metaworld
|
||||
@@ -79,11 +83,19 @@
|
||||
title: Hope Jr
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
- local: unitree_g1
|
||||
title: Unitree G1
|
||||
- local: earthrover_mini_plus
|
||||
title: Earth Rover Mini
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
title: Phone
|
||||
title: "Teleoperators"
|
||||
- sections:
|
||||
- local: torch_accelerators
|
||||
title: PyTorch accelerators
|
||||
title: "Supported Hardware"
|
||||
- sections:
|
||||
- local: notebooks
|
||||
title: Notebooks
|
||||
|
||||
@@ -196,7 +196,7 @@ client_cfg = RobotClientConfig(
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="fracapuano/smolvla_async",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
@@ -278,7 +278,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
2. **Adjust your `fps` based on inference latency.** While the server generates a new action chunk, the client is not idle and is stepping through its current action queue. If the two processes happen at fundamentally different speeds, the client might end up with an empty queue. As such, you should reduce your fps if you consistently run out of actions in queue.
|
||||
3. **Adjust `chunk_size_threshold`**.
|
||||
- Values closer to `0.0` result in almost sequential behavior. Values closer to `1.0` → send observation every step (more bandwidth, relies on good world-model).
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug-visualize-queue-size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
- We found values around 0.5-0.6 to work well. If you want to tweak this, spin up a `RobotClient` setting the `--debug_visualize_queue_size` to `True`. This will plot the action queue size evolution at runtime, and you can use it to find the value of `chunk_size_threshold` that works best for your setup.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
@@ -289,7 +289,7 @@ We found the default values of `actions_per_chunk` and `chunk_size_threshold` to
|
||||
<p align="center">
|
||||
<i>
|
||||
The action queue size is plotted at runtime when the
|
||||
`--debug-visualize-queue-size` flag is passed, for various levels of
|
||||
`--debug_visualize_queue_size` flag is passed, for various levels of
|
||||
`chunk_size_threshold` (`g` in the SmolVLA paper).
|
||||
</i>
|
||||
</p>
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
# Bring Your Own Policies
|
||||
|
||||
This tutorial explains how to integrate your own custom policy implementations into the LeRobot ecosystem, allowing you to leverage all LeRobot tools for training, evaluation, and deployment while using your own algorithms.
|
||||
|
||||
## Step 1: Create a Policy Package
|
||||
|
||||
Your custom policy should be organized as an installable Python package following LeRobot's plugin conventions.
|
||||
|
||||
### Package Structure
|
||||
|
||||
Create a package with the prefix `lerobot_policy_` (IMPORTANT!) followed by your policy name:
|
||||
|
||||
```bash
|
||||
lerobot_policy_my_custom_policy/
|
||||
├── pyproject.toml
|
||||
└── src/
|
||||
└── lerobot_policy_my_custom_policy/
|
||||
├── __init__.py
|
||||
├── configuration_my_custom_policy.py
|
||||
├── modeling_my_custom_policy.py
|
||||
└── processor_my_custom_policy.py
|
||||
```
|
||||
|
||||
### Package Configuration
|
||||
|
||||
Set up your `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "lerobot_policy_my_custom_policy"
|
||||
version = "0.1.0"
|
||||
dependencies = [
|
||||
# your policy-specific dependencies
|
||||
]
|
||||
requires-python = ">= 3.11"
|
||||
|
||||
[build-system]
|
||||
build-backend = # your-build-backend
|
||||
requires = # your-build-system
|
||||
```
|
||||
|
||||
## Step 2: Define the Policy Configuration
|
||||
|
||||
Create a configuration class that inherits from `PreTrainedConfig` and registers your policy type:
|
||||
|
||||
```python
|
||||
# configuration_my_custom_policy.py
|
||||
from dataclasses import dataclass, field
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
|
||||
@PreTrainedConfig.register_subclass("my_custom_policy")
|
||||
@dataclass
|
||||
class MyCustomPolicyConfig(PreTrainedConfig):
|
||||
"""Configuration class for MyCustomPolicy.
|
||||
|
||||
Args:
|
||||
n_obs_steps: Number of observation steps to use as input
|
||||
horizon: Action prediction horizon
|
||||
n_action_steps: Number of action steps to execute
|
||||
hidden_dim: Hidden dimension for the policy network
|
||||
# Add your policy-specific parameters here
|
||||
"""
|
||||
# ...PreTrainedConfig fields...
|
||||
pass
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
# Add any validation logic here
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input/output feature compatibility."""
|
||||
# Implement validation logic for your policy's requirements
|
||||
pass
|
||||
```
|
||||
|
||||
## Step 3: Implement the Policy Class
|
||||
|
||||
Create your policy implementation by inheriting from LeRobot's base `PreTrainedPolicy` class:
|
||||
|
||||
```python
|
||||
# modeling_my_custom_policy.py
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from typing import Dict, Any
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
|
||||
class MyCustomPolicy(PreTrainedPolicy):
|
||||
config_class = MyCustomPolicyConfig
|
||||
name = "my_custom_policy"
|
||||
|
||||
def __init__(self, config: MyCustomPolicyConfig, dataset_stats: Dict[str, Any] = None):
|
||||
super().__init__(config, dataset_stats)
|
||||
...
|
||||
```
|
||||
|
||||
## Step 4: Add Data Processors
|
||||
|
||||
Create processor functions:
|
||||
|
||||
```python
|
||||
# processor_my_custom_policy.py
|
||||
from typing import Dict, Any
|
||||
import torch
|
||||
|
||||
|
||||
def make_my_custom_policy_pre_post_processors(
|
||||
config,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""Create preprocessing and postprocessing functions for your policy."""
|
||||
pass # Define your preprocessing and postprocessing logic here
|
||||
|
||||
```
|
||||
|
||||
## Step 5: Package Initialization
|
||||
|
||||
Expose your classes in the package's `__init__.py`:
|
||||
|
||||
```python
|
||||
# __init__.py
|
||||
"""Custom policy package for LeRobot."""
|
||||
|
||||
try:
|
||||
import lerobot # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"lerobot is not installed. Please install lerobot to use this policy package."
|
||||
)
|
||||
|
||||
from .configuration_my_custom_policy import MyCustomPolicyConfig
|
||||
from .modeling_my_custom_policy import MyCustomPolicy
|
||||
from .processor_my_custom_policy import make_my_custom_policy_pre_post_processors
|
||||
|
||||
__all__ = [
|
||||
"MyCustomPolicyConfig",
|
||||
"MyCustomPolicy",
|
||||
"make_my_custom_policy_pre_post_processors",
|
||||
]
|
||||
```
|
||||
|
||||
## Step 6: Installation and Usage
|
||||
|
||||
### Install Your Policy Package
|
||||
|
||||
```bash
|
||||
cd lerobot_policy_my_custom_policy
|
||||
pip install -e .
|
||||
|
||||
# Or install from PyPI if published
|
||||
pip install lerobot_policy_my_custom_policy
|
||||
```
|
||||
|
||||
### Use Your Policy
|
||||
|
||||
Once installed, your policy automatically integrates with LeRobot's training and evaluation tools:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.type my_custom_policy \
|
||||
--env.type pusht \
|
||||
--steps 200000
|
||||
```
|
||||
|
||||
## Examples and Community Contributions
|
||||
|
||||
Check out these example policy implementations:
|
||||
|
||||
- [DiTFlow Policy](https://github.com/danielsanjosepro/lerobot_policy_ditflow) - Diffusion Transformer policy with flow-matching objective. Try it out in this example: [DiTFlow Example](https://github.com/danielsanjosepro/test_lerobot_policy_ditflow)
|
||||
|
||||
Share your policy implementations with the community! 🤗
|
||||
@@ -0,0 +1,206 @@
|
||||
# EarthRover Mini Plus
|
||||
|
||||
The EarthRover Mini Plus is a fully open source mobile robot that connects through the cloud using the Frodobots SDK. This lets you control the robot and record datasets for training AI models.
|
||||
|
||||
## What You Need
|
||||
|
||||
### Hardware
|
||||
|
||||
- EarthRover Mini robot
|
||||
- Computer with Python 3.10 or newer
|
||||
- Internet connection
|
||||
|
||||
### Setting Up the Frodobots SDK
|
||||
|
||||
The robot needs the [Frodobots SDK](https://github.com/Frodobots/earth-rovers-sdk) running on your computer. Here's how:
|
||||
|
||||
1. Download and install the SDK:
|
||||
|
||||
```bash
|
||||
git clone https://github.com/Frodobots/earth-rovers-sdk.git
|
||||
cd earth-rovers-sdk
|
||||
pip install -r requirements.txt
|
||||
```
|
||||
|
||||
2. Start the SDK:
|
||||
|
||||
```bash
|
||||
hypercorn main:app --reload
|
||||
```
|
||||
|
||||
3. Open your web browser and go to `http://localhost:8000`, then click "Join"
|
||||
|
||||
The SDK gives you:
|
||||
|
||||
- Live video from front and rear cameras
|
||||
|
||||
> [!IMPORTANT]
|
||||
> The SDK must be running before you can use the robot.
|
||||
|
||||
## Install LeRobot
|
||||
|
||||
Follow our [Installation Guide](./installation) to install LeRobot.
|
||||
|
||||
In addition to the base installation, install the EarthRover Mini dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .
|
||||
```
|
||||
|
||||
## How It Works
|
||||
|
||||
The robot uses the internet to communicate:
|
||||
|
||||
- **Movement commands**: Sent through the SDK
|
||||
- **Camera video**: Received from the SDK
|
||||
- **Robot info**: Battery, location, speed from the SDK
|
||||
|
||||
You don't need to plug anything in - it all works through the SDK.
|
||||
|
||||
## Calibration
|
||||
|
||||
No calibration needed! The robot is ready to use as soon as the SDK is running.
|
||||
|
||||
## Controlling the Robot
|
||||
|
||||
You control the robot using your keyboard - just like playing a video game with WASD keys.
|
||||
|
||||
### Keyboard Controls
|
||||
|
||||
| Key | Action |
|
||||
| --- | -------------------------------- |
|
||||
| W | Move forward |
|
||||
| S | Move backward |
|
||||
| A | Turn left (with forward motion) |
|
||||
| D | Turn right (with forward motion) |
|
||||
| Q | Rotate left in place |
|
||||
| E | Rotate right in place |
|
||||
| X | Stop all movement |
|
||||
| +/= | Increase speed |
|
||||
| - | Decrease speed |
|
||||
| ESC | Disconnect |
|
||||
|
||||
### Speed Settings
|
||||
|
||||
You can adjust how fast the robot moves:
|
||||
|
||||
- **Forward/backward speed**: Default is full speed (1.0)
|
||||
- **Turning speed**: Default is full speed (1.0)
|
||||
- **Speed changes**: Use +/- keys to adjust by 0.1 each time
|
||||
|
||||
### Try It Out
|
||||
|
||||
Test driving the robot before recording data:
|
||||
|
||||
```python
|
||||
from lerobot.robots.earthrover_mini_plus import EarthRoverMiniPlus, EarthRoverMiniPlusConfig
|
||||
from lerobot.teleoperators.keyboard import KeyboardRoverTeleop, KeyboardRoverTeleopConfig
|
||||
|
||||
# Initialize robot
|
||||
robot_config = EarthRoverMiniPlusConfig()
|
||||
robot = EarthRoverMiniPlus(robot_config)
|
||||
|
||||
# Initialize teleoperator
|
||||
teleop_config = KeyboardRoverTeleopConfig(
|
||||
linear_speed=1.0,
|
||||
angular_speed=1.0,
|
||||
speed_increment=0.1
|
||||
)
|
||||
teleop = KeyboardRoverTeleop(teleop_config)
|
||||
|
||||
# Connect
|
||||
robot.connect()
|
||||
teleop.connect()
|
||||
|
||||
# Teleoperate (use keyboard controls)
|
||||
try:
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> If you're using a Mac, you might need to give Terminal permission to access your keyboard for teleoperation. Go to System Preferences > Security & Privacy > Input Monitoring and check the box for Terminal.
|
||||
|
||||
## Recording Data
|
||||
|
||||
Once you can drive the robot well, you can start recording data to train AI models. The system records:
|
||||
|
||||
- **What you do**: How you move the robot (forward, backward, turning)
|
||||
- **What the robot sees**:
|
||||
- Videos from both cameras
|
||||
- Robot speed and direction
|
||||
- Battery level and location
|
||||
- GPS position and signal
|
||||
- Other sensor data
|
||||
- **When it happened**: Timestamps for everything
|
||||
|
||||
### Setting Up Hugging Face
|
||||
|
||||
We use Hugging Face to store your data online. First, log in with your token from [Hugging Face settings](https://huggingface.co/settings/tokens):
|
||||
|
||||
```bash
|
||||
huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
```
|
||||
|
||||
Store your Hugging Face username:
|
||||
|
||||
```bash
|
||||
HF_USER=$(huggingface-cli whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
### Start Recording
|
||||
|
||||
Use the standard recording command:
|
||||
|
||||
```bash
|
||||
python src/lerobot/scripts/lerobot_record.py \
|
||||
--robot.type=earthrover_mini_plus \
|
||||
--teleop.type=keyboard_rover \
|
||||
--dataset.repo_id=your_username/dataset_name \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
Replace `your_username/dataset_name` with your Hugging Face username and a name for your dataset.
|
||||
|
||||
### What Gets Saved
|
||||
|
||||
Your dataset includes:
|
||||
|
||||
**Your Actions (2 things)**:
|
||||
|
||||
- How much you moved forward/backward
|
||||
- How much you turned left/right
|
||||
|
||||
**Robot Observations (12 things)**:
|
||||
|
||||
- Front camera video
|
||||
- Rear camera video
|
||||
- Current speed
|
||||
- Battery level
|
||||
- Which way the robot is facing
|
||||
- GPS location (latitude, longitude, signal strength)
|
||||
- Network signal strength
|
||||
- Vibration level
|
||||
- Lamp status (on/off)
|
||||
|
||||
### Where Your Data Goes
|
||||
|
||||
On your computer: `~/.cache/huggingface/lerobot/{repo-id}`
|
||||
|
||||
After recording, your data automatically uploads to your Hugging Face page:
|
||||
|
||||
```bash
|
||||
echo https://huggingface.co/datasets/${HF_USER}/earthrover-navigation
|
||||
```
|
||||
|
||||
Your dataset will be tagged with `LeRobot` for community discovery.
|
||||
@@ -0,0 +1,301 @@
|
||||
# LeIsaac × LeRobot EnvHub
|
||||
|
||||
LeRobot EnvHub now supports **imitation learning in simulation** with LeIsaac.
|
||||
Spin up everyday manipulation tasks, teleoperate the robot, collect demos, push them to the Hub, and train policies in LeRobot — all in one loop.
|
||||
|
||||
[LeIsaac](https://github.com/LightwheelAI/leisaac) integrates with IsaacLab and the SO101 Leader/Follower setup to provide:
|
||||
|
||||
- 🕹️ **Teleoperation-first workflows** for data collection
|
||||
- 📦 **Built-in data conversion** ready for LeRobot training
|
||||
- 🤖 **Everyday skills** like picking oranges, lifting cubes, cleaning tables, and folding cloth
|
||||
- ☁️ **Ongoing upgrades** from [LightWheel](https://lightwheel.ai/): cloud simulation, EnvHub support, Sim2Real tooling, and more
|
||||
|
||||
Below you’ll find the currently supported LeIsaac tasks exposed through LeRobot EnvHub.
|
||||
|
||||
# Available Environments
|
||||
|
||||
The following table lists all available tasks and environments in LeIsaac x LeRobot Envhub. You can also get the latest list of environments by running the following command:
|
||||
|
||||
```bash
|
||||
python scripts/environments/list_envs.py
|
||||
```
|
||||
|
||||
| Task | Environment ID | Task Description | Related Robot |
|
||||
| :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------- |
|
||||
| <video src="https://github.com/user-attachments/assets/466eddff-f720-4f99-94d5-5e123e4c302c" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-PickOrange-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/pick_orange_env_cfg.py)<br /><br />[LeIsaac-SO101-PickOrange-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/direct/pick_orange_env.py) | Pick three oranges and put them into the plate, then reset the arm to rest state. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/1e4eb83a-0b38-40fb-a0b2-ddb0fe201e6d" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-LiftCube-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/lift_cube_env_cfg.py)<br /><br />[LeIsaac-SO101-LiftCube-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/direct/lift_cube_env.py) | Lift the red cube up. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e49d8f1c-dcc9-412b-a88f-100680d8a45b" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-CleanToyTable-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/direct/clean_toy_table_bi_arm_env.py) | Pick two letter e objects into the box, and reset the arm to rest state. | Single-Arm SO101 Follower<br /><br />Bi-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e29a0f8a-9286-4ce6-b45d-342c3d3ba754" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-FoldCloth-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/fold_cloth_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-FoldCloth-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/direct/fold_cloth_bi_arm_env.py) | Fold the cloth, and reset the arm to rest state.<br /><br />_Note: Only the DirectEnv support check_success in this task._ | Bi-Arm SO101 Follower |
|
||||
|
||||
# Load LeIsaac directly in LeRobot with one line of code
|
||||
|
||||
> EnvHub: Share LeIsaac environments through HuggingFace
|
||||
|
||||
[EnvHub](https://huggingface.co/docs/lerobot/envhub) is our reproducible environment hub, spin up a packaged simulation with one line, experiment immediately, and publish your own tasks for the community.
|
||||
|
||||
LeIsaac offers EnvHub support so you can consume or share tasks with only a few commands.
|
||||
|
||||
<video
|
||||
controls
|
||||
src="https://github.com/user-attachments/assets/687666f5-ebe0-421d-84a0-eb86116ac5f8"
|
||||
style={{ width: "100%", maxWidth: "960px", borderRadius: "8px" }}
|
||||
/>
|
||||
|
||||
## How to get started, environment Setup
|
||||
|
||||
Run the following commands to setup your code environments:
|
||||
|
||||
```bash
|
||||
# Refer to Getting Started/Installation to install leisaac firstly
|
||||
conda create -n leisaac_envhub python=3.11
|
||||
conda activate leisaac_envhub
|
||||
|
||||
conda install -c "nvidia/label/cuda-12.8.1" cuda-toolkit
|
||||
pip install -U torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
pip install 'leisaac[isaaclab] @ git+https://github.com/LightwheelAI/leisaac.git#subdirectory=source/leisaac' --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Install lerobot
|
||||
pip install lerobot==0.4.1
|
||||
|
||||
# Fix numpy version
|
||||
pip install numpy==1.26.0
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples below load `so101_pick_orange` and demonstrate a random-action rollout and an interactive teleoperation.
|
||||
|
||||
### Random Action
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_random_action.py
|
||||
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# Use it like any gym environment
|
||||
obs, info = env.reset()
|
||||
|
||||
while True:
|
||||
action = torch.tensor(env.action_space.sample())
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_random_action.py
|
||||
```
|
||||
|
||||
You should see the SO101 arm swinging under purely random commands.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
LeRobot’s teleoperation stack can drive the simulated arm.
|
||||
|
||||
Connect the SO101 Leader controller, run the calibration command below.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
And then launch the teleop script.
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_teleop_example.py
|
||||
|
||||
import logging
|
||||
import time
|
||||
import gymnasium as gym
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
make_teleoperator_from_config,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeleoperateConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
env_name: str = "so101_pick_orange"
|
||||
fps: int = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrap:
|
||||
env: gym.Env
|
||||
|
||||
|
||||
def make_env_from_leisaac(env_name: str = "so101_pick_orange"):
|
||||
envs_dict = make_env(
|
||||
f'LightwheelAI/leisaac_env:envs/{env_name}.py',
|
||||
n_envs=1,
|
||||
trust_remote_code=True
|
||||
)
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
|
||||
from leisaac.devices.action_process import preprocess_device_action
|
||||
from leisaac.assets.robots.lerobot import SO101_FOLLOWER_MOTOR_LIMITS
|
||||
from leisaac.utils.env_utils import dynamic_reset_gripper_effort_limit_sim
|
||||
|
||||
env_wrap = EnvWrap(env=env)
|
||||
|
||||
obs, info = env.reset()
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
if env.cfg.dynamic_reset_gripper_effort_limit:
|
||||
dynamic_reset_gripper_effort_limit_sim(env, 'so101leader')
|
||||
|
||||
raw_action = teleop.get_action()
|
||||
processed_action = preprocess_device_action(
|
||||
dict(
|
||||
so101_leader=True,
|
||||
joint_state={
|
||||
k.removesuffix(".pos"): v for k, v in raw_action.items()},
|
||||
motor_limits=SO101_FOLLOWER_MOTOR_LIMITS),
|
||||
env_wrap
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
env = make_env_from_leisaac(cfg.env_name)
|
||||
|
||||
teleop.connect()
|
||||
if hasattr(env, 'initialize'):
|
||||
env.initialize()
|
||||
try:
|
||||
teleop_loop(teleop=teleop, env=env, fps=cfg.fps)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
teleop.disconnect()
|
||||
env.close()
|
||||
|
||||
|
||||
def main():
|
||||
teleoperate(TeleoperateConfig(
|
||||
teleop=so101_leader.SO101LeaderConfig(
|
||||
port="/dev/ttyACM0",
|
||||
id='leader',
|
||||
use_degrees=False,
|
||||
),
|
||||
env_name="so101_pick_orange",
|
||||
fps=60,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_teleop_example.py
|
||||
```
|
||||
|
||||
Running the script lets you operate the simulated arm using the physical Leader device.
|
||||
|
||||
## ☁️ Cloud Simulation (No GPU Required)
|
||||
|
||||
Don’t have a local GPU or the right drivers? No problem! You can run LeIsaac entirely in the cloud with zero setup.
|
||||
LeIsaac works out-of-the-box on **NVIDIA Brev**, giving you a fully configured environment directly in your browser.
|
||||
|
||||
👉 **Start here:** [https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev](https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev)
|
||||
|
||||
Once your instance is deployed, simply open the link for **port 80 (HTTP)** to launch **Visual Studio Code Server** (default password: `password`). From there, you can run simulations, edit code, and visualize IsaacLab environments — all from your web browser.
|
||||
|
||||
**No GPU, no drivers, no local installation. Just click and run.**
|
||||
|
||||
## Additional Notes
|
||||
|
||||
We keep EnvHub coverage aligned with the LeIsaac task. Currently supported:
|
||||
|
||||
- `so101_pick_orange`
|
||||
- `so101_lift_cube`
|
||||
- `so101_clean_toytable`
|
||||
- `bi_so101_fold_cloth`
|
||||
|
||||
Switch tasks by targeting a different script when calling `make_env`, for example:
|
||||
|
||||
```python
|
||||
envs_dict_pick_orange = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_lift_cube = make_env("LightwheelAI/leisaac_env:envs/so101_lift_cube.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_clean_toytable = make_env("LightwheelAI/leisaac_env:envs/so101_clean_toytable.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_fold_cloth = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
```
|
||||
|
||||
Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately after retrieving the env before performing any other operations:
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# NOTE: initialize() first
|
||||
env.initialize()
|
||||
|
||||
# other operation with env...
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -393,7 +393,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
@@ -415,7 +415,7 @@ for idx in range(dataset.num_frames):
|
||||
}
|
||||
robot.send_action(action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
```
|
||||
@@ -428,7 +428,7 @@ Your robot should replicate movements similar to those you recorded. For example
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
@@ -485,7 +485,7 @@ huggingface-cli upload ${HF_USER}/act_so101_test${CKPT} \
|
||||
|
||||
## Run inference and evaluate your policy
|
||||
|
||||
You can use the `record` script from [`lerobot/record.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
You can use the `record` script from [`lerobot-record`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/lerobot_record.py) with a policy checkpoint as input, to run inference and evaluate your policy. For instance, run this command or API example to run inference and record 10 evaluation episodes:
|
||||
|
||||
<hfoptions id="eval">
|
||||
<hfoption id="Command">
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
# Imitation Learning in Sim
|
||||
|
||||
This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning.
|
||||
|
||||
**You'll learn:**
|
||||
|
||||
1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset.
|
||||
2. How to train a policy using your data.
|
||||
3. How to evaluate your policy in simulation and visualize the results.
|
||||
|
||||
For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm.
|
||||
This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format.
|
||||
Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge.
|
||||
|
||||
## Installation
|
||||
|
||||
First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command:
|
||||
|
||||
```bash
|
||||
pip install -e ".[hilserl]"
|
||||
```
|
||||
|
||||
## Teleoperate and Record a Dataset
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
<hfoptions id="teleop_sim">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls.
|
||||
|
||||
Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control!
|
||||
|
||||
**Gamepad Controls**
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true"
|
||||
alt="Figure shows the control mappings on a Logitech gamepad."
|
||||
title="Gamepad Control Mapping"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Gamepad button mapping for robot control and episode management</i>
|
||||
</p>
|
||||
|
||||
**Keyboard controls**
|
||||
|
||||
For keyboard controls use the `spacebar` to enable control and the following keys to move the robot:
|
||||
|
||||
```bash
|
||||
Arrow keys: Move in X-Y plane
|
||||
Shift and Shift_R: Move in Z axis
|
||||
Right Ctrl and Left Ctrl: Open and close gripper
|
||||
ESC: Exit
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/dataset_visualizer_sim.png"
|
||||
alt="Figure shows the dataset visualizer"
|
||||
title="Dataset visualization"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Dataset visualizer</i>
|
||||
</p>
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
--job_name=il_sim_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain the command:
|
||||
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`.
|
||||
|
||||
#### Train using Collab
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test \
|
||||
outputs/train/il_sim_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json).
|
||||
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
<hfoptions id="eval_policy">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!WARNING]
|
||||
> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
|
||||
|
||||
Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
@@ -90,7 +90,7 @@ If you encounter build errors, you may need to install additional dependencies:
|
||||
To install these for linux run:
|
||||
|
||||
```bash
|
||||
sudo apt-get install cmake build-essential python-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev pkg-config
|
||||
sudo apt-get install cmake build-essential python3-dev pkg-config libavformat-dev libavcodec-dev libavdevice-dev libavutil-dev libswscale-dev libswresample-dev libavfilter-dev
|
||||
```
|
||||
|
||||
For other systems, see: [Compiling PyAV](https://pyav.org/docs/develop/overview/installation.html#bring-your-own-ffmpeg)
|
||||
|
||||
@@ -62,6 +62,11 @@ lerobot-eval \
|
||||
|
||||
- Pass a comma-separated list to `--env.task` for multi-suite evaluation.
|
||||
|
||||
### Control Mode
|
||||
|
||||
LIBERO now supports two control modes: relative and absolute. This matters because different VLA checkpoints are trained with different mode of action to output hence control parameterizations.
|
||||
You can switch them with: `env.control_mode = "relative"` and `env.control_mode = "absolute"`
|
||||
|
||||
### Policy inputs and outputs
|
||||
|
||||
When using LIBERO through LeRobot, policies interact with the environment via **observations** and **actions**:
|
||||
|
||||
+125
-125
@@ -30,131 +30,6 @@ The follower arm uses 6x STS3215 motors with 1/345 gearing. The leader, however,
|
||||
| Wrist Roll | 5 | 1 / 147 |
|
||||
| Gripper | 6 | 1 / 147 |
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Configure the motors
|
||||
|
||||
### 1. Find the USB ports associated with each arm
|
||||
@@ -340,6 +215,131 @@ leader.setup_motors()
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
### Clean Parts
|
||||
|
||||
Remove all support material from the 3D-printed parts. The easiest way to do this is using a small screwdriver to get underneath the support material.
|
||||
|
||||
It is advisable to install one 3-pin cable in the motor after placing them before continuing assembly.
|
||||
|
||||
### Joint 1
|
||||
|
||||
- Place the first motor into the base.
|
||||
- Fasten the motor with 4 M2x6mm screws (smallest screws). Two from the top and two from the bottom.
|
||||
- Slide over the first motor holder and fasten it using two M2x6mm screws (one on each side).
|
||||
- Install both motor horns, securing the top horn with a M3x6mm screw.
|
||||
- Attach the shoulder part.
|
||||
- Tighten the shoulder part with 4 M3x6mm screws on top and 4 M3x6mm screws on the bottom
|
||||
- Add the shoulder motor holder.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint1_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 2
|
||||
|
||||
- Slide the second motor in from the top.
|
||||
- Fasten the second motor with 4 M2x6mm screws.
|
||||
- Attach both motor horns to motor 2, again use the M3x6mm horn screw.
|
||||
- Attach the upper arm with 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint2_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 3
|
||||
|
||||
- Insert motor 3 and fasten using 4 M2x6mm screws
|
||||
- Attach both motor horns to motor 3 and secure one again with a M3x6mm horn screw.
|
||||
- Connect the forearm to motor 3 using 4 M3x6mm screws on each side.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint3_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 4
|
||||
|
||||
- Slide over motor holder 4.
|
||||
- Slide in motor 4.
|
||||
- Fasten motor 4 with 4 M2x6mm screws and attach its motor horns, use a M3x6mm horn screw.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint4_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Joint 5
|
||||
|
||||
- Insert motor 5 into the wrist holder and secure it with 2 M2x6mm front screws.
|
||||
- Install only one motor horn on the wrist motor and secure it with a M3x6mm horn screw.
|
||||
- Secure the wrist to motor 4 using 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Joint5_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
### Gripper / Handle
|
||||
|
||||
<hfoptions id="assembly">
|
||||
<hfoption id="Follower">
|
||||
|
||||
- Attach the gripper to motor 5, attach it to the motor horn on the wrist using 4 M3x6mm screws.
|
||||
- Insert the gripper motor and secure it with 2 M2x6mm screws on each side.
|
||||
- Attach the motor horns and again use a M3x6mm horn screw.
|
||||
- Install the gripper claw and secure it with 4 M3x6mm screws on both sides.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Gripper_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="Leader">
|
||||
|
||||
- Mount the leader holder onto the wrist and secure it with 4 M3x6mm screws.
|
||||
- Attach the handle to motor 5 using 1 M2x6mm screw.
|
||||
- Insert the gripper motor, secure it with 2 M2x6mm screws on each side, attach a motor horn using a M3x6mm horn screw.
|
||||
- Attach the follower trigger with 4 M3x6mm screws.
|
||||
|
||||
<div class="video-container">
|
||||
<video controls width="600">
|
||||
<source
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/Leader_v2.mp4"
|
||||
type="video/mp4"
|
||||
/>
|
||||
</video>
|
||||
</div>
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
## Calibrate
|
||||
|
||||
Next, you'll need to calibrate your robot to ensure that the leader and follower arms have the same position values when they are in the same physical position.
|
||||
|
||||
@@ -0,0 +1,42 @@
|
||||
# PyTorch accelerators
|
||||
|
||||
LeRobot supports multiple hardware acceleration options for both training and inference.
|
||||
|
||||
These options include:
|
||||
|
||||
- **CPU**: CPU executes all computations, no dedicated accelerator is used
|
||||
- **CUDA**: acceleration with NVIDIA & AMD GPUs
|
||||
- **MPS**: acceleration with Apple Silicon GPUs
|
||||
- **XPU**: acceleration with Intel integrated and discrete GPUs
|
||||
|
||||
## Getting Started
|
||||
|
||||
To use particular accelerator, a suitable version of PyTorch should be installed.
|
||||
|
||||
For CPU, CUDA, and MPS backends follow instructions provided on [PyTorch installation page](https://pytorch.org/get-started/locally).
|
||||
For XPU backend, follow instructions from [PyTorch documentation](https://docs.pytorch.org/docs/stable/notes/get_start_xpu.html).
|
||||
|
||||
### Verifying the installation
|
||||
|
||||
After installation, accelerator availability can be verified by running
|
||||
|
||||
```python
|
||||
import torch
|
||||
print(torch.<backend_name>.is_available()) # <backend_name> is cuda, mps, or xpu
|
||||
```
|
||||
|
||||
## How to run training or evaluation
|
||||
|
||||
To select the desired accelerator, use the `--policy.device` flag when running `lerobot-train` or `lerobot-eval`. For example, to use MPS on Apple Silicon, run:
|
||||
|
||||
```bash
|
||||
lerobot-train
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.device=mps ...
|
||||
```
|
||||
|
||||
However, in most cases, presence of an accelerator is detected automatically and `policy.device` parameter can be omitted from CLI commands.
|
||||
@@ -0,0 +1,203 @@
|
||||
# Unitree G1 Robot Setup and Control
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
|
||||
## About the Unitree G1
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Connect to Robot over Ethernet
|
||||
|
||||
### Step 1: Configure Your Computer's Ethernet Interface
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
sudo ip addr flush dev enp131s0
|
||||
sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the robot's onboard computer.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Once connected via Ethernet, follow these steps to enable WiFi:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
|
||||
```bash
|
||||
# Unblock WiFi radio
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up WiFi interface
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the robot:**
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Test connection
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Step 3: Connect to WiFi Network
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Robot Server Setup
|
||||
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
```
|
||||
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Running GR00T Locomotion
|
||||
|
||||
With the robot server running, you can now control the robot from your laptop.
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
```
|
||||
|
||||
### Step 4: Control with Remote
|
||||
|
||||
- **Left stick**: Forward/backward and left/right movement
|
||||
- **Right stick**: Rotation
|
||||
- **R1 button**: Raise waist height
|
||||
- **R2 button**: Lower waist height
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: December 2025_
|
||||
@@ -11,13 +11,14 @@ LeRobot provides several utilities for manipulating datasets:
|
||||
3. **Merge Datasets** - Combine multiple datasets into one. The datasets must have identical features, and episodes are concatenated in the order specified in `repo_ids`
|
||||
4. **Add Features** - Add new features to a dataset
|
||||
5. **Remove Features** - Remove features from a dataset
|
||||
6. **Convert to Video** - Convert image-based datasets to video format for efficient storage
|
||||
|
||||
The core implementation is in `lerobot.datasets.dataset_tools`.
|
||||
An example script detailing how to use the tools API is available in `examples/dataset/use_dataset_tools.py`.
|
||||
|
||||
## Command-Line Tool: lerobot-edit-dataset
|
||||
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, and remove features.
|
||||
`lerobot-edit-dataset` is a command-line script for editing datasets. It can be used to delete episodes, split datasets, merge datasets, add features, remove features, and convert image datasets to video format.
|
||||
|
||||
Run `lerobot-edit-dataset --help` for more information on the configuration of each operation.
|
||||
|
||||
@@ -86,9 +87,71 @@ lerobot-edit-dataset \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
```
|
||||
|
||||
#### Convert to Video
|
||||
|
||||
Convert an image-based dataset to video format, creating a new LeRobotDataset where images are stored as videos. This is useful for reducing storage requirements and improving data loading performance. The new dataset will have the exact same structure as the original, but with images encoded as MP4 videos in the proper LeRobot format.
|
||||
|
||||
```bash
|
||||
# Local-only: Save to a custom output directory (no hub push)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
# Save with new repo_id (local storage)
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
# Convert and push to Hugging Face Hub
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
# Convert with custom video codec and quality settings
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.vcodec libsvtav1 \
|
||||
--operation.pix_fmt yuv420p \
|
||||
--operation.g 2 \
|
||||
--operation.crf 30
|
||||
|
||||
# Convert only specific episodes
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.episode_indices "[0, 1, 2, 5, 10]"
|
||||
|
||||
# Convert with multiple workers for parallel processing
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir outputs/pusht_video \
|
||||
--operation.num_workers 8
|
||||
```
|
||||
|
||||
**Parameters:**
|
||||
|
||||
- `output_dir`: Custom output directory (optional - by default uses `new_repo_id` or `{repo_id}_video`)
|
||||
- `vcodec`: Video codec to use - options: `h264`, `hevc`, `libsvtav1` (default: `libsvtav1`)
|
||||
- `pix_fmt`: Pixel format - options: `yuv420p`, `yuv444p` (default: `yuv420p`)
|
||||
- `g`: Group of pictures (GOP) size - lower values give better quality but larger files (default: 2)
|
||||
- `crf`: Constant rate factor - lower values give better quality but larger files, 0 is lossless (default: 30)
|
||||
- `fast_decode`: Fast decode tuning option (default: 0)
|
||||
- `episode_indices`: List of specific episodes to convert (default: all episodes)
|
||||
- `num_workers`: Number of parallel workers for processing (default: 4)
|
||||
|
||||
**Note:** The resulting dataset will be a proper LeRobotDataset with all cameras encoded as videos in the `videos/` directory, with parquet files containing only metadata (no raw image data). All episodes, stats, and tasks are preserved.
|
||||
|
||||
### Push to Hub
|
||||
|
||||
Add the `--push_to_hub` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
Add the `--push_to_hub true` flag to any command to automatically upload the resulting dataset to the Hugging Face Hub:
|
||||
|
||||
```bash
|
||||
lerobot-edit-dataset \
|
||||
@@ -96,7 +159,7 @@ lerobot-edit-dataset \
|
||||
--new_repo_id lerobot/pusht_after_deletion \
|
||||
--operation.type delete_episodes \
|
||||
--operation.episode_indices "[0, 2, 5]" \
|
||||
--push_to_hub
|
||||
--push_to_hub true
|
||||
```
|
||||
|
||||
There is also a tool for adding features to a dataset that is not yet covered in `lerobot-edit-dataset`.
|
||||
|
||||
@@ -0,0 +1,528 @@
|
||||
# X-VLA: The First Soft-Prompted Robot Foundation Model for Any Robot, Any Task
|
||||
|
||||
## Overview
|
||||
|
||||
For years, robotics has aspired to build agents that can follow natural human instructions and operate dexterously across many environments and robot bodies. Recent breakthroughs in LLMs and VLMs suggest a path forward: extend these foundation-model architectures to embodied control by grounding them in actions. This has led to the rise of Vision-Language-Action (VLA) models, with the hope that a single generalist model could combine broad semantic understanding with robust manipulation skills.
|
||||
|
||||
But training such models is difficult. Robot data is fragmented across platforms, sensors, embodiments, and collection protocols. Heterogeneity appears everywhere: different arm configurations, different action spaces, different camera setups, different visual domains, and different task distributions. These inconsistencies create major distribution shifts that make pretraining unstable and adaptation unreliable.
|
||||
|
||||
Inspired by meta-learning and prompt learning, we ask: **"What if a VLA model could learn the structure of each robot and dataset the same way LLMs learn tasks, through prompts?"**
|
||||
|
||||
**X-VLA** is a soft-prompted, flow-matching VLA framework that treats each hardware setup as a "task" and encodes it using a small set of learnable embeddings. These **Soft Prompts** capture embodiment and domain-specific variations, guiding the Transformer from the earliest stages of multimodal fusion. With this mechanism, X-VLA can reconcile diverse robot morphologies, data types, and sensor setups within a single unified architecture.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture.png"
|
||||
alt="XVLA Architecture"
|
||||
style="max-width: 100%; height: auto; width: 800px;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
Built from pure Transformer encoders, X-VLA scales naturally with model size and dataset diversity. Across 6 simulation benchmarks and 3 real robots, Soft Prompts consistently outperform existing methods in handling hardware and domain differences. X-VLA-0.9B, trained on 290K episodes spanning seven robotic platforms, learns an embodiment-agnostic generalist policy in Phase I, and adapts efficiently to new robots in Phase II simply by learning a new set of prompts, while keeping the backbone frozen.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-architecture2.png"
|
||||
alt="XVLA Architecture 2"
|
||||
style="width: 60%; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
With only 1% of parameters tuned (9M), X-VLA-0.9B achieves near-π₀ performance on LIBERO and Simpler-WidowX, despite using **300× fewer trainable parameters**. It also demonstrates strong real-world dexterity with minimal demonstrations, including folding cloths in under two minutes.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/xvla-fold.png"
|
||||
alt="XVLA fold visualization"
|
||||
style="width: 95%; max-width: 1100px; height: auto;"
|
||||
/>
|
||||
</p>
|
||||
|
||||
X-VLA shows that generalist robot intelligence does not require increasingly complex architectures, only the right way to absorb heterogeneity. Soft Prompts offer a simple, scalable mechanism for unifying diverse robotic data, paving the way toward adaptable, cross-embodiment robot foundation models.
|
||||
|
||||
## Installation
|
||||
|
||||
After installing LeRobot, install the X-VLA dependencies:
|
||||
|
||||
```bash
|
||||
pip install -e .[xvla]
|
||||
```
|
||||
|
||||
After the new release, you'll be able to do:
|
||||
|
||||
```bash
|
||||
pip install lerobot[xvla]
|
||||
```
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Basic Usage
|
||||
|
||||
To use X-VLA in your LeRobot configuration, specify the policy type as:
|
||||
|
||||
```bash
|
||||
policy.type=xvla
|
||||
```
|
||||
|
||||
### Evaluating Pre-trained Checkpoints
|
||||
|
||||
Example evaluation with LIBERO:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path="lerobot/xvla-libero" \
|
||||
--env.type=libero \
|
||||
--env.task=libero_spatial,libero_goal,libero_10 \
|
||||
--env.control_mode=absolute \
|
||||
--eval.batch_size=1 \
|
||||
--eval.n_episodes=1 \
|
||||
--env.episode_length=800 \
|
||||
--seed=142
|
||||
```
|
||||
|
||||
## Available Checkpoints
|
||||
|
||||
### 🎯 Base Model
|
||||
|
||||
**[lerobot/xvla-base](https://huggingface.co/lerobot/xvla-base)**
|
||||
|
||||
A 0.9B parameter instantiation of X-VLA, trained with a carefully designed data processing and learning recipe. The training pipeline consists of two phases:
|
||||
|
||||
- **Phase I: Pretraining** - Pretrained on 290K episodes from Droid, Robomind, and Agibot, spanning seven platforms across five types of robotic arms (single-arm to bi-manual setups). By leveraging soft prompts to absorb embodiment-specific variations, the model learns an embodiment-agnostic generalist policy.
|
||||
|
||||
- **Phase II: Domain Adaptation** - Adapted to deployable policies for target domains. A new set of soft prompts is introduced and optimized to encode the hardware configuration of the novel domain, while the pretrained backbone remains frozen.
|
||||
|
||||
### Simulation Checkpoints
|
||||
|
||||
**[lerobot/xvla-libero](https://huggingface.co/lerobot/xvla-libero)**
|
||||
|
||||
Achieves 93% success rate on LIBERO benchmarks. Fine-tuned from the base model for simulation tasks.
|
||||
|
||||
**[lerobot/xvla-widowx](https://huggingface.co/lerobot/xvla-widowx)**
|
||||
|
||||
Fine-tuned on BridgeData for pick-and-place experiments on compact WidowX platforms. Demonstrates robust manipulation capabilities.
|
||||
|
||||
### 🤖 Real-World Checkpoints
|
||||
|
||||
**[lerobot/xvla-folding](https://huggingface.co/lerobot/xvla-folding)**
|
||||
|
||||
A fine-tuned dexterous manipulation model trained on the high-quality Soft-FOLD cloth folding dataset. Achieves 100% success rate over 2 hours of continuous cloth folding.
|
||||
|
||||
**[lerobot/xvla-agibot-world](https://huggingface.co/lerobot/xvla-agibot-world)**
|
||||
|
||||
Optimized for AgileX robot dexterous manipulation tasks.
|
||||
|
||||
**[lerobot/xvla-google-robot](https://huggingface.co/lerobot/xvla-google-robot)**
|
||||
|
||||
Adapted for Google Robot platforms.
|
||||
|
||||
## Training X-VLA
|
||||
|
||||
### Recommended Training Configuration
|
||||
|
||||
When fine-tuning X-VLA for a new embodiment or task, we recommend not freezing the VLM, and also setting the `policy.dtype=bfloat16` to not hit OOM errors.
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--output_dir=./outputs/xvla_training \
|
||||
--job_name=xvla_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.repo_id="HF_USER/xvla-your-robot" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true \
|
||||
--policy.action_mode=YOUR_ACTION_MODE
|
||||
```
|
||||
|
||||
### Training Parameters Explained
|
||||
|
||||
| Parameter | Default | Description |
|
||||
| -------------------------- | ------- | ---------------------------------------------- |
|
||||
| `freeze_vision_encoder` | `false` | Do not freeze the VLM vision encoder weights |
|
||||
| `freeze_language_encoder` | `false` | Do not freeze the VLM language encoder weights |
|
||||
| `train_policy_transformer` | `true` | Allow policy transformer layers to train |
|
||||
| `train_soft_prompts` | `true` | Allow soft prompts to train |
|
||||
|
||||
**💡 Best Practice**: For Phase II adaptation to new embodiments, do not freeze the VLM encoders and also train the policy transformer and soft prompts.
|
||||
|
||||
### Example: Training on Bimanual Robot
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=pepijn223/bimanual-so100-handover-cube \
|
||||
--output_dir=./outputs/xvla_bimanual \
|
||||
--job_name=xvla_so101_training \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.dtype=bfloat16 \
|
||||
--policy.repo_id="YOUR_USERNAME/xvla-biso101" \
|
||||
--steps=3000 \
|
||||
--policy.device=cuda \
|
||||
--policy.action_mode=so101_bimanual \
|
||||
--policy.freeze_vision_encoder=false \
|
||||
--policy.freeze_language_encoder=false \
|
||||
--policy.train_policy_transformer=true \
|
||||
--policy.train_soft_prompts=true
|
||||
```
|
||||
|
||||
💡 **Best Performance:** If you have sufficient computational resources and want to achieve best X-VLA finetuning performance, you should follow the official finetuning strategy:
|
||||
|
||||
**🔥 Full-finetune all components with a custom learning-rate scheme**
|
||||
|
||||
To ensure stable optimization, the Vision-Language Model (VLM) must be trained with only 1/10 of the base learning rate, while all other components use the full LR.
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance. This is already done for you by default.
|
||||
❕Note
|
||||
|
||||
Completely matching the official reported performance may require an additional warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
We encourage implementing this in your customized training pipeline for optimal results.
|
||||
|
||||
## Core Concepts
|
||||
|
||||
### 1. Action Modes
|
||||
|
||||
X-VLA uses an **Action Registry** system to handle different action spaces and embodiments. The `action_mode` parameter defines how actions are processed, what loss functions are used, and how predictions are post-processed.
|
||||
|
||||
#### Available Action Modes
|
||||
|
||||
| Action Mode | Action Dim | Description | Use Case |
|
||||
| ---------------- | ----------------------- | ------------------------------------------- | ------------------------------------ |
|
||||
| `ee6d` | 20 | End-effector with xyz, 6D rotation, gripper | Dual-arm setups with spatial control |
|
||||
| `joint` | 14 | Joint-space with gripper | Direct joint control robots |
|
||||
| `agibot_ee6d` | 20 | AGI-bot variant with MSE loss | AGI-bot platforms |
|
||||
| `so101_bimanual` | 20 (model), 12 (real) | SO101 bimanual robot | Bimanual manipulation tasks |
|
||||
| `auto` | 20 (model), auto (real) | Auto-detects action dim from dataset | **Recommended** for new robots |
|
||||
|
||||
#### Why Action Modes Matter
|
||||
|
||||
When you have a pretrained checkpoint like `lerobot/xvla-base` trained with `action_dim=20`, and you want to train on a dataset with a different action dimension (e.g., 14 for bimanual arms), you can't simply trim the action dimension. The action mode orchestrates:
|
||||
|
||||
1. **Loss Computation**: Different loss functions for different action components (MSE for joints, BCE for grippers, etc.)
|
||||
2. **Preprocessing**: Zeroing out gripper channels, padding dimensions
|
||||
3. **Postprocessing**: Applying sigmoid to gripper logits, trimming padding
|
||||
|
||||
#### Example: BimanualSO101 Action Space
|
||||
|
||||
The `so101_bimanual` action mode handles the mismatch between model output (20D) and real robot control (12D):
|
||||
|
||||
```python
|
||||
# Model outputs 20 dimensions for compatibility
|
||||
dim_action = 20
|
||||
|
||||
# Real robot only needs 12 dimensions
|
||||
# [left_arm (6), right_arm (6)] = [joints (5) + gripper (1)] × 2
|
||||
REAL_DIM = 12
|
||||
|
||||
# Preprocessing: Pad 12D actions to 20D for training
|
||||
# Postprocessing: Trim 20D predictions to 12D for deployment
|
||||
```
|
||||
|
||||
See the [action_hub.py](/home/jade_choghari/robot/lerobot/src/lerobot/policies/xvla/action_hub.py) implementation for details.
|
||||
|
||||
#### Auto Action Mode (Recommended)
|
||||
|
||||
The `auto` action mode is the easiest way to use X-VLA with any robot. It automatically detects your dataset's action dimension and handles padding/trimming:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
--policy.action_mode=auto \
|
||||
--policy.max_action_dim=20 \
|
||||
...
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
- Reads `action_feature.shape[-1]` from your dataset (e.g., 7 for Franka)
|
||||
- Model outputs `max_action_dim` (default 20) for pretrained compatibility
|
||||
- Loss is computed **only on the real dimensions**: `MSE(pred[:,:,:real_dim], target[:,:,:real_dim])`
|
||||
- Postprocess trims output back to `real_dim` for robot control
|
||||
|
||||
This eliminates the need to create custom action modes for most robots.
|
||||
|
||||
### 2. Domain IDs
|
||||
|
||||
Domain IDs are learnable identifiers for different robot configurations and camera setups. They allow X-VLA to distinguish between:
|
||||
|
||||
- Different robots (Robot 1 vs Robot 2)
|
||||
- Different camera configurations (cam1 vs cam2)
|
||||
- Different combinations (Robot1-cam1-cam2 vs Robot1-cam1 vs Robot2-cam1)
|
||||
|
||||
#### Setting Domain IDs
|
||||
|
||||
**During Training**: By default, domain_id is set to 0 for general training.
|
||||
|
||||
**During Evaluation**: Specify the domain_id that matches your checkpoint's training configuration.
|
||||
|
||||
```python
|
||||
# Example: LIBERO checkpoint uses domain_id=3
|
||||
domain_id = 3
|
||||
```
|
||||
|
||||
The domain_id is automatically added to observations by the `XVLAAddDomainIdProcessorStep` in the preprocessing pipeline.
|
||||
|
||||
The `lerobot/xvla-base` model has been trained on the following domain IDs. It is recommended to choose one that most resembles your robot/configuration:
|
||||
|
||||
#### Fine-tuning Datasets
|
||||
|
||||
| Dataset Name | Domain ID |
|
||||
| ---------------- | --------- |
|
||||
| Bridge | 0 |
|
||||
| RT1 | 1 |
|
||||
| Calvin | 2 |
|
||||
| libero | 3 |
|
||||
| widowx-air | 4 |
|
||||
| AIR-AGILEX-HQ | 5 |
|
||||
| robotwin2_abs_ee | 6 |
|
||||
| robotwin2_clean | 6 |
|
||||
| robocasa-human | 7 |
|
||||
| VLABench | 8 |
|
||||
| AGIBOT-challenge | 9 |
|
||||
| AIR-AGILEX | 10 |
|
||||
| AIRBOT | 18 |
|
||||
|
||||
### 3. Processor Steps
|
||||
|
||||
X-VLA requires specific preprocessing and postprocessing steps for proper operation.
|
||||
|
||||
#### Required Preprocessing Steps
|
||||
|
||||
1. **XVLAImageToFloatProcessorStep**: Converts images from [0, 255] to [0, 1] range
|
||||
2. **XVLAImageNetNormalizeProcessorStep**: Applies ImageNet normalization (required for VLM backbone)
|
||||
3. **XVLAAddDomainIdProcessorStep**: Adds domain_id to observations
|
||||
|
||||
#### Example Custom Processor
|
||||
|
||||
For LIBERO environments, a custom processor handles the specific observation format:
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.processor_xvla import LiberoProcessorStep
|
||||
|
||||
processor = LiberoProcessorStep()
|
||||
# Handles robot_state dictionary, converts rotation matrices to 6D representation
|
||||
# Applies 180° image rotation for camera convention
|
||||
```
|
||||
|
||||
### 4. Configuration Parameters
|
||||
|
||||
Key configuration parameters for X-VLA:
|
||||
|
||||
```python
|
||||
# Observation and action
|
||||
n_obs_steps: int = 1 # Number of observation timesteps
|
||||
chunk_size: int = 32 # Action sequence length
|
||||
n_action_steps: int = 32 # Number of action steps to execute
|
||||
|
||||
# Model architecture
|
||||
hidden_size: int = 1024 # Transformer hidden dimension
|
||||
depth: int = 24 # Number of transformer layers
|
||||
num_heads: int = 16 # Number of attention heads
|
||||
num_domains: int = 30 # Maximum number of domain IDs
|
||||
len_soft_prompts: int = 32 # Length of soft prompt embeddings
|
||||
|
||||
# Action space
|
||||
action_mode: str = "ee6d" # Action space type (use "auto" for auto-detection)
|
||||
use_proprio: bool = True # Use proprioceptive state
|
||||
max_state_dim: int = 32 # Maximum state dimension
|
||||
max_action_dim: int = 20 # Max action dim for padding (used by "auto" mode)
|
||||
|
||||
# Vision
|
||||
num_image_views: int | None # Number of camera views
|
||||
resize_imgs_with_padding: tuple[int, int] | None # Target image size with padding
|
||||
|
||||
# Training
|
||||
num_denoising_steps: int = 10 # Flow matching denoising steps
|
||||
```
|
||||
|
||||
## Creating Custom Action Modes
|
||||
|
||||
If your robot has a unique action space, you can create a custom action mode:
|
||||
|
||||
### Step 1: Define Your Action Space
|
||||
|
||||
```python
|
||||
from lerobot.policies.xvla.action_hub import BaseActionSpace, register_action
|
||||
import torch.nn as nn
|
||||
|
||||
@register_action("my_custom_robot")
|
||||
class MyCustomActionSpace(BaseActionSpace):
|
||||
"""Custom action space for my robot."""
|
||||
|
||||
dim_action = 15 # Your robot's action dimension
|
||||
gripper_idx = (7, 14) # Gripper channel indices
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""Define your loss computation."""
|
||||
# Example: MSE for joints, BCE for grippers
|
||||
joints_loss = self.mse(pred[:, :, :7], target[:, :, :7])
|
||||
gripper_loss = self.bce(pred[:, :, self.gripper_idx],
|
||||
target[:, :, self.gripper_idx])
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Preprocess actions before training."""
|
||||
# Example: Zero out grippers in proprioception
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone() if action is not None else None
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
if action_m is not None:
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action):
|
||||
"""Post-process predictions for deployment."""
|
||||
# Example: Apply sigmoid to gripper logits
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
```
|
||||
|
||||
### Step 2: Use Your Custom Action Mode
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--policy.action_mode=my_custom_robot \
|
||||
--dataset.repo_id=YOUR_DATASET \
|
||||
--policy.path="lerobot/xvla-base" \
|
||||
...
|
||||
```
|
||||
|
||||
## Advanced Topics
|
||||
|
||||
### Multi-Camera Support
|
||||
|
||||
X-VLA supports multiple camera views through the `num_image_views` parameter:
|
||||
|
||||
```python
|
||||
# Configure for 3 camera views
|
||||
policy.num_image_views=3
|
||||
|
||||
# Add empty cameras if you have fewer physical cameras
|
||||
policy.empty_cameras=1 # Adds 1 zero-padded camera view
|
||||
```
|
||||
|
||||
### Custom Preprocessing Pipeline
|
||||
|
||||
Create a custom preprocessing pipeline for your environment:
|
||||
|
||||
```python
|
||||
from lerobot.processor import PolicyProcessorPipeline
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAImageToFloatProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
)
|
||||
|
||||
# Build custom pipeline
|
||||
preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
YourCustomProcessorStep(), # Your custom processing
|
||||
XVLAImageToFloatProcessorStep(), # Required: convert to float
|
||||
XVLAImageNetNormalizeProcessorStep(), # Required: ImageNet norm
|
||||
XVLAAddDomainIdProcessorStep(domain_id=5), # Your domain ID
|
||||
]
|
||||
)
|
||||
```
|
||||
|
||||
### Handling Different Action Dimensions
|
||||
|
||||
When your dataset has fewer action dimensions than the pretrained model:
|
||||
|
||||
**Option 1 (Recommended)**: Use `auto` action mode
|
||||
|
||||
```bash
|
||||
# Automatically detects your dataset's action dimension
|
||||
# Works with any robot without custom code
|
||||
policy.action_mode=auto
|
||||
policy.max_action_dim=20 # Match pretrained model
|
||||
```
|
||||
|
||||
**Option 2**: Use a predefined action mode with built-in padding
|
||||
|
||||
```python
|
||||
# Model expects 20D, dataset has 12D
|
||||
# Action mode handles padding internally
|
||||
action_mode = "so101_bimanual" # Pads 12 → 20
|
||||
```
|
||||
|
||||
**Option 2**: Create a custom action mode that maps dimensions explicitly
|
||||
|
||||
```python
|
||||
@register_action("my_mapped_action")
|
||||
class MappedActionSpace(BaseActionSpace):
|
||||
dim_action = 20
|
||||
REAL_DIM = 12
|
||||
|
||||
def _pad_to_model_dim(self, x):
|
||||
# Custom padding logic
|
||||
...
|
||||
```
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
### Common Issues
|
||||
|
||||
**Issue**: "Action dimension mismatch"
|
||||
|
||||
- **Solution**: Check that your `action_mode` matches your robot's action space. Create a custom action mode if needed.
|
||||
|
||||
**Issue**: "Image values outside [0, 1] range"
|
||||
|
||||
- **Solution**: Ensure images are preprocessed with `XVLAImageToFloatProcessorStep` before normalization.
|
||||
|
||||
**Issue**: "Domain ID not found"
|
||||
|
||||
- **Solution**: Make sure `XVLAAddDomainIdProcessorStep` is in your preprocessing pipeline with the correct domain_id.
|
||||
|
||||
**Issue**: "Low success rate on new embodiment"
|
||||
|
||||
- **Solution**:
|
||||
1. Verify your action_mode is correct
|
||||
2. Check that soft prompts are being trained (`train_soft_prompts=True`)
|
||||
3. Ensure proper preprocessing (ImageNet normalization, domain_id)
|
||||
4. Consider increasing training steps
|
||||
|
||||
**Issue**: "Out of memory during training"
|
||||
|
||||
- **Solution**:
|
||||
1. Reduce `chunk_size` (e.g., from 32 to 16)
|
||||
2. Enable gradient checkpointing
|
||||
3. Reduce batch size
|
||||
4. Freeze more components
|
||||
|
||||
## Citation
|
||||
|
||||
If you use X-VLA in your research, please cite:
|
||||
|
||||
```bibtex
|
||||
@article{zheng2025x,
|
||||
title = {X-VLA: Soft-Prompted Transformer as Scalable Cross-Embodiment Vision-Language-Action Model},
|
||||
author = {Zheng, Jinliang and Li, Jianxiong and Wang, Zhihao and Liu, Dongxiu and Kang, Xirui
|
||||
and Feng, Yuchun and Zheng, Yinan and Zou, Jiayin and Chen, Yilun and Zeng, Jia and others},
|
||||
journal = {arXiv preprint arXiv:2510.10274},
|
||||
year = {2025}
|
||||
}
|
||||
```
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [X-VLA Paper](https://arxiv.org/pdf/2510.10274)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Action Registry Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/action_hub.py)
|
||||
- [Processor Implementation](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/processor_xvla.py)
|
||||
- [Model Configuration](https://github.com/huggingface/lerobot/src/lerobot/policies/xvla/configuration_xvla.py)
|
||||
|
||||
## Contributing
|
||||
|
||||
We welcome contributions! If you've implemented a new action mode or processor for your robot, please consider submitting a PR to help the community.
|
||||
@@ -45,7 +45,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig):
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -34,105 +34,106 @@ from huggingface_hub import HfApi
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
def main():
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
@@ -144,3 +145,7 @@ if __name__ == "__main__":
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+86
-80
@@ -33,83 +33,68 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
robot = LeKiwiClient(robot_config)
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -118,21 +103,42 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+82
-76
@@ -34,78 +34,62 @@ RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -113,23 +97,45 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+32
-26
@@ -20,42 +20,48 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,54 +19,60 @@ import time
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+135
-127
@@ -52,125 +52,114 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -179,21 +168,40 @@ for episode_idx in range(NUM_EPISODES):
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+142
-133
@@ -50,133 +50,122 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -184,22 +173,42 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -29,72 +29,78 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -32,82 +32,90 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
def main():
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -52,126 +52,114 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -180,21 +168,40 @@ for episode_idx in range(NUM_EPISODES):
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -48,134 +48,122 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -183,22 +171,42 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -30,72 +30,78 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -32,90 +32,96 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
def main():
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,80 +19,86 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,50 +8,56 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot.send_action(action)
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -6,50 +6,56 @@ from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,81 +19,87 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,53 +8,57 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_diffusion"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,57 +11,63 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -20,6 +20,8 @@ from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
def run_learner(
|
||||
@@ -223,123 +225,123 @@ def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
def main():
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -4,59 +4,64 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
def main():
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,56 +11,62 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -0,0 +1,347 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: GR00T Locomotion with Pre-loaded Policies
|
||||
|
||||
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
|
||||
LOCOMOTION_CONTROL_DT = 0.02
|
||||
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||
)
|
||||
walk_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||
)
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession(balance_path)
|
||||
policy_walk = ort.InferenceSession(walk_path)
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
|
||||
return policy_balance, policy_walk
|
||||
|
||||
|
||||
class GrootLocomotionController:
|
||||
"""
|
||||
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Dual-policy system (Balance + Walk)
|
||||
- 29-joint observation processing
|
||||
- 15D action output (legs + waist)
|
||||
- Policy inference and motor command generation
|
||||
"""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
# GR00T-specific state
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||
self.groot_obs_single = np.zeros(86, dtype=np.float32)
|
||||
self.groot_obs_history = deque(maxlen=6)
|
||||
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# input to gr00t is 6 frames (6*86D=516)
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# express imu data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||
self.groot_obs_single[10:13] = gravity_orientation
|
||||
self.groot_obs_single[13:42] = qj_obs
|
||||
self.groot_obs_single[42:71] = dqj_obs
|
||||
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
|
||||
|
||||
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||
|
||||
# Stack all 6 frames into 516D vector
|
||||
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||
start_idx = i * 86
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
|
||||
selected_policy = (
|
||||
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
|
||||
) # balance/standing policy for small commands, walking policy for movement commands
|
||||
|
||||
# run policy inference
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# command motors
|
||||
for i in range(15):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# adapt action for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
# send action to robot
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.groot_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Only control legs, not arms (first 12 joints)
|
||||
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
|
||||
dof_size = len(default_pos)
|
||||
|
||||
# Get current lowstate
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record the current leg positions
|
||||
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||
for i in range(dof_size):
|
||||
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||
|
||||
# Move legs to default pos
|
||||
for i in range(num_step):
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# initialize gr00t locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# reset legs and start locomotion thread
|
||||
try:
|
||||
groot_controller.reset_robot()
|
||||
groot_controller.start_locomotion_thread()
|
||||
|
||||
# log status
|
||||
logger.info("Robot initialized with GR00T locomotion policies")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
+10
-4
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.2"
|
||||
version = "0.4.3"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -107,6 +107,10 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0"
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
@@ -129,6 +133,7 @@ groot = [
|
||||
"ninja>=1.11.1,<2.0.0",
|
||||
"flash-attn>=2.5.9,<3.0.0 ; sys_platform != 'darwin'"
|
||||
]
|
||||
xvla = ["lerobot[transformers-dep]"]
|
||||
hilserl = ["lerobot[transformers-dep]", "gym-hil>=0.1.13,<0.2.0", "lerobot[grpcio-dep]", "lerobot[placo-dep]"]
|
||||
|
||||
# Features
|
||||
@@ -157,6 +162,7 @@ all = [
|
||||
"lerobot[pi]",
|
||||
"lerobot[smolvla]",
|
||||
# "lerobot[groot]", TODO(Steven): Gr00t requires specific installation instructions for flash-attn
|
||||
"lerobot[xvla]",
|
||||
"lerobot[hilserl]",
|
||||
"lerobot[async]",
|
||||
"lerobot[dev]",
|
||||
@@ -356,9 +362,9 @@ ignore_errors = false
|
||||
# module = "lerobot.async_inference.*"
|
||||
# ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.transport.*"
|
||||
# ignore_errors = false
|
||||
[[tool.mypy.overrides]]
|
||||
module = "lerobot.transport.*"
|
||||
ignore_errors = false
|
||||
|
||||
# [[tool.mypy.overrides]]
|
||||
# module = "lerobot.scripts.*"
|
||||
|
||||
@@ -136,21 +136,40 @@ def update_meta_data(
|
||||
df["_orig_chunk"] = df[orig_chunk_col].copy()
|
||||
df["_orig_file"] = df[orig_file_col].copy()
|
||||
|
||||
# Update chunk and file indices to point to destination
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
|
||||
# Apply per-source-file timestamp offsets
|
||||
# Get mappings for this video key
|
||||
src_to_offset = video_idx.get("src_to_offset", {})
|
||||
if src_to_offset:
|
||||
# Apply offset based on original source file
|
||||
src_to_dst = video_idx.get("src_to_dst", {})
|
||||
|
||||
# Apply per-source-file mappings
|
||||
if src_to_dst:
|
||||
# Map each episode to its correct destination file and apply offset
|
||||
for idx in df.index:
|
||||
src_key = (df.at[idx, "_orig_chunk"], df.at[idx, "_orig_file"])
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
|
||||
# Get destination chunk/file for this source file
|
||||
dst_chunk, dst_file = src_to_dst.get(src_key, (video_idx["chunk"], video_idx["file"]))
|
||||
df.at[idx, orig_chunk_col] = dst_chunk
|
||||
df.at[idx, orig_file_col] = dst_file
|
||||
|
||||
# Apply timestamp offset
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
elif src_to_offset:
|
||||
# Fallback: use same destination for all, but apply per-file offsets
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
for idx in df.index:
|
||||
# Convert to Python int to avoid numpy type mismatch in dict lookup
|
||||
src_key = (int(df.at[idx, "_orig_chunk"]), int(df.at[idx, "_orig_file"]))
|
||||
offset = src_to_offset.get(src_key, 0)
|
||||
df.at[idx, f"videos/{key}/from_timestamp"] += offset
|
||||
df.at[idx, f"videos/{key}/to_timestamp"] += offset
|
||||
else:
|
||||
# Fallback to simple offset (for backward compatibility)
|
||||
df[orig_chunk_col] = video_idx["chunk"]
|
||||
df[orig_file_col] = video_idx["file"]
|
||||
df[f"videos/{key}/from_timestamp"] = (
|
||||
df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
|
||||
)
|
||||
@@ -268,6 +287,12 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
videos_idx[key]["episode_duration"] = 0
|
||||
# Track offset for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_offset"] = {}
|
||||
# Track destination (chunk, file) for each source (chunk, file) pair
|
||||
videos_idx[key]["src_to_dst"] = {}
|
||||
# Initialize dst_file_durations if not present
|
||||
# dst_file_durations tracks duration of each destination file
|
||||
if "dst_file_durations" not in videos_idx[key]:
|
||||
videos_idx[key]["dst_file_durations"] = {}
|
||||
|
||||
for key, video_idx in videos_idx.items():
|
||||
unique_chunk_file_pairs = {
|
||||
@@ -282,9 +307,13 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
|
||||
chunk_idx = video_idx["chunk"]
|
||||
file_idx = video_idx["file"]
|
||||
current_offset = video_idx["latest_duration"]
|
||||
dst_file_durations = video_idx["dst_file_durations"]
|
||||
|
||||
for src_chunk_idx, src_file_idx in unique_chunk_file_pairs:
|
||||
# Convert to Python int to ensure consistent dict keys
|
||||
src_chunk_idx = int(src_chunk_idx)
|
||||
src_file_idx = int(src_file_idx)
|
||||
|
||||
src_path = src_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=src_chunk_idx,
|
||||
@@ -298,14 +327,17 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
|
||||
src_duration = get_video_duration_in_s(src_path)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
|
||||
if not dst_path.exists():
|
||||
# Store offset before incrementing
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
# New destination file: offset is 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Track duration of this destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
current_offset += src_duration
|
||||
continue
|
||||
|
||||
# Check file sizes before appending
|
||||
@@ -313,10 +345,11 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
dst_size = get_file_size_in_mb(dst_path)
|
||||
|
||||
if dst_size + src_size >= video_files_size_in_mb:
|
||||
# Rotate to a new file, this source becomes start of new destination
|
||||
# So its offset should be 0
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
# Rotate to a new file - offset is 0
|
||||
chunk_idx, file_idx = update_chunk_file_indices(chunk_idx, file_idx, chunk_size)
|
||||
dst_key = (chunk_idx, file_idx)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = 0
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
dst_path = dst_meta.root / DEFAULT_VIDEO_PATH.format(
|
||||
video_key=key,
|
||||
chunk_index=chunk_idx,
|
||||
@@ -324,16 +357,20 @@ def aggregate_videos(src_meta, dst_meta, videos_idx, video_files_size_in_mb, chu
|
||||
)
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
shutil.copy(str(src_path), str(dst_path))
|
||||
# Reset offset for next file
|
||||
current_offset = src_duration
|
||||
# Track duration of this new destination file
|
||||
dst_file_durations[dst_key] = src_duration
|
||||
else:
|
||||
# Append to existing video file - use current accumulated offset
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_offset
|
||||
# Append to existing destination file
|
||||
# Offset is the current duration of this destination file
|
||||
current_dst_duration = dst_file_durations.get(dst_key, 0)
|
||||
videos_idx[key]["src_to_offset"][(src_chunk_idx, src_file_idx)] = current_dst_duration
|
||||
videos_idx[key]["src_to_dst"][(src_chunk_idx, src_file_idx)] = dst_key
|
||||
concatenate_video_files(
|
||||
[dst_path, src_path],
|
||||
dst_path,
|
||||
)
|
||||
current_offset += src_duration
|
||||
# Update duration of this destination file
|
||||
dst_file_durations[dst_key] = current_dst_duration + src_duration
|
||||
|
||||
videos_idx[key]["episode_duration"] += src_duration
|
||||
|
||||
|
||||
@@ -110,8 +110,8 @@ def worker_thread_loop(queue: queue.Queue):
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
image_array, fpath, compress_level = item
|
||||
write_image(image_array, fpath, compress_level)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
@@ -169,11 +169,13 @@ class AsyncImageWriter:
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
self.queue.put((image, fpath, compress_level))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
@@ -539,6 +540,15 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1071,6 +1081,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
# TODO(Steven): consider move this to utils
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
@@ -1080,13 +1091,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath)
|
||||
write_image(image, fpath, compress_level=compress_level)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
@@ -1124,14 +1137,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
@@ -1143,6 +1161,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
|
||||
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
|
||||
"""
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
|
||||
@@ -1179,8 +1199,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
# num_cameras * num_threads = (total_cpu -1)
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
@@ -1345,9 +1397,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
def _save_episode_video(
|
||||
self,
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
temp_path: Path | None = None,
|
||||
) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
if temp_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
else:
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
@@ -1465,11 +1526,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
||||
@@ -49,7 +49,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
|
||||
@@ -311,6 +311,7 @@ def encode_video_frames(
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
@@ -359,6 +360,9 @@ def encode_video_frames(
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
||||
|
||||
@@ -245,7 +245,7 @@ class HILSerlRobotEnvConfig(EnvConfig):
|
||||
class LiberoEnv(EnvConfig):
|
||||
task: str = "libero_10" # can also choose libero_spatial, libero_object, etc.
|
||||
fps: int = 30
|
||||
episode_length: int = 520
|
||||
episode_length: int | None = None
|
||||
obs_type: str = "pixels_agent_pos"
|
||||
render_mode: str = "rgb_array"
|
||||
camera_name: str = "agentview_image,robot0_eye_in_hand_image"
|
||||
@@ -272,6 +272,7 @@ class LiberoEnv(EnvConfig):
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
control_mode: str = "relative" # or "absolute"
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
|
||||
@@ -19,8 +19,10 @@ from typing import Any
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
@@ -39,6 +41,7 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
policy_cfg: PreTrainedConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -61,6 +64,10 @@ def make_env_pre_post_processors(
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
if isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import make_xvla_libero_pre_post_processors
|
||||
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
@@ -136,6 +143,8 @@ def make_env(
|
||||
init_states=cfg.init_states,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
@@ -80,10 +80,7 @@ def get_libero_dummy_action():
|
||||
return [0, 0, 0, 0, 0, 0, -1]
|
||||
|
||||
|
||||
OBS_STATE_DIM = 8
|
||||
ACTION_DIM = 7
|
||||
AGENT_POS_LOW = -1000.0
|
||||
AGENT_POS_HIGH = 1000.0
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
@@ -103,6 +100,7 @@ class LiberoEnv(gym.Env):
|
||||
task_suite: Any,
|
||||
task_id: int,
|
||||
task_suite_name: str,
|
||||
episode_length: int | None = None,
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
obs_type: str = "pixels",
|
||||
render_mode: str = "rgb_array",
|
||||
@@ -114,6 +112,7 @@ class LiberoEnv(gym.Env):
|
||||
episode_index: int = 0,
|
||||
camera_name_mapping: dict[str, str] | None = None,
|
||||
num_steps_wait: int = 10,
|
||||
control_mode: str = "relative",
|
||||
):
|
||||
super().__init__()
|
||||
self.task_id = task_id
|
||||
@@ -141,14 +140,19 @@ class LiberoEnv(gym.Env):
|
||||
self.camera_name_mapping = camera_name_mapping
|
||||
self.num_steps_wait = num_steps_wait
|
||||
self.episode_index = episode_index
|
||||
self.episode_length = episode_length
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
self._env = self._make_envs_task(task_suite, self.task_id)
|
||||
default_steps = 500
|
||||
self._max_episode_steps = TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
|
||||
self._max_episode_steps = (
|
||||
TASK_SUITE_MAX_STEPS.get(task_suite_name, default_steps)
|
||||
if self.episode_length is None
|
||||
else self.episode_length
|
||||
)
|
||||
self.control_mode = control_mode
|
||||
images = {}
|
||||
for cam in self.camera_name:
|
||||
images[self.camera_name_mapping[cam]] = spaces.Box(
|
||||
@@ -296,6 +300,15 @@ class LiberoEnv(gym.Env):
|
||||
# Increasing this value can improve determinism and reproducibility across resets.
|
||||
for _ in range(self.num_steps_wait):
|
||||
raw_obs, _, _, _ = self._env.step(get_libero_dummy_action())
|
||||
|
||||
if self.control_mode == "absolute":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = False
|
||||
elif self.control_mode == "relative":
|
||||
for robot in self._env.robots:
|
||||
robot.controller.use_delta = True
|
||||
else:
|
||||
raise ValueError(f"Invalid control mode: {self.control_mode}")
|
||||
observation = self._format_raw_obs(raw_obs)
|
||||
info = {"is_success": False}
|
||||
return observation, info
|
||||
@@ -341,8 +354,10 @@ def _make_env_fns(
|
||||
task_id: int,
|
||||
n_envs: int,
|
||||
camera_names: list[str],
|
||||
episode_length: int | None,
|
||||
init_states: bool,
|
||||
gym_kwargs: Mapping[str, Any],
|
||||
control_mode: str,
|
||||
) -> list[Callable[[], LiberoEnv]]:
|
||||
"""Build n_envs factory callables for a single (suite, task_id)."""
|
||||
|
||||
@@ -354,7 +369,9 @@ def _make_env_fns(
|
||||
task_suite_name=suite_name,
|
||||
camera_name=camera_names,
|
||||
init_states=init_states,
|
||||
episode_length=episode_length,
|
||||
episode_index=episode_index,
|
||||
control_mode=control_mode,
|
||||
**local_kwargs,
|
||||
)
|
||||
|
||||
@@ -374,6 +391,8 @@ def create_libero_envs(
|
||||
camera_name: str | Sequence[str] = "agentview_image,robot0_eye_in_hand_image",
|
||||
init_states: bool = True,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
control_mode: str = "relative",
|
||||
episode_length: int | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""
|
||||
Create vectorized LIBERO environments with a consistent return shape.
|
||||
@@ -415,12 +434,14 @@ def create_libero_envs(
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
suite=suite,
|
||||
episode_length=episode_length,
|
||||
suite_name=suite_name,
|
||||
task_id=tid,
|
||||
n_envs=n_envs,
|
||||
camera_names=camera_names,
|
||||
init_states=init_states,
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
@@ -81,10 +81,14 @@ class AdamWConfig(OptimizerConfig):
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 1e-2
|
||||
grad_clip_norm: float = 10.0
|
||||
fused: bool = False
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
kwargs = asdict(self)
|
||||
kwargs.pop("grad_clip_norm")
|
||||
# Fused optimizer only works on CUDA
|
||||
if kwargs.get("fused") and not torch.cuda.is_available():
|
||||
kwargs["fused"] = False
|
||||
return torch.optim.AdamW(params, **kwargs)
|
||||
|
||||
|
||||
@@ -104,6 +108,107 @@ class SGDConfig(OptimizerConfig):
|
||||
return torch.optim.SGD(params, **kwargs)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("xvla-adamw")
|
||||
@dataclass
|
||||
class XVLAAdamWConfig(OptimizerConfig):
|
||||
"""Custom AdamW optimizer for XVLA with differential learning rates.
|
||||
|
||||
The Vision-Language Model (VLM) is trained with 1/10 of the base learning rate
|
||||
for stable optimization, while all other components use the full LR.
|
||||
|
||||
This LR ratio is crucial for achieving strong and stable finetuning performance.
|
||||
|
||||
Soft-prompts can optionally use a separate learning rate with warm-up support.
|
||||
Set `soft_prompt_lr_scale` to a value < 1.0 (e.g., 0.1) to start soft-prompts
|
||||
at a lower LR. Combine with a warmup scheduler for optimal results.
|
||||
|
||||
Note:
|
||||
Completely matching official reported performance may require an additional
|
||||
warm-up LR schedule for soft-prompts, which can bring minor improvements.
|
||||
When `soft_prompt_warmup_lr_scale` is set, soft-prompts start at
|
||||
`lr * soft_prompt_warmup_lr_scale` and should be warmed up via the scheduler.
|
||||
|
||||
Parameter Groups:
|
||||
- Group 0 (vlm): VLM parameters at lr * 0.1, weight_decay * 0.1
|
||||
- Group 1 (soft_prompts): Soft-prompt parameters at lr * soft_prompt_lr_scale
|
||||
- Group 2 (other): All other parameters at full lr
|
||||
"""
|
||||
|
||||
lr: float = 1e-4
|
||||
betas: tuple[float, float] = (0.9, 0.99)
|
||||
eps: float = 1e-8
|
||||
weight_decay: float = 0.0
|
||||
grad_clip_norm: float = 10.0
|
||||
# Soft-prompt specific settings
|
||||
soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR (1.0 = same as base LR)
|
||||
soft_prompt_warmup_lr_scale: float | None = None # If set, start soft-prompts at this scale (e.g., 0.01)
|
||||
|
||||
def build(self, params: dict) -> torch.optim.Optimizer:
|
||||
"""
|
||||
Build AdamW optimizer with differential learning rates.
|
||||
|
||||
Expects `named_parameters()` as input (dict of name -> param).
|
||||
Applies:
|
||||
- lr * 0.1 for all VLM-related parameters
|
||||
- lr * soft_prompt_lr_scale for soft-prompt parameters (with optional warmup)
|
||||
- full lr for all other parameters
|
||||
|
||||
Args:
|
||||
params: Dictionary of parameter names to parameters (from named_parameters())
|
||||
|
||||
Returns:
|
||||
AdamW optimizer with parameter groups for VLM, soft-prompts, and other components
|
||||
"""
|
||||
assert isinstance(params, dict), "Custom LR optimizer requires `named_parameters()` as inputs."
|
||||
|
||||
vlm_group, soft_prompt_group, other_group = [], [], []
|
||||
for name, p in params.items():
|
||||
if not p.requires_grad:
|
||||
continue
|
||||
if "vlm" in name.lower():
|
||||
vlm_group.append(p)
|
||||
elif "soft_prompt" in name.lower():
|
||||
soft_prompt_group.append(p)
|
||||
else:
|
||||
other_group.append(p)
|
||||
|
||||
# Determine soft-prompt LR
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_lr_scale
|
||||
if self.soft_prompt_warmup_lr_scale is not None:
|
||||
# Start at warmup scale, scheduler will warm up to soft_prompt_lr
|
||||
soft_prompt_lr = self.lr * self.soft_prompt_warmup_lr_scale
|
||||
|
||||
param_groups = [
|
||||
{
|
||||
"params": vlm_group,
|
||||
"lr": self.lr * 0.1,
|
||||
"weight_decay": self.weight_decay * 0.1,
|
||||
"name": "vlm",
|
||||
},
|
||||
{
|
||||
"params": soft_prompt_group,
|
||||
"lr": soft_prompt_lr,
|
||||
"weight_decay": self.weight_decay,
|
||||
"name": "soft_prompts",
|
||||
},
|
||||
{
|
||||
"params": other_group,
|
||||
"lr": self.lr,
|
||||
"weight_decay": self.weight_decay,
|
||||
"name": "other",
|
||||
},
|
||||
]
|
||||
|
||||
# Filter out empty groups
|
||||
param_groups = [g for g in param_groups if len(g["params"]) > 0]
|
||||
|
||||
return torch.optim.AdamW(
|
||||
param_groups,
|
||||
betas=self.betas,
|
||||
eps=self.eps,
|
||||
)
|
||||
|
||||
|
||||
@OptimizerConfig.register_subclass("multi_adam")
|
||||
@dataclass
|
||||
class MultiAdamConfig(OptimizerConfig):
|
||||
|
||||
@@ -21,6 +21,7 @@ from .smolvla.configuration_smolvla import SmolVLAConfig as SmolVLAConfig
|
||||
from .smolvla.processor_smolvla import SmolVLANewLineProcessor
|
||||
from .tdmpc.configuration_tdmpc import TDMPCConfig as TDMPCConfig
|
||||
from .vqbet.configuration_vqbet import VQBeTConfig as VQBeTConfig
|
||||
from .xvla.configuration_xvla import XVLAConfig as XVLAConfig
|
||||
|
||||
__all__ = [
|
||||
"ACTConfig",
|
||||
@@ -31,4 +32,5 @@ __all__ = [
|
||||
"TDMPCConfig",
|
||||
"VQBeTConfig",
|
||||
"GrootConfig",
|
||||
"XVLAConfig",
|
||||
]
|
||||
|
||||
@@ -136,6 +136,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
optimizer_lr: float = 1e-5
|
||||
optimizer_weight_decay: float = 1e-4
|
||||
optimizer_lr_backbone: float = 1e-5
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
@@ -164,6 +165,7 @@ class ACTConfig(PreTrainedConfig):
|
||||
return AdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> None:
|
||||
|
||||
@@ -16,6 +16,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import importlib
|
||||
import logging
|
||||
from typing import Any, TypedDict
|
||||
|
||||
@@ -40,6 +41,7 @@ from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
from lerobot.policies.utils import validate_visual_features_consistency
|
||||
from lerobot.policies.vqbet.configuration_vqbet import VQBeTConfig
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.processor.converters import (
|
||||
batch_to_transition,
|
||||
@@ -107,8 +109,15 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
return GrootPolicy
|
||||
elif name == "xvla":
|
||||
from lerobot.policies.xvla.modeling_xvla import XVLAPolicy
|
||||
|
||||
return XVLAPolicy
|
||||
else:
|
||||
raise NotImplementedError(f"Policy with name {name} is not implemented.")
|
||||
try:
|
||||
return _get_policy_cls_from_policy_name(name=name)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{name}' is not available.") from e
|
||||
|
||||
|
||||
def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
@@ -150,8 +159,14 @@ def make_policy_config(policy_type: str, **kwargs) -> PreTrainedConfig:
|
||||
return RewardClassifierConfig(**kwargs)
|
||||
elif policy_type == "groot":
|
||||
return GrootConfig(**kwargs)
|
||||
elif policy_type == "xvla":
|
||||
return XVLAConfig(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.")
|
||||
try:
|
||||
config_cls = PreTrainedConfig.get_choice_class(policy_type)
|
||||
return config_cls(**kwargs)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Policy type '{policy_type}' is not available.") from e
|
||||
|
||||
|
||||
class ProcessorConfigKwargs(TypedDict, total=False):
|
||||
@@ -329,9 +344,24 @@ def make_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
elif isinstance(policy_cfg, XVLAConfig):
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
make_xvla_pre_post_processors,
|
||||
)
|
||||
|
||||
processors = make_xvla_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
else:
|
||||
raise NotImplementedError(f"Processor for policy type '{policy_cfg.type}' is not implemented.")
|
||||
try:
|
||||
processors = _make_processors_from_policy_config(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Processor for policy type '{policy_cfg.type}' is not implemented.") from e
|
||||
|
||||
return processors
|
||||
|
||||
@@ -400,8 +430,7 @@ def make_policy(
|
||||
raise ValueError("env_cfg cannot be None when ds_meta is not provided")
|
||||
features = env_to_policy_features(env_cfg)
|
||||
|
||||
if not cfg.output_features:
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
cfg.output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
@@ -425,3 +454,65 @@ def make_policy(
|
||||
# TODO: (jadechoghari) - add a check_state(cfg, features) and check_action(cfg, features)
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
def _get_policy_cls_from_policy_name(name: str) -> type[PreTrainedConfig]:
|
||||
"""Get policy class from its registered name using dynamic imports.
|
||||
|
||||
This is used as a helper function to import policies from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
name: The name of the policy.
|
||||
Returns:
|
||||
The policy class corresponding to the given name.
|
||||
"""
|
||||
if name not in PreTrainedConfig.get_known_choices():
|
||||
raise ValueError(
|
||||
f"Unknown policy name '{name}'. Available policies: {PreTrainedConfig.get_known_choices()}"
|
||||
)
|
||||
|
||||
config_cls = PreTrainedConfig.get_choice_class(name)
|
||||
config_cls_name = config_cls.__name__
|
||||
|
||||
model_name = config_cls_name.removesuffix("Config") # e.g., DiffusionConfig -> Diffusion
|
||||
if model_name == config_cls_name:
|
||||
raise ValueError(
|
||||
f"The config class name '{config_cls_name}' does not follow the expected naming convention."
|
||||
f"Make sure it ends with 'Config'!"
|
||||
)
|
||||
cls_name = model_name + "Policy" # e.g., DiffusionConfig -> DiffusionPolicy
|
||||
module_path = config_cls.__module__.replace(
|
||||
"configuration_", "modeling_"
|
||||
) # e.g., configuration_diffusion -> modeling_diffusion
|
||||
|
||||
module = importlib.import_module(module_path)
|
||||
policy_cls = getattr(module, cls_name)
|
||||
return policy_cls
|
||||
|
||||
|
||||
def _make_processors_from_policy_config(
|
||||
config: PreTrainedConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[Any, Any]:
|
||||
"""Create pre- and post-processors from a policy configuration using dynamic imports.
|
||||
|
||||
This is used as a helper function to import processor factories from 3rd party lerobot plugins.
|
||||
|
||||
Args:
|
||||
config: The policy configuration object.
|
||||
dataset_stats: Dataset statistics for normalization.
|
||||
Returns:
|
||||
A tuple containing the input (pre-processor) and output (post-processor) pipelines.
|
||||
"""
|
||||
|
||||
policy_type = config.type
|
||||
function_name = f"make_{policy_type}_pre_post_processors"
|
||||
module_path = config.__class__.__module__.replace(
|
||||
"configuration_", "processor_"
|
||||
) # e.g., configuration_diffusion -> processor_diffusion
|
||||
logging.debug(
|
||||
f"Instantiating pre/post processors using function '{function_name}' from module '{module_path}'"
|
||||
)
|
||||
module = importlib.import_module(module_path)
|
||||
function = getattr(module, function_name)
|
||||
return function(config, dataset_stats=dataset_stats)
|
||||
|
||||
@@ -94,6 +94,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
optimizer_betas: tuple[float, float] = (0.95, 0.999)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-5
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
warmup_ratio: float = 0.05
|
||||
use_bf16: bool = True
|
||||
|
||||
@@ -174,6 +175,7 @@ class GrootConfig(PreTrainedConfig):
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
|
||||
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi0")
|
||||
@dataclass
|
||||
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
@@ -69,6 +74,7 @@ class PI0Config(PreTrainedConfig):
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW``
|
||||
@@ -136,6 +142,7 @@ class PI0Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, False],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -812,16 +820,13 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -846,15 +851,11 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
DEFAULT_IMAGE_SIZE = 224
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@dataclass
|
||||
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
image_resolution: tuple[int, int] = (
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
DEFAULT_IMAGE_SIZE,
|
||||
) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
empty_cameras: int = 0
|
||||
@@ -69,6 +74,7 @@ class PI05Config(PreTrainedConfig):
|
||||
gradient_checkpointing: bool = False # Enable gradient checkpointing for memory optimization
|
||||
compile_model: bool = False # Whether to use torch.compile for model optimization
|
||||
compile_mode: str = "max-autotune" # Torch compile mode
|
||||
optimizer_fused: bool = False # Use CUDA fused AdamW kernel
|
||||
device: str | None = None # Device to use for the model (None = auto-detect)
|
||||
|
||||
# Optimizer settings: see openpi `AdamW`
|
||||
@@ -136,6 +142,7 @@ class PI05Config(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -41,7 +41,7 @@ else:
|
||||
PaliGemmaForConditionalGeneration = None
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
@@ -336,6 +336,7 @@ class PaliGemmaWithExpertModel(
|
||||
action_expert_config,
|
||||
use_adarms=None,
|
||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||
):
|
||||
if use_adarms is None:
|
||||
use_adarms = [False, False]
|
||||
@@ -355,6 +356,7 @@ class PaliGemmaWithExpertModel(
|
||||
vlm_config_hf.text_config.vocab_size = 257152
|
||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||
vlm_config_hf.vision_config.image_size = image_size
|
||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||
vlm_config_hf.vision_config.projection_dim = 2048
|
||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||
@@ -518,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
|
||||
if config.image_resolution[0] != config.image_resolution[1]:
|
||||
raise ValueError(
|
||||
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||
)
|
||||
|
||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||
paligemma_config,
|
||||
action_expert_config,
|
||||
use_adarms=[False, True],
|
||||
precision=config.dtype,
|
||||
image_size=config.image_resolution[0],
|
||||
)
|
||||
|
||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||
@@ -538,6 +546,8 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
@@ -785,16 +795,13 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
)
|
||||
|
||||
dt = -1.0 / num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
@@ -818,15 +825,11 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -79,6 +79,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 1e-10
|
||||
optimizer_grad_clip_norm: float = 10
|
||||
optimizer_fused: bool = False
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
@@ -136,6 +137,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
fused=self.optimizer_fused,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
|
||||
@@ -783,18 +783,15 @@ class VLAFlowMatching(nn.Module):
|
||||
use_cache=self.config.use_cache,
|
||||
fill_kv_cache=True,
|
||||
)
|
||||
dt = -1.0 / self.config.num_steps
|
||||
dt = torch.tensor(dt, dtype=torch.float32, device=device)
|
||||
num_steps = self.config.num_steps
|
||||
dt = -1.0 / num_steps
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
for step in range(num_steps):
|
||||
time = 1.0 + step * dt
|
||||
time_tensor = torch.tensor(time, dtype=torch.float32, device=device).expand(bsize)
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=time_tensor):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
@@ -818,15 +815,11 @@ class VLAFlowMatching(nn.Module):
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -0,0 +1,6 @@
|
||||
# register the processor steps
|
||||
from lerobot.policies.xvla.processor_xvla import (
|
||||
XVLAAddDomainIdProcessorStep,
|
||||
XVLAImageNetNormalizeProcessorStep,
|
||||
XVLAImageToFloatProcessorStep,
|
||||
)
|
||||
@@ -0,0 +1,588 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 2toINF and HuggingFace Inc. (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterable
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
# =============================================================================
|
||||
# Registry
|
||||
# =============================================================================
|
||||
ACTION_REGISTRY: dict[str, type[BaseActionSpace]] = {}
|
||||
|
||||
|
||||
def register_action(name: str):
|
||||
"""Decorator for registering a new action space."""
|
||||
|
||||
def _wrap(cls):
|
||||
key = name.lower()
|
||||
if key in ACTION_REGISTRY:
|
||||
raise KeyError(f"ActionSpace '{key}' already registered -> {ACTION_REGISTRY[key]}")
|
||||
ACTION_REGISTRY[key] = cls
|
||||
cls.name = key
|
||||
return cls
|
||||
|
||||
return _wrap
|
||||
|
||||
|
||||
def build_action_space(name: str, **kwargs) -> BaseActionSpace:
|
||||
"""Instantiate a registered action space by name."""
|
||||
key = name.lower()
|
||||
if key not in ACTION_REGISTRY:
|
||||
raise KeyError(f"Unknown action space '{name}'. Available: {list(ACTION_REGISTRY.keys())}")
|
||||
return ACTION_REGISTRY[key](**kwargs)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Base class
|
||||
# =============================================================================
|
||||
class BaseActionSpace(nn.Module):
|
||||
"""
|
||||
Abstract base class for all action-space definitions.
|
||||
|
||||
Each subclass defines:
|
||||
- `dim_action`: dimension of the action vector.
|
||||
- `gripper_idx`: indices of gripper channels.
|
||||
- `compute_loss(pred, target)`: supervised loss for this space.
|
||||
- `preprocess(proprio, action, mode)`: pre-step modifications.
|
||||
- `postprocess(action)`: post-step corrections (e.g. apply sigmoid).
|
||||
"""
|
||||
|
||||
name: str = "base"
|
||||
dim_action: int = 0
|
||||
gripper_idx: tuple[int, ...] = ()
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Core supervised loss
|
||||
# ---------------------------------------------------------------------
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
raise NotImplementedError
|
||||
|
||||
def forward(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""Alias for compute_loss."""
|
||||
return self.compute_loss(pred, target)
|
||||
|
||||
# ---------------------------------------------------------------------
|
||||
# Space-level hooks
|
||||
# ---------------------------------------------------------------------
|
||||
def preprocess(
|
||||
self,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
mode: str = "train",
|
||||
) -> tuple[torch.Tensor, torch.Tensor]:
|
||||
"""Default: return unchanged."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Default: return unchanged."""
|
||||
return action
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Utilities
|
||||
# =============================================================================
|
||||
def _ensure_indices_valid(dim_action: int, idx: Iterable[int], name: str) -> None:
|
||||
bad = [i for i in idx if i < 0 or i >= dim_action]
|
||||
if bad:
|
||||
raise IndexError(f"{name} contains out-of-range indices {bad} for action dim dim_action={dim_action}")
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Implementations
|
||||
# =============================================================================
|
||||
@register_action("ee6d")
|
||||
class EE6DActionSpace(BaseActionSpace):
|
||||
"""End-effector layout with xyz, 6D rotation, and gripper channels."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 1.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape, "pred/target shapes must match"
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
# Gripper BCE
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
# XYZ position
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
|
||||
# Rotation 6D
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("joint")
|
||||
class JointActionSpace(BaseActionSpace):
|
||||
"""Joint-space layout with joints + gripper only."""
|
||||
|
||||
dim_action = 14
|
||||
gripper_idx = (6, 13)
|
||||
GRIPPER_SCALE = 0.1
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
g_losses = [self.bce(pred[:, :, gi], target[:, :, gi]) for gi in self.gripper_idx]
|
||||
gripper_loss = sum(g_losses) / len(self.gripper_idx) * self.GRIPPER_SCALE
|
||||
|
||||
joints_idx = tuple(i for i in range(action_dim) if i not in set(self.gripper_idx))
|
||||
joints_loss = self.mse(pred[:, :, joints_idx], target[:, :, joints_idx]) * self.JOINTS_SCALE
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""Zero-out gripper channels in proprio/action."""
|
||||
proprio_m = proprio.clone()
|
||||
action_m = action.clone()
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""Apply sigmoid to gripper logits."""
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
return action
|
||||
|
||||
|
||||
@register_action("agibot_ee6d")
|
||||
class AGIBOTEE6DActionSpace(BaseActionSpace):
|
||||
"""AGI-bot variant of EE6DActionSpace using MSE for all components."""
|
||||
|
||||
dim_action = 20
|
||||
gripper_idx = (9, 19)
|
||||
GRIPPER_SCALE = 10.0
|
||||
XYZ_SCALE = 500.0
|
||||
ROT_SCALE = 10.0
|
||||
POS_IDX_1 = (0, 1, 2)
|
||||
POS_IDX_2 = (10, 11, 12)
|
||||
ROT_IDX_1 = (3, 4, 5, 6, 7, 8)
|
||||
ROT_IDX_2 = (13, 14, 15, 16, 17, 18)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
assert pred.shape == target.shape
|
||||
batch_size, seq_len, action_dim = pred.shape
|
||||
_ensure_indices_valid(action_dim, self.gripper_idx, "gripper_idx")
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(pred[:, :, self.gripper_idx], target[:, :, self.gripper_idx]) * self.GRIPPER_SCALE
|
||||
)
|
||||
pos_loss = (
|
||||
self.mse(pred[:, :, self.POS_IDX_1], target[:, :, self.POS_IDX_1])
|
||||
+ self.mse(pred[:, :, self.POS_IDX_2], target[:, :, self.POS_IDX_2])
|
||||
) * self.XYZ_SCALE
|
||||
rot_loss = (
|
||||
self.mse(pred[:, :, self.ROT_IDX_1], target[:, :, self.ROT_IDX_1])
|
||||
+ self.mse(pred[:, :, self.ROT_IDX_2], target[:, :, self.ROT_IDX_2])
|
||||
) * self.ROT_SCALE
|
||||
|
||||
return {
|
||||
"position_loss": pos_loss,
|
||||
"rotate6D_loss": rot_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""No preprocessing applied in AGIBOT variant."""
|
||||
return proprio, action
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""AGIBOT does not postprocess."""
|
||||
return action
|
||||
|
||||
|
||||
@register_action("franka_joint7")
|
||||
class FrankaJoint7ActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Franka Panda joint-space: 7 joints, with gripper.
|
||||
|
||||
- Real robot action dim: 7
|
||||
- Model-facing dim: 20 (padded with zeros)
|
||||
compatible with pretrained VLA models expecting 20D.
|
||||
"""
|
||||
|
||||
dim_action = 20 # model dimension
|
||||
REAL_DIM = 7 # actual Franka joints
|
||||
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad 7 → 20 dims (zeros for the dummy channels)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.REAL_DIM:
|
||||
raise ValueError(
|
||||
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||
)
|
||||
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM] # 13 zeros
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Trim model output 20 → 7 dims."""
|
||||
return x[..., : self.REAL_DIM]
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""
|
||||
pred : [B, T, 20]
|
||||
target : [B, T, 7] or [B, T, 20]
|
||||
|
||||
Only compute MSE on the first 7 dims.
|
||||
"""
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
|
||||
assert pred.shape == target.shape
|
||||
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, : self.REAL_DIM], # use only the first 7 joints
|
||||
target[:, :, : self.REAL_DIM],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""
|
||||
During training:
|
||||
- Pad [7] → [20]
|
||||
"""
|
||||
return proprio, self._pad_to_model_dim(action)
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
After model prediction:
|
||||
- Trim [20] → [7] for real robot control.
|
||||
"""
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
@register_action("auto")
|
||||
class AutoActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Auto-detecting action space that adapts to any action dimension.
|
||||
|
||||
- Auto-detects the real action dimension from the policy feature
|
||||
- Model outputs max_dim for compatibility with pretrained models
|
||||
- Loss is computed only on the first real_dim dimensions
|
||||
- Postprocess trims output back to real_dim
|
||||
|
||||
Args:
|
||||
real_dim: The actual action dimension from the dataset/policy feature
|
||||
max_dim: The model's output dimension for pretrained VLA compatibility
|
||||
"""
|
||||
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
def __init__(self, real_dim: int, max_dim: int):
|
||||
super().__init__()
|
||||
self.real_dim = real_dim
|
||||
self.dim_action = max_dim # Model-facing dimension
|
||||
self.mse = nn.MSELoss()
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Pad real_dim → max_dim (zeros for the dummy channels)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.real_dim:
|
||||
# If dimension doesn't match either, pad/trim to real_dim first
|
||||
if x.size(-1) < self.real_dim:
|
||||
pad_shape = list(x.shape[:-1]) + [self.real_dim - x.size(-1)]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
x = torch.cat([x, pad], dim=-1)
|
||||
else:
|
||||
x = x[..., : self.real_dim]
|
||||
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.real_dim]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Trim model output max_dim → real_dim."""
|
||||
return x[..., : self.real_dim]
|
||||
|
||||
def compute_loss(self, pred: torch.Tensor, target: torch.Tensor) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Compute loss only on the first real_dim dimensions.
|
||||
|
||||
pred: [B, T, max_dim] from the model
|
||||
target: [B, T, real_dim] or [B, T, max_dim]
|
||||
|
||||
Loss = MSE(pred[:,:,:real_dim], target[:,:,:real_dim])
|
||||
"""
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
assert pred.shape == target.shape, f"Shape mismatch: pred {pred.shape} vs target {target.shape}"
|
||||
|
||||
# only compute loss on the real dimensions
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, : self.real_dim],
|
||||
target[:, :, : self.real_dim],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
return {"joints_loss": joints_loss}
|
||||
|
||||
def preprocess(self, proprio: torch.Tensor, action: torch.Tensor, mode: str = "train"):
|
||||
"""
|
||||
Pad action from real_dim to max_dim for the model.
|
||||
"""
|
||||
return proprio, self._pad_to_model_dim(action)
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Trim model output from max_dim to real_dim for real robot control.
|
||||
"""
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
@register_action("so101_bimanual")
|
||||
class BimanualSO101ActionSpace(BaseActionSpace):
|
||||
"""
|
||||
Bimanual SO101 robot: 2 arms with 5 joints each + gripper.
|
||||
|
||||
Layout (real robot):
|
||||
[left_arm (5 joints + gripper), right_arm (5 joints + gripper)]
|
||||
- Left arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
- Right arm: shoulder_pan, shoulder_lift, elbow_flex, wrist_flex, wrist_roll, gripper
|
||||
|
||||
Real action dim: 12
|
||||
Model-facing dim: 20 (extra 8 dummy dims at the end)
|
||||
"""
|
||||
|
||||
# Model output / training dimension (to match pretrained policy)
|
||||
dim_action = 20
|
||||
|
||||
# Real robot action dimension
|
||||
REAL_DIM = 12
|
||||
|
||||
# Indices of real vs dummy channels
|
||||
REAL_IDXS = tuple(range(REAL_DIM)) # 0..11
|
||||
DUMMY_IDXS = tuple(range(REAL_DIM, dim_action)) # 12..19
|
||||
|
||||
# Grippers live in the real part
|
||||
gripper_idx = (5, 11) # left_gripper at idx 5, right_gripper at idx 11
|
||||
GRIPPER_SCALE = 1.0
|
||||
JOINTS_SCALE = 1.0
|
||||
|
||||
# Indices for left and right arm joints (excluding grippers)
|
||||
LEFT_ARM_JOINTS = (0, 1, 2, 3, 4)
|
||||
RIGHT_ARM_JOINTS = (6, 7, 8, 9, 10)
|
||||
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.mse = nn.MSELoss()
|
||||
self.bce = nn.BCEWithLogitsLoss()
|
||||
|
||||
# ---------- helpers ----------
|
||||
|
||||
def _pad_to_model_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""If last dim is REAL_DIM (12), pad zeros to reach dim_action (20)."""
|
||||
if x is None:
|
||||
return None
|
||||
if x.size(-1) == self.dim_action:
|
||||
return x
|
||||
if x.size(-1) != self.REAL_DIM:
|
||||
raise ValueError(
|
||||
f"Expected last dim to be {self.REAL_DIM} or {self.dim_action}, got {x.size(-1)}"
|
||||
)
|
||||
pad_shape = list(x.shape[:-1]) + [self.dim_action - self.REAL_DIM]
|
||||
pad = x.new_zeros(pad_shape)
|
||||
return torch.cat([x, pad], dim=-1)
|
||||
|
||||
def _trim_to_real_dim(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""Keep only the first REAL_DIM (12) dims for the real robot."""
|
||||
return x[..., : self.REAL_DIM]
|
||||
|
||||
# ---------- loss ----------
|
||||
|
||||
def compute_loss(self, pred, target):
|
||||
"""
|
||||
pred: [B, T, 20] from the model
|
||||
target: [B, T, 12] or [B, T, 20]
|
||||
We pad target → 20 and compute loss only on the real dims.
|
||||
"""
|
||||
# Ensure both are [B, T, 20]
|
||||
pred = self._pad_to_model_dim(pred)
|
||||
target = self._pad_to_model_dim(target)
|
||||
assert pred.shape == target.shape
|
||||
|
||||
# ---- MSE for all real dims (0–11) ----
|
||||
real_dims = 12
|
||||
|
||||
joints_loss = (
|
||||
self.mse(
|
||||
pred[:, :, :real_dims],
|
||||
target[:, :, :real_dims],
|
||||
)
|
||||
* self.JOINTS_SCALE
|
||||
)
|
||||
|
||||
left_arm_loss = self.mse(pred[:, :, :6], target[:, :, :6])
|
||||
right_arm_loss = self.mse(pred[:, :, 6:12], target[:, :, 6:12])
|
||||
|
||||
gripper_loss = (
|
||||
self.mse(
|
||||
pred[:, :, [5, 11]],
|
||||
target[:, :, [5, 11]],
|
||||
)
|
||||
* self.GRIPPER_SCALE
|
||||
)
|
||||
|
||||
return {
|
||||
"joints_loss": joints_loss,
|
||||
"gripper_loss": gripper_loss,
|
||||
"left_arm_loss": left_arm_loss,
|
||||
"right_arm_loss": right_arm_loss,
|
||||
}
|
||||
|
||||
# ---------- preprocess / postprocess ----------
|
||||
|
||||
def preprocess(self, proprio, action, mode="train"):
|
||||
"""
|
||||
- If proprio/action are 12-dim, pad them to 20 for the model.
|
||||
- Zero-out gripper channels in proprio/action to focus learning on joints.
|
||||
"""
|
||||
proprio_m = self._pad_to_model_dim(proprio.clone())
|
||||
action_m = self._pad_to_model_dim(action.clone()) if action is not None else None
|
||||
|
||||
proprio_m[..., self.gripper_idx] = 0.0
|
||||
if action_m is not None:
|
||||
action_m[..., self.gripper_idx] = 0.0
|
||||
|
||||
return proprio_m, action_m
|
||||
|
||||
def postprocess(self, action: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
- Model outputs [*, 20]
|
||||
- Apply sigmoid to gripper logits
|
||||
- Return only the first 12 dims for the real robot:
|
||||
["left_shoulder_pan.pos",
|
||||
"left_shoulder_lift.pos",
|
||||
"left_elbow_flex.pos",
|
||||
"left_wrist_flex.pos",
|
||||
"left_wrist_roll.pos",
|
||||
"left_gripper.pos",
|
||||
"right_shoulder_pan.pos",
|
||||
"right_shoulder_lift.pos",
|
||||
"right_elbow_flex.pos",
|
||||
"right_wrist_flex.pos",
|
||||
"right_wrist_roll.pos",
|
||||
"right_gripper.pos"]
|
||||
"""
|
||||
# Ensure we at least have the real dims + grippers
|
||||
if action.size(-1) < self.REAL_DIM:
|
||||
raise ValueError(f"Expected at least {self.REAL_DIM} dims in action, got {action.size(-1)}")
|
||||
|
||||
# Apply sigmoid on gripper channels in model space (indices 5 and 11)
|
||||
if action.size(-1) > max(self.gripper_idx):
|
||||
action[..., self.gripper_idx] = torch.sigmoid(action[..., self.gripper_idx])
|
||||
|
||||
# Return only the real 12-dim control vector for the env
|
||||
return self._trim_to_real_dim(action)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Exports
|
||||
# =============================================================================
|
||||
__all__ = [
|
||||
"BaseActionSpace",
|
||||
"build_action_space",
|
||||
"register_action",
|
||||
"EE6DActionSpace",
|
||||
"JointActionSpace",
|
||||
"AGIBOTEE6DActionSpace",
|
||||
"FrankaJoint7ActionSpace",
|
||||
"AutoActionSpace",
|
||||
"BimanualSO101ActionSpace",
|
||||
"ACTION_REGISTRY",
|
||||
]
|
||||
@@ -0,0 +1,353 @@
|
||||
# Copyright 2024 Microsoft and the HuggingFace Inc. team. All rights reserved.
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import warnings
|
||||
|
||||
from transformers.configuration_utils import PretrainedConfig
|
||||
from transformers.utils import logging
|
||||
|
||||
""" Florence-2 configuration"""
|
||||
|
||||
logger = logging.get_logger(__name__)
|
||||
|
||||
|
||||
class Florence2VisionConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2VisionModel`]. It is used to instantiate a Florence2VisionModel
|
||||
according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the Florence2VisionModel architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
drop_path_rate (`float`, *optional*, defaults to 0.1):
|
||||
The dropout rate of the drop path layer.
|
||||
patch_size (`List[int]`, *optional*, defaults to [7, 3, 3, 3]):
|
||||
The patch size of the image.
|
||||
patch_stride (`List[int]`, *optional*, defaults to [4, 2, 2, 2]):
|
||||
The patch stride of the image.
|
||||
patch_padding (`List[int]`, *optional*, defaults to [3, 1, 1, 1]):
|
||||
The patch padding of the image.
|
||||
patch_prenorm (`List[bool]`, *optional*, defaults to [false, true, true, true]):
|
||||
Whether to apply layer normalization before the patch embedding layer.
|
||||
enable_checkpoint (`bool`, *optional*, defaults to False):
|
||||
Whether to enable checkpointing.
|
||||
dim_embed (`List[int]`, *optional*, defaults to [256, 512, 1024, 2048]):
|
||||
The dimension of the embedding layer.
|
||||
num_heads (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of attention heads.
|
||||
num_groups (`List[int]`, *optional*, defaults to [8, 16, 32, 64]):
|
||||
The number of groups.
|
||||
depths (`List[int]`, *optional*, defaults to [1, 1, 9, 1]):
|
||||
The depth of the model.
|
||||
window_size (`int`, *optional*, defaults to 12):
|
||||
The window size of the model.
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
The dimension of the projection layer.
|
||||
visual_temporal_embedding (`dict`, *optional*):
|
||||
The configuration of the visual temporal embedding.
|
||||
image_pos_embed (`dict`, *optional*):
|
||||
The configuration of the image position embedding.
|
||||
image_feature_source (`List[str]`, *optional*, defaults to ["spatial_avg_pool", "temporal_avg_pool"]):
|
||||
The source of the image feature.
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2VisionConfig, Florence2VisionModel
|
||||
|
||||
>>> # Initializing a Florence2 Vision style configuration
|
||||
>>> configuration = Florence2VisionConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2VisionModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "davit"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
drop_path_rate=0.1,
|
||||
patch_size=None,
|
||||
patch_stride=None,
|
||||
patch_padding=None,
|
||||
patch_prenorm=None,
|
||||
enable_checkpoint=False,
|
||||
dim_embed=None,
|
||||
num_heads=None,
|
||||
num_groups=None,
|
||||
depths=None,
|
||||
window_size=12,
|
||||
projection_dim=1024,
|
||||
visual_temporal_embedding=None,
|
||||
image_pos_embed=None,
|
||||
image_feature_source=None,
|
||||
**kwargs,
|
||||
):
|
||||
self.drop_path_rate = drop_path_rate
|
||||
self.patch_size = patch_size if patch_size is not None else [7, 3, 3, 3]
|
||||
self.patch_stride = patch_stride if patch_stride is not None else [4, 2, 2, 2]
|
||||
self.patch_padding = patch_padding if patch_padding is not None else [3, 1, 1, 1]
|
||||
self.patch_prenorm = patch_prenorm if patch_prenorm is not None else [False, True, True, True]
|
||||
self.enable_checkpoint = enable_checkpoint
|
||||
self.dim_embed = dim_embed if dim_embed is not None else [256, 512, 1024, 2048]
|
||||
self.num_heads = num_heads if num_heads is not None else [8, 16, 32, 64]
|
||||
self.num_groups = num_groups if num_groups is not None else [8, 16, 32, 64]
|
||||
self.depths = depths if depths is not None else [1, 1, 9, 1]
|
||||
self.window_size = window_size
|
||||
self.projection_dim = projection_dim
|
||||
|
||||
if visual_temporal_embedding is None:
|
||||
visual_temporal_embedding = {
|
||||
"type": "COSINE",
|
||||
"max_temporal_embeddings": 100,
|
||||
}
|
||||
self.visual_temporal_embedding = visual_temporal_embedding
|
||||
|
||||
if image_pos_embed is None:
|
||||
image_pos_embed = {
|
||||
"type": "learned_abs_2d",
|
||||
"max_pos_embeddings": 1000,
|
||||
}
|
||||
self.image_pos_embed = image_pos_embed
|
||||
|
||||
self.image_feature_source = (
|
||||
image_feature_source
|
||||
if image_feature_source is not None
|
||||
else ["spatial_avg_pool", "temporal_avg_pool"]
|
||||
)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
|
||||
class Florence2LanguageConfig(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2LanguagePreTrainedModel`]. It is used to instantiate a BART
|
||||
model according to the specified arguments, defining the model architecture. Instantiating a configuration with the
|
||||
defaults will yield a similar configuration to that of the BART
|
||||
[facebook/bart-large](https://huggingface.co/facebook/bart-large) architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
|
||||
Args:
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2Language model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`Florence2LanguageModel`].
|
||||
d_model (`int`, *optional*, defaults to 1024):
|
||||
Dimensionality of the layers and the pooler layer.
|
||||
encoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of encoder layers.
|
||||
decoder_layers (`int`, *optional*, defaults to 12):
|
||||
Number of decoder layers.
|
||||
encoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer encoder.
|
||||
decoder_attention_heads (`int`, *optional*, defaults to 16):
|
||||
Number of attention heads for each attention layer in the Transformer decoder.
|
||||
decoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
encoder_ffn_dim (`int`, *optional*, defaults to 4096):
|
||||
Dimensionality of the "intermediate" (often named feed-forward) layer in decoder.
|
||||
activation_function (`str` or `function`, *optional*, defaults to `"gelu"`):
|
||||
The non-linear activation function (function or string) in the encoder and pooler. If string, `"gelu"`,
|
||||
`"relu"`, `"silu"` and `"gelu_new"` are supported.
|
||||
dropout (`float`, *optional*, defaults to 0.1):
|
||||
The dropout probability for all fully connected layers in the embeddings, encoder, and pooler.
|
||||
attention_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for the attention probabilities.
|
||||
activation_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for activations inside the fully connected layer.
|
||||
classifier_dropout (`float`, *optional*, defaults to 0.0):
|
||||
The dropout ratio for classifier.
|
||||
max_position_embeddings (`int`, *optional*, defaults to 1024):
|
||||
The maximum sequence length that this model might ever be used with. Typically set this to something large
|
||||
just in case (e.g., 512 or 1024 or 2048).
|
||||
init_std (`float`, *optional*, defaults to 0.02):
|
||||
The standard deviation of the truncated_normal_initializer for initializing all weight matrices.
|
||||
encoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the encoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
decoder_layerdrop (`float`, *optional*, defaults to 0.0):
|
||||
The LayerDrop probability for the decoder. See the [LayerDrop paper](see https://arxiv.org/abs/1909.11556)
|
||||
for more details.
|
||||
scale_embedding (`bool`, *optional*, defaults to `False`):
|
||||
Scale embeddings by diving by sqrt(d_model).
|
||||
use_cache (`bool`, *optional*, defaults to `True`):
|
||||
Whether or not the model should return the last key/values attentions (not used by all models).
|
||||
num_labels (`int`, *optional*, defaults to 3):
|
||||
The number of labels to use in [`Florence2LanguageForSequenceClassification`].
|
||||
forced_eos_token_id (`int`, *optional*, defaults to 2):
|
||||
The id of the token to force as the last generated token when `max_length` is reached. Usually set to
|
||||
`eos_token_id`.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2LanguageConfig, Florence2LanguageModel
|
||||
|
||||
>>> # Initializing a Florence2 Language style configuration
|
||||
>>> configuration = Florence2LanguageConfig()
|
||||
|
||||
>>> # Initializing a model (with random weights)
|
||||
>>> model = Florence2LanguageModel(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2_language"
|
||||
keys_to_ignore_at_inference = ["past_key_values"]
|
||||
attribute_map = {"num_attention_heads": "encoder_attention_heads", "hidden_size": "d_model"}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vocab_size=51289,
|
||||
max_position_embeddings=1024,
|
||||
encoder_layers=12,
|
||||
encoder_ffn_dim=4096,
|
||||
encoder_attention_heads=16,
|
||||
decoder_layers=12,
|
||||
decoder_ffn_dim=4096,
|
||||
decoder_attention_heads=16,
|
||||
encoder_layerdrop=0.0,
|
||||
decoder_layerdrop=0.0,
|
||||
activation_function="gelu",
|
||||
d_model=1024,
|
||||
dropout=0.1,
|
||||
attention_dropout=0.0,
|
||||
activation_dropout=0.0,
|
||||
init_std=0.02,
|
||||
classifier_dropout=0.0,
|
||||
scale_embedding=False,
|
||||
use_cache=True,
|
||||
num_labels=3,
|
||||
pad_token_id=1,
|
||||
bos_token_id=0,
|
||||
eos_token_id=2,
|
||||
is_encoder_decoder=True,
|
||||
decoder_start_token_id=2,
|
||||
forced_eos_token_id=2,
|
||||
**kwargs,
|
||||
):
|
||||
self.vocab_size = vocab_size
|
||||
self.max_position_embeddings = max_position_embeddings
|
||||
self.d_model = d_model
|
||||
self.encoder_ffn_dim = encoder_ffn_dim
|
||||
self.encoder_layers = encoder_layers
|
||||
self.encoder_attention_heads = encoder_attention_heads
|
||||
self.decoder_ffn_dim = decoder_ffn_dim
|
||||
self.decoder_layers = decoder_layers
|
||||
self.decoder_attention_heads = decoder_attention_heads
|
||||
self.dropout = dropout
|
||||
self.attention_dropout = attention_dropout
|
||||
self.activation_dropout = activation_dropout
|
||||
self.activation_function = activation_function
|
||||
self.init_std = init_std
|
||||
self.encoder_layerdrop = encoder_layerdrop
|
||||
self.decoder_layerdrop = decoder_layerdrop
|
||||
self.classifier_dropout = classifier_dropout
|
||||
self.use_cache = use_cache
|
||||
self.num_hidden_layers = encoder_layers
|
||||
self.scale_embedding = scale_embedding # scale factor will be sqrt(d_model) if True
|
||||
|
||||
super().__init__(
|
||||
num_labels=num_labels,
|
||||
pad_token_id=pad_token_id,
|
||||
bos_token_id=bos_token_id,
|
||||
eos_token_id=eos_token_id,
|
||||
is_encoder_decoder=is_encoder_decoder,
|
||||
decoder_start_token_id=decoder_start_token_id,
|
||||
forced_eos_token_id=forced_eos_token_id,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# ensure backward compatibility for BART CNN models
|
||||
if self.forced_bos_token_id is None and kwargs.get("force_bos_token_to_be_generated", False):
|
||||
self.forced_bos_token_id = self.bos_token_id
|
||||
warnings.warn(
|
||||
f"Please make sure the config includes `forced_bos_token_id={self.bos_token_id}` in future versions. "
|
||||
"The config can simply be saved and uploaded again to be fixed.",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
|
||||
class Florence2Config(PretrainedConfig):
|
||||
r"""
|
||||
This is the configuration class to store the configuration of a [`Florence2ForConditionalGeneration`]. It is used to instantiate an
|
||||
Florence-2 model according to the specified arguments, defining the model architecture.
|
||||
|
||||
Configuration objects inherit from [`PretrainedConfig`] and can be used to control the model outputs. Read the
|
||||
documentation from [`PretrainedConfig`] for more information.
|
||||
|
||||
Args:
|
||||
vision_config (`Florence2VisionConfig`, *optional*):
|
||||
Custom vision config or dict
|
||||
text_config (`Union[AutoConfig, dict]`, *optional*):
|
||||
The config object of the text backbone.
|
||||
ignore_index (`int`, *optional*, defaults to -100):
|
||||
The ignore index for the loss function.
|
||||
vocab_size (`int`, *optional*, defaults to 51289):
|
||||
Vocabulary size of the Florence2model. Defines the number of different tokens that can be represented by the
|
||||
`inputs_ids` passed when calling [`~Florence2ForConditionalGeneration`]
|
||||
projection_dim (`int`, *optional*, defaults to 1024):
|
||||
Dimension of the multimodal projection space.
|
||||
|
||||
Example:
|
||||
|
||||
```python
|
||||
>>> from transformers import Florence2ForConditionalGeneration, Florence2Config, CLIPVisionConfig, BartConfig
|
||||
|
||||
>>> # Initializing a clip-like vision config
|
||||
>>> vision_config = CLIPVisionConfig()
|
||||
|
||||
>>> # Initializing a Bart config
|
||||
>>> text_config = BartConfig()
|
||||
|
||||
>>> # Initializing a Florence-2 configuration
|
||||
>>> configuration = Florence2Config(vision_config, text_config)
|
||||
|
||||
>>> # Initializing a model from the florence-2 configuration
|
||||
>>> model = Florence2ForConditionalGeneration(configuration)
|
||||
|
||||
>>> # Accessing the model configuration
|
||||
>>> configuration = model.config
|
||||
```"""
|
||||
|
||||
model_type = "florence2"
|
||||
is_composition = False
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
vision_config=None,
|
||||
text_config=None,
|
||||
ignore_index=-100,
|
||||
vocab_size=51289,
|
||||
projection_dim=1024,
|
||||
**kwargs,
|
||||
):
|
||||
self.ignore_index = ignore_index
|
||||
self.vocab_size = vocab_size
|
||||
self.projection_dim = projection_dim
|
||||
if vision_config is not None:
|
||||
vision_config = Florence2VisionConfig(**vision_config)
|
||||
self.vision_config = vision_config
|
||||
|
||||
self.text_config = text_config
|
||||
if text_config is not None:
|
||||
self.text_config = Florence2LanguageConfig(**text_config)
|
||||
|
||||
super().__init__(**kwargs)
|
||||
@@ -0,0 +1,203 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import XVLAAdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# Conditional import for type checking and lazy loading
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
if TYPE_CHECKING or _transformers_available:
|
||||
from .configuration_florence2 import Florence2Config
|
||||
else:
|
||||
Florence2Config = None
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("xvla")
|
||||
@dataclass
|
||||
class XVLAConfig(PreTrainedConfig):
|
||||
"""
|
||||
Configuration class for the XVLA (Extended Vision-Language-Action) policy so it can
|
||||
plug into the LeRobot training stack.
|
||||
|
||||
The config mirrors the knobs exposed in the original XVLA repository but also
|
||||
declares the input/output feature contract required by LeRobot.
|
||||
"""
|
||||
|
||||
# Input / output structure
|
||||
n_obs_steps: int = 1
|
||||
chunk_size: int = 32
|
||||
n_action_steps: int = 32
|
||||
dtype: str = "float32" # Options: "bfloat16", "float32"
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.IDENTITY,
|
||||
"ACTION": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
# Florence2 backbone and tokenizer configuration
|
||||
florence_config: dict[str, Any] = field(default_factory=dict)
|
||||
tokenizer_name: str = "facebook/bart-large"
|
||||
tokenizer_max_length: int = 64
|
||||
tokenizer_padding_side: str = "right"
|
||||
pad_language_to: str = "max_length"
|
||||
|
||||
# Transformer head
|
||||
hidden_size: int = 1024
|
||||
depth: int = 24
|
||||
num_heads: int = 16
|
||||
mlp_ratio: float = 4.0
|
||||
num_domains: int = 30
|
||||
len_soft_prompts: int = 32
|
||||
dim_time: int = 32
|
||||
max_len_seq: int = 512
|
||||
use_hetero_proj: bool = False
|
||||
|
||||
# Action & proprioception
|
||||
action_mode: str = "ee6d"
|
||||
num_denoising_steps: int = 10
|
||||
use_proprio: bool = True
|
||||
max_state_dim: int = 32
|
||||
max_action_dim: int = 20 # Maximum action dimension for padding (used by "auto" action mode)
|
||||
domain_feature_key: str | None = None
|
||||
|
||||
# Vision preprocessing
|
||||
resize_imgs_with_padding: tuple[int, int] | None = None
|
||||
num_image_views: int | None = None
|
||||
empty_cameras: int = 0
|
||||
|
||||
# Freezing options for VLM components
|
||||
# By default, VLM encoders are frozen and only policy transformer + soft prompts train
|
||||
freeze_vision_encoder: bool = False # Freeze VLM vision encoder weights
|
||||
freeze_language_encoder: bool = False # Freeze VLM language encoder weights
|
||||
train_policy_transformer: bool = True # Allow policy transformer to train
|
||||
train_soft_prompts: bool = True # Allow soft prompts to train
|
||||
|
||||
# Training presets
|
||||
optimizer_lr: float = 1e-4
|
||||
optimizer_betas: tuple[float, float] = (0.9, 0.99)
|
||||
optimizer_eps: float = 1e-8
|
||||
optimizer_weight_decay: float = 0.0
|
||||
optimizer_grad_clip_norm: float = 10.0
|
||||
# Soft-prompt LR settings (for optional warm-up)
|
||||
optimizer_soft_prompt_lr_scale: float = 1.0 # Scale factor for soft-prompt LR
|
||||
optimizer_soft_prompt_warmup_lr_scale: float | None = None # Start scale for warmup (e.g., 0.01)
|
||||
|
||||
scheduler_warmup_steps: int = 1_000
|
||||
scheduler_decay_steps: int = 30_000
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
super().__post_init__()
|
||||
|
||||
if self.chunk_size <= 0:
|
||||
raise ValueError("`chunk_size` must be strictly positive.")
|
||||
if self.n_action_steps > self.chunk_size:
|
||||
raise ValueError(
|
||||
f"`n_action_steps` ({self.n_action_steps}) must be <= `chunk_size` ({self.chunk_size})."
|
||||
)
|
||||
if self.num_image_views is not None and self.num_image_views <= 0:
|
||||
raise ValueError("`num_image_views` must be > 0 when specified.")
|
||||
if self.dtype not in ["bfloat16", "float32"]:
|
||||
raise ValueError(f"Invalid dtype: {self.dtype}")
|
||||
self._florence_config_obj: Florence2Config | None = None
|
||||
|
||||
def get_florence_config(self) -> Florence2Config:
|
||||
"""
|
||||
Build (and cache) the Florence2 transformer config that should back the VLM.
|
||||
"""
|
||||
if self._florence_config_obj is None:
|
||||
config_dict = dict(self.florence_config)
|
||||
if "vision_config" not in config_dict or config_dict["vision_config"] is None:
|
||||
raise ValueError("vision_config is required")
|
||||
|
||||
if "text_config" not in config_dict or config_dict["text_config"] is None:
|
||||
raise ValueError("text_config is required")
|
||||
self._florence_config_obj = Florence2Config(**config_dict)
|
||||
return self._florence_config_obj
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if not self.image_features:
|
||||
raise ValueError("XVLA requires at least one visual feature in the inputs.")
|
||||
if self.use_proprio and self.robot_state_feature is None:
|
||||
raise ValueError("`use_proprio=True` requires a proprioceptive state feature.")
|
||||
if self.num_image_views is None:
|
||||
self.num_image_views = len(self.image_features) + self.empty_cameras
|
||||
else:
|
||||
self.num_image_views = max(self.num_image_views, len(self.image_features) + self.empty_cameras)
|
||||
|
||||
if self.empty_cameras > 0:
|
||||
height, width = (480, 640)
|
||||
if self.resize_imgs_with_padding is not None:
|
||||
height, width = self.resize_imgs_with_padding
|
||||
for idx in range(self.empty_cameras):
|
||||
key = f"{OBS_IMAGES}.empty_camera_{idx}"
|
||||
if key not in self.input_features:
|
||||
self.input_features[key] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(3, height, width),
|
||||
)
|
||||
|
||||
def get_optimizer_preset(self) -> XVLAAdamWConfig:
|
||||
"""Return the XVLA-specific optimizer with differential learning rates.
|
||||
|
||||
This optimizer applies:
|
||||
- 1/10 LR for VLM parameters (stable optimization)
|
||||
- Full LR for transformer/action head
|
||||
- Configurable LR for soft-prompts (with optional warm-up)
|
||||
"""
|
||||
return XVLAAdamWConfig(
|
||||
lr=self.optimizer_lr,
|
||||
betas=self.optimizer_betas,
|
||||
eps=self.optimizer_eps,
|
||||
weight_decay=self.optimizer_weight_decay,
|
||||
grad_clip_norm=self.optimizer_grad_clip_norm,
|
||||
soft_prompt_lr_scale=self.optimizer_soft_prompt_lr_scale,
|
||||
soft_prompt_warmup_lr_scale=self.optimizer_soft_prompt_warmup_lr_scale,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=self.optimizer_lr,
|
||||
decay_lr=self.scheduler_decay_lr,
|
||||
num_warmup_steps=self.scheduler_warmup_steps,
|
||||
num_decay_steps=self.scheduler_decay_steps,
|
||||
)
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list[int]:
|
||||
return list(range(self.chunk_size))
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> list[int] | None:
|
||||
return None
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,548 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import builtins
|
||||
import logging
|
||||
import os
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.utils import populate_queues
|
||||
from lerobot.utils.constants import ACTION, OBS_LANGUAGE_TOKENS, OBS_STATE
|
||||
|
||||
from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .soft_transformer import SoftPromptedTransformer
|
||||
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
"""
|
||||
XVLA backbone that stitches Florence-2 embeddings with the temporal/action transformer head.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
config: XVLAConfig,
|
||||
florence_config: Florence2Config,
|
||||
proprio_dim: int,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.chunk_size: int = config.chunk_size
|
||||
self.use_proprio: bool = config.use_proprio
|
||||
|
||||
# Build action space with auto-detection for "auto" mode
|
||||
if config.action_mode.lower() == "auto":
|
||||
# Auto-detect real action dim from config.action_feature
|
||||
real_dim = (
|
||||
config.action_feature.shape[-1]
|
||||
if config.action_feature is not None
|
||||
else config.max_action_dim
|
||||
)
|
||||
self.action_space = build_action_space(
|
||||
config.action_mode.lower(),
|
||||
real_dim=real_dim,
|
||||
max_dim=config.max_action_dim,
|
||||
)
|
||||
else:
|
||||
self.action_space = build_action_space(config.action_mode.lower())
|
||||
|
||||
self.dim_action = self.action_space.dim_action
|
||||
self.dim_proprio = proprio_dim
|
||||
|
||||
self.vlm = Florence2ForConditionalGeneration(florence_config)
|
||||
if hasattr(self.vlm, "language_model"):
|
||||
lm = self.vlm.language_model
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "decoder"):
|
||||
del lm.model.decoder
|
||||
if hasattr(lm, "lm_head"):
|
||||
del lm.lm_head
|
||||
|
||||
projection_dim = getattr(self.vlm.config, "projection_dim", None)
|
||||
if projection_dim is None:
|
||||
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||
|
||||
self.transformer = SoftPromptedTransformer(
|
||||
hidden_size=config.hidden_size,
|
||||
multi_modal_input_size=projection_dim,
|
||||
depth=config.depth,
|
||||
num_heads=config.num_heads,
|
||||
mlp_ratio=config.mlp_ratio,
|
||||
num_domains=config.num_domains,
|
||||
dim_action=self.dim_action,
|
||||
dim_propio=self.dim_proprio,
|
||||
len_soft_prompts=config.len_soft_prompts,
|
||||
dim_time=config.dim_time,
|
||||
max_len_seq=config.max_len_seq,
|
||||
use_hetero_proj=config.use_hetero_proj,
|
||||
)
|
||||
|
||||
# Apply freezing based on config
|
||||
self._apply_freezing()
|
||||
|
||||
# Apply dtype casting based on config
|
||||
self._apply_dtype()
|
||||
|
||||
def _get_target_dtype(self) -> torch.dtype:
|
||||
"""Get the target dtype based on config."""
|
||||
if self.config.dtype == "bfloat16":
|
||||
return torch.bfloat16
|
||||
return torch.float32
|
||||
|
||||
def _apply_dtype(self) -> None:
|
||||
"""
|
||||
Apply dtype casting to model components based on config.
|
||||
"""
|
||||
target_dtype = self._get_target_dtype()
|
||||
self.to(dtype=target_dtype)
|
||||
|
||||
def _apply_freezing(self) -> None:
|
||||
"""
|
||||
Freeze VLM vision and language encoders based on config options.
|
||||
Keep only policy transformer and soft prompts trainable.
|
||||
"""
|
||||
# Freeze vision encoder
|
||||
if self.config.freeze_vision_encoder and hasattr(self.vlm, "vision_tower"):
|
||||
for param in self.vlm.vision_tower.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze language encoder
|
||||
if self.config.freeze_language_encoder and hasattr(self.vlm, "language_model"):
|
||||
lm = self.vlm.language_model
|
||||
# Freeze encoder
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "encoder"):
|
||||
for param in lm.model.encoder.parameters():
|
||||
param.requires_grad = False
|
||||
# Freeze shared embeddings
|
||||
if hasattr(lm, "model") and hasattr(lm.model, "shared"):
|
||||
for param in lm.model.shared.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze or unfreeze policy transformer
|
||||
if not self.config.train_policy_transformer:
|
||||
for name, param in self.transformer.named_parameters():
|
||||
if "soft_prompts" not in name:
|
||||
param.requires_grad = False
|
||||
|
||||
# Freeze or unfreeze soft prompts
|
||||
if not self.config.train_soft_prompts and hasattr(self.transformer, "soft_prompt_hub"):
|
||||
for param in self.transformer.soft_prompt_hub.parameters():
|
||||
param.requires_grad = False
|
||||
|
||||
def forward_vlm(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
pixel_values: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Encode text and multi-view images via Florence2 encoder.
|
||||
"""
|
||||
batch_size, num_views = pixel_values.shape[:2]
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
|
||||
valid_images = flat_images[flat_mask]
|
||||
valid_feats = self.vlm._encode_image(valid_images)
|
||||
tokens_per_view, hidden_dim = valid_feats.shape[1:]
|
||||
|
||||
image_features = valid_feats.new_zeros((batch_size * num_views, tokens_per_view, hidden_dim))
|
||||
image_features[flat_mask] = valid_feats
|
||||
image_features = image_features.view(batch_size, num_views, tokens_per_view, hidden_dim)
|
||||
inputs_embeds = self.vlm.get_input_embeddings()(input_ids)
|
||||
merged_embeds, attention_mask = self.vlm._merge_input_ids_with_image_features(
|
||||
image_features[:, 0],
|
||||
inputs_embeds,
|
||||
)
|
||||
|
||||
enc_out = self.vlm.language_model.model.encoder(
|
||||
attention_mask=attention_mask,
|
||||
inputs_embeds=merged_embeds,
|
||||
)[0]
|
||||
|
||||
aux_visual_inputs = image_features[:, 1:].reshape(batch_size, -1, hidden_dim)
|
||||
return {"vlm_features": enc_out, "aux_visual_inputs": aux_visual_inputs}
|
||||
|
||||
def forward(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
action: torch.Tensor,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""
|
||||
Forward pass for the XVLA model.
|
||||
"""
|
||||
target_dtype = self._get_target_dtype()
|
||||
image_input = image_input.to(dtype=target_dtype)
|
||||
proprio = proprio.to(dtype=target_dtype)
|
||||
action = action.to(dtype=target_dtype)
|
||||
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
t = (
|
||||
torch.rand(1, device=input_ids.device, dtype=target_dtype)
|
||||
+ torch.arange(batch_size, device=input_ids.device, dtype=target_dtype) / batch_size
|
||||
) % (1 - 1e-5)
|
||||
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
|
||||
pred_action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=action_noisy_m,
|
||||
t=t,
|
||||
proprio=proprio_m,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.compute_loss(pred_action, action)
|
||||
|
||||
@torch.no_grad()
|
||||
def generate_actions(
|
||||
self,
|
||||
input_ids: torch.LongTensor,
|
||||
image_input: torch.FloatTensor,
|
||||
image_mask: torch.Tensor,
|
||||
domain_id: torch.LongTensor,
|
||||
proprio: torch.Tensor,
|
||||
steps: int,
|
||||
) -> torch.Tensor:
|
||||
self.eval()
|
||||
|
||||
target_dtype = self._get_target_dtype()
|
||||
image_input = image_input.to(dtype=target_dtype)
|
||||
proprio = proprio.to(dtype=target_dtype)
|
||||
|
||||
enc = self.forward_vlm(input_ids, image_input, image_mask)
|
||||
|
||||
batch_size = input_ids.shape[0]
|
||||
action_dim = self.dim_action
|
||||
|
||||
x1 = torch.randn(batch_size, self.chunk_size, action_dim, device=proprio.device, dtype=target_dtype)
|
||||
action = torch.zeros_like(x1)
|
||||
|
||||
steps = max(1, int(steps))
|
||||
for i in range(steps, 0, -1):
|
||||
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=target_dtype)
|
||||
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||
action = self.transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=x_t_m,
|
||||
proprio=proprio_m,
|
||||
t=t,
|
||||
**enc,
|
||||
)
|
||||
return self.action_space.postprocess(action)
|
||||
|
||||
|
||||
class XVLAPolicy(PreTrainedPolicy):
|
||||
"""LeRobot-compliant wrapper built around the XVLA model."""
|
||||
|
||||
config_class = XVLAConfig
|
||||
name = "xvla"
|
||||
|
||||
def __init__(self, config: XVLAConfig):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
florence_config = config.get_florence_config()
|
||||
proprio_dim = config.max_state_dim if config.use_proprio else 0
|
||||
self.model = XVLAModel(config=config, florence_config=florence_config, proprio_dim=proprio_dim)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
self._queues = {
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
"""Return trainable named parameters for optimization.
|
||||
|
||||
Returns a dict of name -> param for all trainable parameters.
|
||||
This enables the xvla-adamw optimizer to apply differential learning rates
|
||||
based on parameter names (e.g., 1/10 LR for VLM components).
|
||||
"""
|
||||
return dict(filter(lambda kv: kv[1].requires_grad, self.named_parameters()))
|
||||
|
||||
def _prepare_state(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
if not self.config.use_proprio or OBS_STATE not in batch:
|
||||
return torch.zeros(batch_size, 0, device=device)
|
||||
state = batch[OBS_STATE]
|
||||
if state.ndim > 2:
|
||||
state = state[:, -1, :]
|
||||
return pad_vector(state, self.model.dim_proprio)
|
||||
|
||||
def _prepare_images(self, batch: dict[str, Tensor]) -> tuple[Tensor, Tensor]:
|
||||
present_img_keys = [key for key in self.config.image_features if key in batch]
|
||||
if len(present_img_keys) == 0:
|
||||
raise ValueError(
|
||||
"All image features are missing from the batch. "
|
||||
f"Batch keys: {list(batch.keys())}, expected at least one of {list(self.config.image_features)}."
|
||||
)
|
||||
|
||||
images = []
|
||||
masks = []
|
||||
for key in present_img_keys:
|
||||
img = batch[key][:, -1] if batch[key].ndim == 5 else batch[key]
|
||||
if self.config.resize_imgs_with_padding is not None:
|
||||
img = resize_with_pad(img, *self.config.resize_imgs_with_padding)
|
||||
images.append(img)
|
||||
masks.append(torch.ones(img.size(0), dtype=torch.bool, device=img.device))
|
||||
|
||||
stacked_imgs = torch.stack(images, dim=1)
|
||||
stacked_masks = torch.stack(masks, dim=1)
|
||||
|
||||
total_views = self.config.num_image_views or stacked_imgs.size(1)
|
||||
total_views = max(total_views, stacked_imgs.size(1))
|
||||
num_pad = total_views - stacked_imgs.size(1)
|
||||
if num_pad > 0:
|
||||
pad_shape = (stacked_imgs.size(0), num_pad, *stacked_imgs.shape[2:])
|
||||
pad_imgs = stacked_imgs.new_zeros(pad_shape)
|
||||
pad_masks = stacked_masks.new_zeros((stacked_masks.size(0), num_pad))
|
||||
stacked_imgs = torch.cat([stacked_imgs, pad_imgs], dim=1)
|
||||
stacked_masks = torch.cat([stacked_masks, pad_masks], dim=1)
|
||||
|
||||
return stacked_imgs, stacked_masks
|
||||
|
||||
def _get_domain_id(self, batch: dict[str, Tensor], batch_size: int, device: torch.device) -> Tensor:
|
||||
candidate = None
|
||||
if self.config.domain_feature_key and self.config.domain_feature_key in batch:
|
||||
candidate = batch[self.config.domain_feature_key]
|
||||
elif "domain_id" in batch:
|
||||
candidate = batch["domain_id"]
|
||||
|
||||
if candidate is None:
|
||||
return torch.zeros(batch_size, dtype=torch.long, device=device)
|
||||
|
||||
if not isinstance(candidate, torch.Tensor):
|
||||
candidate = torch.as_tensor(candidate, device=device)
|
||||
else:
|
||||
candidate = candidate.to(device=device)
|
||||
|
||||
if candidate.ndim == 0:
|
||||
candidate = candidate.expand(batch_size)
|
||||
if candidate.ndim > 1:
|
||||
candidate = candidate.view(candidate.shape[0], -1)[:, 0]
|
||||
if candidate.shape[0] != batch_size:
|
||||
candidate = candidate.expand(batch_size)
|
||||
return candidate.to(dtype=torch.long)
|
||||
|
||||
def _prepare_action_targets(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
if ACTION not in batch:
|
||||
raise ValueError("Batch is missing action targets required for training.")
|
||||
actions = batch[ACTION]
|
||||
if actions.ndim == 2:
|
||||
actions = actions.unsqueeze(1)
|
||||
actions = pad_tensor_along_dim(actions, self.config.chunk_size, dim=1)
|
||||
if actions.shape[-1] != self.model.dim_action:
|
||||
actions = pad_vector(actions, self.model.dim_action)
|
||||
return actions
|
||||
|
||||
def _build_model_inputs(self, batch: dict[str, Tensor]) -> dict[str, Tensor]:
|
||||
input_ids = batch[OBS_LANGUAGE_TOKENS]
|
||||
batch_size = input_ids.shape[0]
|
||||
images, image_mask = self._prepare_images(batch)
|
||||
domain_id = self._get_domain_id(batch, batch_size, images.device)
|
||||
proprio = self._prepare_state(batch, batch_size, images.device)
|
||||
return {
|
||||
"input_ids": input_ids,
|
||||
"image_input": images,
|
||||
"image_mask": image_mask,
|
||||
"domain_id": domain_id,
|
||||
"proprio": proprio,
|
||||
}
|
||||
|
||||
def forward(self, batch: dict[str, Tensor]) -> tuple[Tensor, dict]:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
targets = self._prepare_action_targets(batch)
|
||||
losses = self.model(action=targets, **inputs)
|
||||
total_loss = sum(losses.values())
|
||||
|
||||
log_dict = {k: v.detach().item() for k, v in losses.items()}
|
||||
log_dict["loss"] = total_loss.detach().item()
|
||||
return total_loss, log_dict
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
inputs = self._build_model_inputs(batch)
|
||||
actions = self.model.generate_actions(**inputs, steps=self.config.num_denoising_steps)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
return self._get_action_chunk(batch)
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor: # noqa: ARG002
|
||||
self.eval()
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
actions = self._get_action_chunk(batch)
|
||||
self._queues[ACTION].extend(actions.transpose(0, 1)[: self.config.n_action_steps])
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
@classmethod
|
||||
def from_pretrained(
|
||||
cls: builtins.type[T],
|
||||
pretrained_name_or_path: str | Path,
|
||||
*,
|
||||
config: PreTrainedConfig | None = None,
|
||||
force_download: bool = False,
|
||||
resume_download: bool | None = None,
|
||||
proxies: dict | None = None,
|
||||
token: str | bool | None = None,
|
||||
cache_dir: str | Path | None = None,
|
||||
local_files_only: bool = False,
|
||||
revision: str | None = None,
|
||||
strict: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
"""
|
||||
Loads XVLA model weights with:
|
||||
- automatic prefix 'model.' added to all keys
|
||||
- skip list for layers that should remain randomly initialized
|
||||
"""
|
||||
import safetensors.torch
|
||||
|
||||
# step 1: load config
|
||||
# TODO: jadechoghari, fix this
|
||||
if config is None:
|
||||
config = PreTrainedConfig.from_pretrained(
|
||||
pretrained_name_or_path=pretrained_name_or_path,
|
||||
force_download=force_download,
|
||||
resume_download=resume_download,
|
||||
proxies=proxies,
|
||||
token=token,
|
||||
cache_dir=cache_dir,
|
||||
local_files_only=local_files_only,
|
||||
revision=revision,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
model_id = str(pretrained_name_or_path)
|
||||
instance = cls(config, **kwargs)
|
||||
# step 2: locate model.safetensors
|
||||
if os.path.isdir(model_id):
|
||||
logging.info("Loading weights from local directory")
|
||||
model_file = os.path.join(model_id, "model.safetensors")
|
||||
else:
|
||||
try:
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import HfHubHTTPError
|
||||
|
||||
model_file = hf_hub_download(
|
||||
repo_id=model_id,
|
||||
filename="model.safetensors",
|
||||
revision=revision,
|
||||
cache_dir=cache_dir,
|
||||
force_download=force_download,
|
||||
proxies=proxies,
|
||||
resume_download=resume_download,
|
||||
token=token,
|
||||
local_files_only=local_files_only,
|
||||
)
|
||||
except HfHubHTTPError as e:
|
||||
raise FileNotFoundError(f"model.safetensors not found on the Hub at {model_id}") from e
|
||||
|
||||
logging.info(f"Loading checkpoint from {model_file}")
|
||||
# step 3: load state dict
|
||||
state_dict = safetensors.torch.load_file(model_file)
|
||||
encoder_key = "model.vlm.language_model.model.encoder.embed_tokens.weight"
|
||||
shared_key = "model.vlm.language_model.model.shared.weight"
|
||||
if encoder_key in state_dict:
|
||||
state_dict[shared_key] = state_dict[encoder_key]
|
||||
# or deepcopy
|
||||
# step 4: load into instance
|
||||
instance.load_state_dict(state_dict, strict=True)
|
||||
logging.info("Loaded XVLA checkpoint")
|
||||
# step 5: finalize
|
||||
# Reapply dtype after loading state dict
|
||||
instance.model._apply_dtype()
|
||||
instance.to(config.device)
|
||||
instance.eval()
|
||||
return instance
|
||||
|
||||
|
||||
def resize_with_pad(img: torch.Tensor, height: int, width: int, pad_value: float = 0.0) -> torch.Tensor:
|
||||
if img.ndim != 4:
|
||||
raise ValueError(f"(b,c,h,w) expected, but got {img.shape}")
|
||||
|
||||
current_height, current_width = img.shape[2:]
|
||||
if current_height == height and current_width == width:
|
||||
return img
|
||||
|
||||
ratio = max(current_width / width, current_height / height)
|
||||
resized_height = int(current_height / ratio)
|
||||
resized_width = int(current_width / ratio)
|
||||
resized_img = F.interpolate(
|
||||
img, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
||||
)
|
||||
|
||||
pad_height = max(0, height - resized_height)
|
||||
pad_width = max(0, width - resized_width)
|
||||
padded_img = F.pad(resized_img, (pad_width, 0, pad_height, 0), value=pad_value)
|
||||
return padded_img
|
||||
|
||||
|
||||
def pad_vector(vector: Tensor, new_dim: int) -> Tensor:
|
||||
if vector.shape[-1] == new_dim:
|
||||
return vector
|
||||
if new_dim == 0:
|
||||
shape = list(vector.shape)
|
||||
shape[-1] = 0
|
||||
return vector.new_zeros(*shape)
|
||||
shape = list(vector.shape)
|
||||
current_dim = shape[-1]
|
||||
shape[-1] = new_dim
|
||||
new_vector = vector.new_zeros(*shape)
|
||||
length = min(current_dim, new_dim)
|
||||
new_vector[..., :length] = vector[..., :length]
|
||||
return new_vector
|
||||
|
||||
|
||||
def pad_tensor_along_dim(tensor: Tensor, target_len: int, dim: int = 1) -> Tensor:
|
||||
current_len = tensor.size(dim)
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len > target_len:
|
||||
slices = [slice(None)] * tensor.dim()
|
||||
slices[dim] = slice(0, target_len)
|
||||
return tensor[tuple(slices)]
|
||||
pad_shape = list(tensor.shape)
|
||||
pad_shape[dim] = target_len - current_len
|
||||
pad_tensor = tensor.new_zeros(pad_shape)
|
||||
return torch.cat([tensor, pad_tensor], dim=dim)
|
||||
@@ -0,0 +1,554 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 The HuggingFace Inc. team and 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.datasets.factory import IMAGENET_STATS
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.policies.xvla.utils import rotate6d_to_axis_angle
|
||||
from lerobot.processor import (
|
||||
AddBatchDimensionProcessorStep,
|
||||
DeviceProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
ObservationProcessorStep,
|
||||
PolicyAction,
|
||||
PolicyProcessorPipeline,
|
||||
ProcessorStep,
|
||||
ProcessorStepRegistry,
|
||||
RenameObservationsProcessorStep,
|
||||
TokenizerProcessorStep,
|
||||
UnnormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import policy_action_to_transition, transition_to_policy_action
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.utils.constants import (
|
||||
OBS_IMAGES,
|
||||
OBS_STATE,
|
||||
POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
)
|
||||
|
||||
|
||||
def make_xvla_pre_post_processors(
|
||||
config: XVLAConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Build the LeRobot processor pipelines for XVLA.
|
||||
"""
|
||||
|
||||
features = {**config.input_features, **config.output_features}
|
||||
input_steps = [
|
||||
RenameObservationsProcessorStep(rename_map={}),
|
||||
AddBatchDimensionProcessorStep(),
|
||||
TokenizerProcessorStep(
|
||||
tokenizer_name=config.tokenizer_name,
|
||||
max_length=config.tokenizer_max_length,
|
||||
padding=config.pad_language_to,
|
||||
padding_side=config.tokenizer_padding_side,
|
||||
),
|
||||
XVLAImageToFloatProcessorStep(),
|
||||
XVLAImageNetNormalizeProcessorStep(),
|
||||
XVLAAddDomainIdProcessorStep(),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
NormalizerProcessorStep(
|
||||
features=features, norm_map=config.normalization_mapping, stats=dataset_stats
|
||||
),
|
||||
]
|
||||
output_steps = [
|
||||
UnnormalizerProcessorStep(
|
||||
features=config.output_features,
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
DeviceProcessorStep(device="cpu"),
|
||||
]
|
||||
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=input_steps,
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=output_steps,
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
# Custom XVLA processor steps
|
||||
@dataclass
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
This step handles the specific observation structure from LIBERO environments,
|
||||
which includes nested robot_state dictionaries and image observations.
|
||||
|
||||
**State Processing:**
|
||||
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||
gripper, and joint information.
|
||||
- Extracts and concatenates:
|
||||
- End-effector position (3D)
|
||||
- End-effector quaternion converted to axis-angle (3D)
|
||||
- Gripper joint positions (2D)
|
||||
- Maps the concatenated state to `"observation.state"`.
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
img = processed_obs[key]
|
||||
|
||||
if key == f"{OBS_IMAGES}.image":
|
||||
# Flip both H and W
|
||||
img = torch.flip(img, dims=[2, 3])
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
eef_mat = robot_state["eef"]["mat"] # (B, 3, 3)
|
||||
eef_rot6d = self._mat_to_rotate6d(eef_mat) # (B, 6)
|
||||
|
||||
extra = torch.zeros((eef_pos.shape[0], 1), dtype=torch.float32, device=eef_pos.device)
|
||||
|
||||
proprio_state = torch.cat((eef_pos, eef_rot6d, extra), dim=-1) # (B, 10)
|
||||
state = torch.cat((proprio_state, torch.zeros_like(proprio_state)), dim=-1) # (B, 20)
|
||||
# ensure float32
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||
"""
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
shape=(20,),
|
||||
dtype="float32",
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
def _mat_to_rotate6d(self, rot_mats: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched rotation matrices (B, 3, 3) into 6D rotation representation (B, 6).
|
||||
|
||||
Args:
|
||||
rot_mats (Tensor): Rotation matrices of shape (B, 3, 3)
|
||||
|
||||
Returns:
|
||||
Tensor: 6D rotation representation, shape (B, 6)
|
||||
|
||||
Raises:
|
||||
TypeError: if input is not a torch tensor
|
||||
ValueError: if shape is not (B, 3, 3)
|
||||
"""
|
||||
|
||||
if not isinstance(rot_mats, torch.Tensor):
|
||||
raise TypeError(f"mat_to_rot6d expects a torch.Tensor, got {type(rot_mats)}")
|
||||
|
||||
if rot_mats.ndim != 3 or rot_mats.shape[1:] != (3, 3):
|
||||
raise ValueError(f"mat_to_rot6d expects shape (B, 3, 3), got {tuple(rot_mats.shape)}")
|
||||
|
||||
rot_mats = rot_mats.to(torch.float32)
|
||||
|
||||
col1 = rot_mats[:, :3, 0] # (B, 3)
|
||||
col2 = rot_mats[:, :3, 1] # (B, 3)
|
||||
|
||||
rot6d = torch.cat([col1, col2], dim=-1) # (B, 6)
|
||||
|
||||
return rot6d
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_image_scale")
|
||||
class XVLAImageScaleProcessorStep(ProcessorStep):
|
||||
"""Scale image observations by 255 to convert from [0, 1] to [0, 255] range.
|
||||
|
||||
This processor step multiplies all image observations by 255, which is required
|
||||
for XVLA models that expect images in uint8-like range.
|
||||
|
||||
Args:
|
||||
image_keys: List of observation keys that contain images to scale.
|
||||
If None, will automatically detect keys starting with "observation.images."
|
||||
"""
|
||||
|
||||
image_keys: list[str] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Scale image observations by 255."""
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
if obs is None:
|
||||
return new_transition
|
||||
|
||||
# Make a copy of observations to avoid modifying the original
|
||||
obs = obs.copy()
|
||||
|
||||
# Determine which keys to scale
|
||||
keys_to_scale = self.image_keys
|
||||
if keys_to_scale is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_scale = [k for k in obs if k.startswith("observation.images.")]
|
||||
|
||||
# Scale each image
|
||||
for key in keys_to_scale:
|
||||
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||
obs[key] = obs[key] * 255
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""Image scaling doesn't change feature structure."""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"image_keys": self.image_keys,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_image_to_float")
|
||||
class XVLAImageToFloatProcessorStep(ProcessorStep):
|
||||
"""Convert image observations from [0, 255] to [0, 1] range.
|
||||
|
||||
This processor step divides image observations by 255 to convert from uint8-like
|
||||
range [0, 255] to float range [0, 1]. This is typically used when loading images
|
||||
that are stored as uint8 values.
|
||||
|
||||
Args:
|
||||
image_keys: List of observation keys that contain images to convert.
|
||||
If None, will automatically detect keys starting with "observation.images."
|
||||
validate_range: If True, validates that input values are in [0, 255] range (default: True)
|
||||
|
||||
Raises:
|
||||
ValueError: If validate_range is True and image values are not in [0, 255] range.
|
||||
"""
|
||||
|
||||
image_keys: list[str] | None = None
|
||||
validate_range: bool = True
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Convert image observations from [0, 255] to [0, 1]."""
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
if obs is None:
|
||||
return new_transition
|
||||
|
||||
# Make a copy of observations to avoid modifying the original
|
||||
obs = obs.copy()
|
||||
|
||||
# Determine which keys to convert
|
||||
keys_to_convert = self.image_keys
|
||||
if keys_to_convert is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_convert = [k for k in obs if k.startswith("observation.images.")]
|
||||
|
||||
# Convert each image
|
||||
for key in keys_to_convert:
|
||||
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||
tensor = obs[key]
|
||||
|
||||
min_val = tensor.min().item()
|
||||
max_val = tensor.max().item()
|
||||
|
||||
if max_val <= 1.0:
|
||||
obs[key] = tensor.float() # ensure float dtype, but no division
|
||||
continue
|
||||
# Validate that values are in [0, 255] range if requested
|
||||
if self.validate_range and (min_val < 0.0 or max_val > 255.0):
|
||||
raise ValueError(
|
||||
f"Image '{key}' has values outside [0, 255] range: "
|
||||
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||
f"Cannot convert to [0, 1] range."
|
||||
)
|
||||
|
||||
# Convert to float and divide by 255
|
||||
obs[key] = tensor.float() / 255.0
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""Image conversion doesn't change feature structure."""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"image_keys": self.image_keys,
|
||||
"validate_range": self.validate_range,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_imagenet_normalize")
|
||||
class XVLAImageNetNormalizeProcessorStep(ProcessorStep):
|
||||
"""Normalize image observations using ImageNet statistics.
|
||||
|
||||
This processor step applies ImageNet normalization (mean and std) to image observations.
|
||||
It validates that input values are in the [0, 1] range before normalizing.
|
||||
|
||||
The normalization formula is: (image - mean) / std
|
||||
|
||||
Args:
|
||||
image_keys: List of observation keys that contain images to normalize.
|
||||
If None, will automatically detect keys starting with "observation.images."
|
||||
|
||||
Raises:
|
||||
ValueError: If image values are not in the [0, 1] range.
|
||||
"""
|
||||
|
||||
image_keys: list[str] | None = None
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Normalize image observations using ImageNet statistics."""
|
||||
new_transition = transition.copy()
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
if obs is None:
|
||||
return new_transition
|
||||
|
||||
# Make a copy of observations to avoid modifying the original
|
||||
obs = obs.copy()
|
||||
|
||||
# Determine which keys to normalize
|
||||
keys_to_normalize = self.image_keys
|
||||
if keys_to_normalize is None:
|
||||
# Auto-detect image keys
|
||||
keys_to_normalize = [k for k in obs if k.startswith("observation.images.")]
|
||||
|
||||
# Normalize each image
|
||||
for key in keys_to_normalize:
|
||||
if key in obs and isinstance(obs[key], torch.Tensor):
|
||||
tensor = obs[key]
|
||||
|
||||
# Validate that values are in [0, 1] range
|
||||
min_val = tensor.min().item()
|
||||
max_val = tensor.max().item()
|
||||
if min_val < 0.0 or max_val > 1.0:
|
||||
raise ValueError(
|
||||
f"Image '{key}' has values outside [0, 1] range: "
|
||||
f"min={min_val:.4f}, max={max_val:.4f}. "
|
||||
f"ImageNet normalization requires input values in [0, 1]."
|
||||
)
|
||||
|
||||
# Apply ImageNet normalization
|
||||
mean = torch.tensor(IMAGENET_STATS["mean"], device=tensor.device, dtype=tensor.dtype)
|
||||
std = torch.tensor(IMAGENET_STATS["std"], device=tensor.device, dtype=tensor.dtype)
|
||||
|
||||
# Expand mean/std to match tensor dims (e.g., BCHW or BNCHW)
|
||||
while mean.dim() < tensor.dim():
|
||||
mean = mean.unsqueeze(0)
|
||||
std = std.unsqueeze(0)
|
||||
|
||||
# Normalize: (image - mean) / std
|
||||
obs[key] = (tensor - mean) / std
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = obs
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""ImageNet normalization doesn't change feature structure."""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"image_keys": self.image_keys,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_add_domain_id")
|
||||
class XVLAAddDomainIdProcessorStep(ProcessorStep):
|
||||
"""Add domain_id to complementary data.
|
||||
|
||||
This processor step adds a domain_id tensor to the complementary data,
|
||||
which is used by XVLA to identify different robot embodiments or task domains.
|
||||
|
||||
Args:
|
||||
domain_id: The domain ID to add (default: 3)
|
||||
"""
|
||||
|
||||
domain_id: int = 0
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Add domain_id to complementary data."""
|
||||
new_transition = transition.copy()
|
||||
comp = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
comp = {} if comp is None else comp.copy()
|
||||
|
||||
# Infer batch size from observation tensors
|
||||
obs = new_transition.get(TransitionKey.OBSERVATION, {})
|
||||
batch_size = 1
|
||||
if obs:
|
||||
for v in obs.values():
|
||||
if isinstance(v, torch.Tensor):
|
||||
batch_size = v.shape[0]
|
||||
break
|
||||
|
||||
# Add domain_id tensor
|
||||
comp["domain_id"] = torch.tensor([int(self.domain_id)] * batch_size, dtype=torch.long)
|
||||
|
||||
new_transition[TransitionKey.COMPLEMENTARY_DATA] = comp
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""Domain ID addition doesn't change feature structure."""
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"domain_id": self.domain_id,
|
||||
}
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="xvla_rotation_6d_to_axis_angle")
|
||||
class XVLARotation6DToAxisAngleProcessorStep(ProcessorStep):
|
||||
"""Convert 6D rotation representation to axis-angle and reorganize action dimensions.
|
||||
|
||||
This processor step takes actions with 6D rotation representation and converts them to
|
||||
axis-angle representation, reorganizing the action dimensions as:
|
||||
- action[:, :3] -> target_eef (end-effector position)
|
||||
- action[:, 3:9] -> 6D rotation (converted to axis-angle, 3D)
|
||||
- action[:, 9:10] -> gripper action
|
||||
|
||||
Final output: [target_eef (3), axis_angle (3), gripper (1)] = 7D action
|
||||
|
||||
Args:
|
||||
expected_action_dim: Expected input action dimension (default: 10, supports 6D rotation + extras)
|
||||
"""
|
||||
|
||||
expected_action_dim: int = 10
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Convert 6D rotation to axis-angle in action."""
|
||||
new_transition = transition.copy()
|
||||
action = new_transition.get(TransitionKey.ACTION)
|
||||
|
||||
if action is None or not isinstance(action, torch.Tensor):
|
||||
return new_transition
|
||||
|
||||
# Convert to numpy for processing
|
||||
device = action.device
|
||||
dtype = action.dtype
|
||||
action_np = action.cpu().numpy()
|
||||
|
||||
# Extract components
|
||||
# action shape: (B, D) where D >= 10
|
||||
target_eef = action_np[:, :3] # (B, 3)
|
||||
rotation_6d = action_np[:, 3:9] # (B, 6)
|
||||
target_act = action_np[:, 9:10] # (B, 1)
|
||||
|
||||
# Convert 6D rotation to axis-angle
|
||||
target_axis = rotate6d_to_axis_angle(rotation_6d) # (B, 3)
|
||||
|
||||
# Concatenate: [eef (3), axis_angle (3), gripper (1)] = 7D
|
||||
action_np = np.concatenate([target_eef, target_axis, target_act], axis=-1)
|
||||
|
||||
# Convert gripper action to -1 or 1
|
||||
action_np[:, -1] = np.where(action_np[:, -1] > 0.5, 1.0, -1.0)
|
||||
|
||||
# Convert back to tensor
|
||||
action = torch.from_numpy(action_np).to(device=device, dtype=dtype)
|
||||
|
||||
new_transition[TransitionKey.ACTION] = action
|
||||
return new_transition
|
||||
|
||||
def transform_features(self, features):
|
||||
"""Rotation conversion changes action dimension from 10 to 7."""
|
||||
# Note: This is a simplified version. In practice, you might want to
|
||||
# update the action feature shape in the features dict.
|
||||
return features
|
||||
|
||||
def get_config(self) -> dict[str, Any]:
|
||||
"""Return serializable configuration."""
|
||||
return {
|
||||
"expected_action_dim": self.expected_action_dim,
|
||||
}
|
||||
|
||||
|
||||
def make_xvla_libero_pre_post_processors() -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Build the LeRobot processor pipelines for XVLA with LIBERO environment.
|
||||
"""
|
||||
pre_processor_steps: list[ProcessorStep] = []
|
||||
post_processor_steps: list[ProcessorStep] = []
|
||||
pre_processor_steps.extend(
|
||||
[LiberoProcessorStep(), XVLAImageNetNormalizeProcessorStep(), XVLAAddDomainIdProcessorStep()]
|
||||
)
|
||||
post_processor_steps.extend([XVLARotation6DToAxisAngleProcessorStep()])
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=pre_processor_steps,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=post_processor_steps,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,415 @@
|
||||
# ------------------------------------------------------------------------------
|
||||
# Copyright 2025 2toINF (https://github.com/2toINF)
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ------------------------------------------------------------------------------
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterable
|
||||
from functools import partial
|
||||
from typing import Final
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as functional
|
||||
|
||||
# ------------------------------- Small utils ----------------------------------
|
||||
|
||||
|
||||
def _to_2tuple(x) -> tuple:
|
||||
"""Minimal replacement for timm.layers.to_2tuple."""
|
||||
if isinstance(x, Iterable) and not isinstance(x, (str, bytes)):
|
||||
t = tuple(x)
|
||||
return (t[0], t[1]) if len(t) >= 2 else (t[0], t[0])
|
||||
return (x, x)
|
||||
|
||||
|
||||
def _has_sdp_attention() -> bool:
|
||||
"""Check if we can use PyTorch fused scaled_dot_product_attention."""
|
||||
return hasattr(functional, "scaled_dot_product_attention")
|
||||
|
||||
|
||||
# ---------------------------------- MLP --------------------------------------
|
||||
|
||||
|
||||
class Mlp(nn.Module):
|
||||
"""
|
||||
MLP used in ViT-style blocks.
|
||||
|
||||
Supports Linear or 1x1 Conv 'linear_layer' for token/channel mixing.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
hidden_features: int | None = None,
|
||||
out_features: int | None = None,
|
||||
norm_layer: type[nn.Module] | None = None,
|
||||
bias: bool | tuple[bool, bool] = True,
|
||||
drop: float | tuple[float, float] = 0.0,
|
||||
use_conv: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
out_features = out_features or in_features
|
||||
hidden_features = hidden_features or in_features
|
||||
bias = _to_2tuple(bias)
|
||||
drop_probs = _to_2tuple(drop)
|
||||
linear_layer = partial(nn.Conv2d, kernel_size=1) if use_conv else nn.Linear
|
||||
|
||||
self.fc1 = linear_layer(in_features, hidden_features, bias=bias[0])
|
||||
self.act = nn.GELU(approximate="tanh")
|
||||
self.drop1 = nn.Dropout(drop_probs[0])
|
||||
self.norm = norm_layer(hidden_features) if norm_layer is not None else nn.Identity()
|
||||
self.fc2 = linear_layer(hidden_features, out_features, bias=bias[1])
|
||||
self.drop2 = nn.Dropout(drop_probs[1])
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
# Expect [B, T, C] for Linear variant; caller is responsible for shapes.
|
||||
x = self.fc1(x)
|
||||
x = self.act(x)
|
||||
x = self.drop1(x)
|
||||
x = self.norm(x)
|
||||
x = self.fc2(x)
|
||||
x = self.drop2(x)
|
||||
return x
|
||||
|
||||
|
||||
# -------------------------------- Attention ----------------------------------
|
||||
|
||||
|
||||
class Attention(nn.Module):
|
||||
"""
|
||||
Multi-Head Self-Attention with optional fused SDPA fallback.
|
||||
|
||||
If PyTorch provides `scaled_dot_product_attention`, it will be used
|
||||
(usually faster and more stable); otherwise we use a manual implementation.
|
||||
"""
|
||||
|
||||
fused_attn: Final[bool]
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dim: int,
|
||||
num_heads: int = 8,
|
||||
qkv_bias: bool = False,
|
||||
qk_norm: bool = False,
|
||||
attn_drop: float = 0.0,
|
||||
proj_drop: float = 0.0,
|
||||
norm_layer: type[nn.Module] = nn.LayerNorm,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
assert dim % num_heads == 0, "dim should be divisible by num_heads"
|
||||
self.num_heads = num_heads
|
||||
self.head_dim = dim // num_heads
|
||||
self.scale = self.head_dim**-0.5
|
||||
self.fused_attn = _has_sdp_attention()
|
||||
|
||||
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
||||
self.q_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.k_norm = norm_layer(self.head_dim) if qk_norm else nn.Identity()
|
||||
self.attn_drop = nn.Dropout(attn_drop)
|
||||
self.proj = nn.Linear(dim, dim)
|
||||
self.proj_drop = nn.Dropout(proj_drop)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor, shape [batch_size, seq_len, channels]
|
||||
Input sequence.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor, shape [batch_size, seq_len, channels]
|
||||
Output sequence after MHSA + projection.
|
||||
"""
|
||||
batch_size, seq_len, channels = x.shape
|
||||
qkv = (
|
||||
self.qkv(x)
|
||||
.reshape(batch_size, seq_len, 3, self.num_heads, self.head_dim)
|
||||
.permute(2, 0, 3, 1, 4) # 3 x [batch_size, num_heads, seq_len, head_dim]
|
||||
)
|
||||
q, k, v = qkv.unbind(0) # each: [batch_size, num_heads, seq_len, head_dim]
|
||||
q, k = self.q_norm(q), self.k_norm(k)
|
||||
|
||||
if self.fused_attn:
|
||||
x = functional.scaled_dot_product_attention(
|
||||
q,
|
||||
k,
|
||||
v,
|
||||
dropout_p=self.attn_drop.p if self.training else 0.0,
|
||||
) # [batch_size, num_heads, seq_len, head_dim]
|
||||
else:
|
||||
q = q * self.scale
|
||||
attn = q @ k.transpose(-2, -1) # [batch_size, num_heads, seq_len, seq_len]
|
||||
attn = attn.softmax(dim=-1)
|
||||
attn = self.attn_drop(attn)
|
||||
x = attn @ v # [batch_size, num_heads, seq_len, head_dim]
|
||||
|
||||
x = x.transpose(1, 2).reshape(batch_size, seq_len, channels) # [batch_size, seq_len, channels]
|
||||
x = self.proj(x)
|
||||
x = self.proj_drop(x)
|
||||
return x
|
||||
|
||||
|
||||
# ------------------------------- Utilities -----------------------------------
|
||||
|
||||
|
||||
def basic_init(module: nn.Module) -> None:
|
||||
"""
|
||||
Apply a basic initialization scheme to Linear layers.
|
||||
|
||||
- Weight: Xavier uniform initialization.
|
||||
- Bias: Set to zero.
|
||||
"""
|
||||
if isinstance(module, nn.Linear):
|
||||
nn.init.xavier_uniform_(module.weight)
|
||||
if module.bias is not None:
|
||||
nn.init.constant_(module.bias, 0.0)
|
||||
|
||||
|
||||
def timestep_embedding(t: torch.Tensor, dim: int, max_period: int = 100) -> torch.Tensor:
|
||||
"""
|
||||
Create sinusoidal timestep embeddings.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
t : torch.Tensor
|
||||
Shape [B]. Each element is a timestep index, may be fractional.
|
||||
dim : int
|
||||
Dimensionality of the output embedding.
|
||||
max_period : int, default=100
|
||||
Controls the minimum frequency of the sinusoids.
|
||||
|
||||
Returns
|
||||
-------
|
||||
torch.Tensor
|
||||
Shape [B, dim]. Sinusoidal embeddings.
|
||||
"""
|
||||
half = dim // 2
|
||||
freqs = torch.exp(
|
||||
-math.log(max_period) * torch.arange(start=0, end=half, dtype=t.dtype, device=t.device) / half
|
||||
)
|
||||
args = t[:, None] * freqs[None]
|
||||
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
||||
if dim % 2 == 1:
|
||||
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
||||
return embedding
|
||||
|
||||
|
||||
# ------------------------------- Core Layers ----------------------------------
|
||||
|
||||
|
||||
class DomainAwareLinear(nn.Module):
|
||||
"""
|
||||
Linear layer with domain-conditioned parameters (per-sample).
|
||||
|
||||
Each domain has its own weight and bias vectors, stored in embeddings.
|
||||
"""
|
||||
|
||||
def __init__(self, input_size: int, output_size: int, num_domains: int = 20) -> None:
|
||||
super().__init__()
|
||||
self.input_size = input_size
|
||||
self.output_size = output_size
|
||||
self.fc = nn.Embedding(num_domains, output_size * input_size)
|
||||
self.bias = nn.Embedding(num_domains, output_size)
|
||||
nn.init.xavier_uniform_(self.fc.weight)
|
||||
nn.init.zeros_(self.bias.weight)
|
||||
|
||||
def forward(self, x: torch.Tensor, domain_id: torch.LongTensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor
|
||||
[B, I] or [B, T, I]
|
||||
domain_id : LongTensor
|
||||
[B], domain indices.
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
[batch_size, output_size] or [batch_size, seq_len, output_size]
|
||||
"""
|
||||
batch_size = domain_id.shape[0]
|
||||
squeeze_seq = False
|
||||
if x.dim() == 2:
|
||||
x = x.unsqueeze(1)
|
||||
squeeze_seq = True
|
||||
weight = self.fc(domain_id).view(batch_size, self.input_size, self.output_size)
|
||||
bias = self.bias(domain_id).view(batch_size, self.output_size)
|
||||
y = torch.matmul(x, weight) + bias.view(batch_size, 1, self.output_size)
|
||||
if squeeze_seq:
|
||||
y = y.squeeze(1)
|
||||
return y
|
||||
|
||||
|
||||
class TransformerBlock(nn.Module):
|
||||
"""
|
||||
Standard Transformer block (pre-LN): LN → MHSA → residual, LN → MLP → residual.
|
||||
"""
|
||||
|
||||
def __init__(self, hidden_size: int, num_heads: int, mlp_ratio: float = 4.0) -> None:
|
||||
super().__init__()
|
||||
self.norm1 = nn.LayerNorm(hidden_size)
|
||||
self.norm2 = nn.LayerNorm(hidden_size)
|
||||
self.attn = Attention(hidden_size, num_heads=num_heads, qkv_bias=True, attn_drop=0.1)
|
||||
self.mlp = Mlp(
|
||||
in_features=hidden_size,
|
||||
hidden_features=int(hidden_size * mlp_ratio),
|
||||
drop=0.1,
|
||||
)
|
||||
|
||||
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Parameters
|
||||
----------
|
||||
x : Tensor, [B, T, H]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor, [B, T, H]
|
||||
"""
|
||||
x = x + self.attn(self.norm1(x))
|
||||
x = x + self.mlp(self.norm2(x))
|
||||
return x
|
||||
|
||||
|
||||
# --------------------------- Main Model ---------------------------------------
|
||||
|
||||
|
||||
class SoftPromptedTransformer(nn.Module):
|
||||
"""
|
||||
Multi-modal, domain-aware Transformer with optional soft prompts.
|
||||
|
||||
See parameter and forward I/O descriptions inside the docstrings.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
hidden_size: int = 768,
|
||||
multi_modal_input_size: int = 768,
|
||||
depth: int = 24,
|
||||
num_heads: int = 16,
|
||||
mlp_ratio: float = 4.0,
|
||||
num_domains: int = 20,
|
||||
dim_action: int = 20,
|
||||
dim_propio: int = 20,
|
||||
dim_time: int = 32,
|
||||
len_soft_prompts: int = 32,
|
||||
max_len_seq: int = 512,
|
||||
use_hetero_proj: bool = False,
|
||||
) -> None:
|
||||
super().__init__()
|
||||
self.hidden_size = hidden_size
|
||||
self.dim_action = dim_action
|
||||
self.dim_time = dim_time
|
||||
self.len_soft_prompts = len_soft_prompts
|
||||
self.use_hetero_proj = use_hetero_proj
|
||||
|
||||
self.blocks = nn.ModuleList(
|
||||
[TransformerBlock(hidden_size, num_heads, mlp_ratio=mlp_ratio) for _ in range(depth)]
|
||||
)
|
||||
|
||||
if use_hetero_proj:
|
||||
self.vlm_proj = DomainAwareLinear(multi_modal_input_size, hidden_size, num_domains=num_domains)
|
||||
self.aux_visual_proj = DomainAwareLinear(
|
||||
multi_modal_input_size, hidden_size, num_domains=num_domains
|
||||
)
|
||||
else:
|
||||
self.vlm_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
self.aux_visual_proj = nn.Linear(multi_modal_input_size, hidden_size)
|
||||
|
||||
self.pos_emb = nn.Parameter(torch.zeros(1, max_len_seq, hidden_size), requires_grad=True)
|
||||
nn.init.normal_(self.pos_emb, std=0.02)
|
||||
|
||||
self.norm = nn.LayerNorm(hidden_size)
|
||||
self.action_encoder = DomainAwareLinear(
|
||||
dim_action + dim_time + dim_propio, hidden_size, num_domains=num_domains
|
||||
)
|
||||
self.action_decoder = DomainAwareLinear(hidden_size, dim_action, num_domains=num_domains)
|
||||
|
||||
if len_soft_prompts > 0:
|
||||
self.soft_prompt_hub = nn.Embedding(num_domains, len_soft_prompts * hidden_size)
|
||||
nn.init.normal_(self.soft_prompt_hub.weight, std=0.02)
|
||||
|
||||
self.apply(basic_init)
|
||||
|
||||
def forward(
|
||||
self,
|
||||
domain_id: torch.LongTensor,
|
||||
vlm_features: torch.Tensor,
|
||||
aux_visual_inputs: torch.Tensor,
|
||||
action_with_noise: torch.Tensor,
|
||||
proprio: torch.Tensor,
|
||||
t: torch.Tensor,
|
||||
) -> torch.Tensor:
|
||||
"""
|
||||
Forward pass.
|
||||
|
||||
Inputs
|
||||
------
|
||||
domain_id : [B]
|
||||
vlm_features : [B, T_vlm, D]
|
||||
aux_visual_inputs : [B, T_aux, D]
|
||||
action_with_noise : [B, T_action, dim_action]
|
||||
proprio : [B, dim_propio]
|
||||
t : [B]
|
||||
|
||||
Returns
|
||||
-------
|
||||
Tensor
|
||||
Predicted actions, [batch_size, num_actions, dim_action]
|
||||
"""
|
||||
batch_size, num_actions = action_with_noise.shape[:2]
|
||||
|
||||
# Encode (action + proprio + time) → tokens
|
||||
time_emb = timestep_embedding(t, self.dim_time) # [batch_size, dim_time]
|
||||
time_tokens = time_emb.unsqueeze(1).expand(batch_size, num_actions, self.dim_time)
|
||||
proprio_tokens = proprio.unsqueeze(1).expand(batch_size, num_actions, proprio.shape[-1])
|
||||
action_tokens = torch.cat([action_with_noise, proprio_tokens, time_tokens], dim=-1)
|
||||
x = self.action_encoder(action_tokens, domain_id) # [batch_size, num_actions, hidden_size]
|
||||
|
||||
# Project visual streams and concatenate
|
||||
if self.use_hetero_proj:
|
||||
x = torch.cat(
|
||||
[
|
||||
x,
|
||||
self.vlm_proj(vlm_features, domain_id),
|
||||
self.aux_visual_proj(aux_visual_inputs, domain_id),
|
||||
],
|
||||
dim=1,
|
||||
)
|
||||
else:
|
||||
x = torch.cat([x, self.vlm_proj(vlm_features), self.aux_visual_proj(aux_visual_inputs)], dim=1)
|
||||
|
||||
# Add positional embeddings (truncate if needed)
|
||||
seq_len = x.shape[1]
|
||||
if seq_len > self.pos_emb.shape[1]:
|
||||
raise ValueError(f"Sequence length {seq_len} exceeds max_len_seq={self.pos_emb.shape[1]}.")
|
||||
x = x + self.pos_emb[:, :seq_len, :]
|
||||
|
||||
# Append soft prompts
|
||||
if self.len_soft_prompts > 0:
|
||||
soft_prompts = self.soft_prompt_hub(domain_id).view(
|
||||
batch_size, self.len_soft_prompts, self.hidden_size
|
||||
)
|
||||
x = torch.cat([x, soft_prompts], dim=1)
|
||||
|
||||
# Transformer backbone
|
||||
for block in self.blocks:
|
||||
x = block(x)
|
||||
|
||||
# Decode only the action segment
|
||||
return self.action_decoder(self.norm(x[:, :num_actions]), domain_id)
|
||||
@@ -0,0 +1,138 @@
|
||||
import math
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def mat2quat(rmat):
|
||||
"""
|
||||
Converts given rotation matrix to quaternion.
|
||||
|
||||
Args:
|
||||
rmat (np.array): 3x3 rotation matrix
|
||||
|
||||
Returns:
|
||||
np.array: (x,y,z,w) float quaternion angles
|
||||
"""
|
||||
mat = np.asarray(rmat).astype(np.float32)[:3, :3]
|
||||
|
||||
m00 = mat[0, 0]
|
||||
m01 = mat[0, 1]
|
||||
m02 = mat[0, 2]
|
||||
m10 = mat[1, 0]
|
||||
m11 = mat[1, 1]
|
||||
m12 = mat[1, 2]
|
||||
m20 = mat[2, 0]
|
||||
m21 = mat[2, 1]
|
||||
m22 = mat[2, 2]
|
||||
# symmetric matrix k
|
||||
k = np.array(
|
||||
[
|
||||
[m00 - m11 - m22, np.float32(0.0), np.float32(0.0), np.float32(0.0)],
|
||||
[m01 + m10, m11 - m00 - m22, np.float32(0.0), np.float32(0.0)],
|
||||
[m02 + m20, m12 + m21, m22 - m00 - m11, np.float32(0.0)],
|
||||
[m21 - m12, m02 - m20, m10 - m01, m00 + m11 + m22],
|
||||
]
|
||||
)
|
||||
k /= 3.0
|
||||
# quaternion is Eigen vector of k that corresponds to largest eigenvalue
|
||||
w, v = np.linalg.eigh(k)
|
||||
inds = np.array([3, 0, 1, 2])
|
||||
q1 = v[inds, np.argmax(w)]
|
||||
if q1[0] < 0.0:
|
||||
np.negative(q1, q1)
|
||||
inds = np.array([1, 2, 3, 0])
|
||||
return q1[inds]
|
||||
|
||||
|
||||
def quat2axisangle(quat):
|
||||
"""
|
||||
Converts quaternion to axis-angle format.
|
||||
Returns a unit vector direction scaled by its angle in radians.
|
||||
|
||||
Args:
|
||||
quat (np.array): (x,y,z,w) vec4 float angles
|
||||
|
||||
Returns:
|
||||
np.array: (ax,ay,az) axis-angle exponential coordinates
|
||||
"""
|
||||
# clip quaternion
|
||||
if quat[3] > 1.0:
|
||||
quat[3] = 1.0
|
||||
elif quat[3] < -1.0:
|
||||
quat[3] = -1.0
|
||||
|
||||
den = np.sqrt(1.0 - quat[3] * quat[3])
|
||||
if math.isclose(den, 0.0):
|
||||
# This is (close to) a zero degree rotation, immediately return
|
||||
return np.zeros(3)
|
||||
|
||||
return (quat[:3] * 2.0 * math.acos(quat[3])) / den
|
||||
|
||||
|
||||
def rotate6d_to_axis_angle(r6d):
|
||||
"""
|
||||
r6d: np.ndarray, shape (N, 6)
|
||||
return: np.ndarray, shape (N, 3), axis-angle vectors
|
||||
"""
|
||||
flag = 0
|
||||
if len(r6d.shape) == 1:
|
||||
r6d = r6d[None, ...]
|
||||
flag = 1
|
||||
|
||||
a1 = r6d[:, 0:3]
|
||||
a2 = r6d[:, 3:6]
|
||||
|
||||
# b1
|
||||
b1 = a1 / (np.linalg.norm(a1, axis=-1, keepdims=True) + 1e-6)
|
||||
|
||||
# b2
|
||||
dot_prod = np.sum(b1 * a2, axis=-1, keepdims=True)
|
||||
b2_orth = a2 - dot_prod * b1
|
||||
b2 = b2_orth / (np.linalg.norm(b2_orth, axis=-1, keepdims=True) + 1e-6)
|
||||
|
||||
# b3
|
||||
b3 = np.cross(b1, b2, axis=-1)
|
||||
|
||||
rotation_matrix = np.stack([b1, b2, b3], axis=-1) # shape: (N, 3, 3)
|
||||
|
||||
axis_angle_list = []
|
||||
for i in range(rotation_matrix.shape[0]):
|
||||
quat = mat2quat(rotation_matrix[i])
|
||||
axis_angle = quat2axisangle(quat)
|
||||
axis_angle_list.append(axis_angle)
|
||||
|
||||
axis_angle_array = np.stack(axis_angle_list, axis=0) # shape: (N, 3)
|
||||
|
||||
if flag == 1:
|
||||
axis_angle_array = axis_angle_array[0]
|
||||
|
||||
return axis_angle_array
|
||||
|
||||
|
||||
def mat_to_rotate6d(abs_action):
|
||||
if len(abs_action.shape) == 2:
|
||||
return np.concatenate([abs_action[:3, 0], abs_action[:3, 1]], axis=-1)
|
||||
elif len(abs_action.shape) == 3:
|
||||
return np.concatenate([abs_action[:, :3, 0], abs_action[:, :3, 1]], axis=-1)
|
||||
else:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def drop_path(x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True):
|
||||
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
|
||||
|
||||
This is the same as the DropConnect impl I created for EfficientNet, etc networks, however,
|
||||
the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper...
|
||||
See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for
|
||||
changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use
|
||||
'survival rate' as the argument.
|
||||
|
||||
"""
|
||||
if drop_prob == 0.0 or not training:
|
||||
return x
|
||||
keep_prob = 1 - drop_prob
|
||||
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
||||
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
||||
if keep_prob > 0.0 and scale_by_keep:
|
||||
random_tensor.div_(keep_prob)
|
||||
return x * random_tensor
|
||||
@@ -78,7 +78,7 @@ from lerobot.transport.utils import (
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
@@ -398,7 +398,7 @@ def act_with_policy(
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
precise_sleep(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
@@ -74,7 +74,7 @@ from lerobot.teleoperators import (
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -114,7 +114,7 @@ def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> No
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
busy_wait(0.015)
|
||||
precise_sleep(0.015)
|
||||
|
||||
|
||||
class RobotEnv(gym.Env):
|
||||
@@ -238,7 +238,7 @@ class RobotEnv(gym.Env):
|
||||
reset_follower_position(self.robot, np.array(self.reset_pose))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
precise_sleep(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
@@ -713,7 +713,7 @@ def control_loop(
|
||||
transition = env_processor(transition)
|
||||
|
||||
# Maintain fps timing
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
precise_sleep(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Pushing dataset to hub")
|
||||
@@ -745,7 +745,7 @@ def replay_trajectory(
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
from .robot_earthrover_mini_plus import EarthRoverMiniPlus
|
||||
|
||||
__all__ = ["EarthRoverMiniPlus", "EarthRoverMiniPlusConfig"]
|
||||
@@ -0,0 +1,35 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Configuration for EarthRover Mini Plus robot."""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("earthrover_mini_plus")
|
||||
@dataclass
|
||||
class EarthRoverMiniPlusConfig(RobotConfig):
|
||||
"""Configuration for EarthRover Mini Plus robot using Frodobots SDK.
|
||||
|
||||
This robot uses cloud-based control via the Frodobots SDK HTTP API.
|
||||
Camera frames are accessed directly through SDK HTTP endpoints.
|
||||
|
||||
Attributes:
|
||||
sdk_url: URL of the Frodobots SDK server (default: http://localhost:8000)
|
||||
"""
|
||||
|
||||
sdk_url: str = "http://localhost:8000"
|
||||
@@ -0,0 +1 @@
|
||||
../../../../docs/source/earthrover_mini_plus.mdx
|
||||
@@ -0,0 +1,473 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""EarthRover Mini Plus robot using Frodobots SDK."""
|
||||
|
||||
import base64
|
||||
import logging
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import requests
|
||||
|
||||
from lerobot.utils.errors import DeviceAlreadyConnectedError, DeviceNotConnectedError
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_earthrover_mini_plus import EarthRoverMiniPlusConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Action feature keys
|
||||
ACTION_LINEAR_VEL = "linear.vel"
|
||||
ACTION_ANGULAR_VEL = "angular.vel"
|
||||
|
||||
# Observation feature keys
|
||||
OBS_FRONT = "front"
|
||||
OBS_REAR = "rear"
|
||||
OBS_LINEAR_VEL = "linear.vel"
|
||||
OBS_BATTERY_LEVEL = "battery.level"
|
||||
OBS_ORIENTATION_DEG = "orientation.deg"
|
||||
OBS_GPS_LATITUDE = "gps.latitude"
|
||||
OBS_GPS_LONGITUDE = "gps.longitude"
|
||||
OBS_GPS_SIGNAL = "gps.signal"
|
||||
OBS_SIGNAL_LEVEL = "signal.level"
|
||||
OBS_VIBRATION = "vibration"
|
||||
OBS_LAMP_STATE = "lamp.state"
|
||||
|
||||
|
||||
class EarthRoverMiniPlus(Robot):
|
||||
"""
|
||||
EarthRover Mini Plus robot controlled via Frodobots SDK HTTP API.
|
||||
|
||||
This robot uses cloud-based control through the Frodobots SDK instead of direct
|
||||
hardware connection. Cameras stream via WebRTC through Agora cloud, and control
|
||||
commands are sent via HTTP POST requests.
|
||||
|
||||
The robot supports:
|
||||
- Dual cameras (front and rear) accessed via SDK HTTP endpoints
|
||||
- Linear and angular velocity control
|
||||
- Battery and orientation telemetry
|
||||
|
||||
Attributes:
|
||||
config: Robot configuration
|
||||
sdk_base_url: URL of the Frodobots SDK server (default: http://localhost:8000)
|
||||
"""
|
||||
|
||||
config_class = EarthRoverMiniPlusConfig
|
||||
name = "earthrover_mini_plus"
|
||||
|
||||
def __init__(self, config: EarthRoverMiniPlusConfig):
|
||||
"""Initialize EarthRover Mini Plus robot.
|
||||
|
||||
Args:
|
||||
config: Robot configuration including SDK URL
|
||||
"""
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
self.sdk_base_url = "http://localhost:8000"
|
||||
|
||||
# Empty cameras dict for compatibility with recording script
|
||||
# Cameras are accessed directly via SDK, not through Camera objects
|
||||
self.cameras = {}
|
||||
self._is_connected = False
|
||||
|
||||
# Cache for camera frames (fallback when requests fail)
|
||||
self._last_front_frame = None
|
||||
self._last_rear_frame = None
|
||||
|
||||
# Cache for robot telemetry data (fallback when requests fail)
|
||||
self._last_robot_data = None
|
||||
|
||||
logger.info(f"Initialized {self.name} with SDK at {self.sdk_base_url}")
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
"""Check if robot is connected to SDK."""
|
||||
return self._is_connected
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
"""Connect to robot via Frodobots SDK.
|
||||
|
||||
Args:
|
||||
calibrate: Not used for SDK-based robot (kept for API compatibility)
|
||||
|
||||
Raises:
|
||||
DeviceAlreadyConnectedError: If robot is already connected
|
||||
DeviceNotConnectedError: If cannot connect to SDK server
|
||||
"""
|
||||
if self._is_connected:
|
||||
raise DeviceAlreadyConnectedError(f"{self.name} is already connected")
|
||||
|
||||
# Verify SDK is running and accessible
|
||||
try:
|
||||
response = requests.get(f"{self.sdk_base_url}/data", timeout=10.0)
|
||||
if response.status_code != 200:
|
||||
raise DeviceNotConnectedError(
|
||||
f"Cannot connect to SDK at {self.sdk_base_url}. "
|
||||
"Make sure it's running: hypercorn main:app --reload"
|
||||
)
|
||||
except requests.RequestException as e:
|
||||
raise DeviceNotConnectedError(f"Cannot connect to SDK at {self.sdk_base_url}: {e}") from e
|
||||
|
||||
self._is_connected = True
|
||||
logger.info(f"{self.name} connected to SDK")
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""Calibration not needed for SDK-based robot."""
|
||||
logger.info("Calibration not required for SDK-based robot")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
"""SDK robot doesn't require calibration.
|
||||
|
||||
Returns:
|
||||
bool: Always True for SDK-based robots
|
||||
"""
|
||||
return True
|
||||
|
||||
def configure(self) -> None:
|
||||
"""Configure robot (no-op for SDK-based robot)."""
|
||||
pass
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
"""Define the observation space for dataset recording.
|
||||
|
||||
Returns:
|
||||
dict: Observation features with types/shapes:
|
||||
- front: (480, 640, 3) - Front camera RGB image
|
||||
- rear: (480, 640, 3) - Rear camera RGB image
|
||||
- linear.vel: float - Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: float - Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: float - Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: float - GPS latitude coordinate
|
||||
- gps.longitude: float - GPS longitude coordinate
|
||||
- gps.signal: float - GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: float - Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: float - Vibration sensor reading
|
||||
- lamp.state: float - Lamp state (0=off, 1=on)
|
||||
"""
|
||||
return {
|
||||
# Cameras (height, width, channels)
|
||||
OBS_FRONT: (480, 640, 3),
|
||||
OBS_REAR: (480, 640, 3),
|
||||
# Motion state
|
||||
OBS_LINEAR_VEL: float,
|
||||
# Robot state
|
||||
OBS_BATTERY_LEVEL: float,
|
||||
OBS_ORIENTATION_DEG: float,
|
||||
# GPS
|
||||
OBS_GPS_LATITUDE: float,
|
||||
OBS_GPS_LONGITUDE: float,
|
||||
OBS_GPS_SIGNAL: float,
|
||||
# Sensors
|
||||
OBS_SIGNAL_LEVEL: float,
|
||||
OBS_VIBRATION: float,
|
||||
OBS_LAMP_STATE: float,
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
"""Define the action space.
|
||||
|
||||
Returns:
|
||||
dict: Action features with types:
|
||||
- linear.vel: float - Target linear velocity
|
||||
- angular.vel: float - Target angular velocity
|
||||
"""
|
||||
return {
|
||||
ACTION_LINEAR_VEL: float,
|
||||
ACTION_ANGULAR_VEL: float,
|
||||
}
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
"""Get current robot observation from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Observation containing:
|
||||
- front: Front camera image (480, 640, 3) in RGB format
|
||||
- rear: Rear camera image (480, 640, 3) in RGB format
|
||||
- linear.vel: Current speed (0-1, SDK reports only positive speeds)
|
||||
- battery.level: Battery level (0-1, normalized from 0-100)
|
||||
- orientation.deg: Robot orientation (0-1, normalized from raw value)
|
||||
- gps.latitude: GPS latitude coordinate
|
||||
- gps.longitude: GPS longitude coordinate
|
||||
- gps.signal: GPS signal strength (0-1, normalized from percentage)
|
||||
- signal.level: Network signal level (0-1, normalized from 0-5)
|
||||
- vibration: Vibration sensor reading
|
||||
- lamp.state: Lamp state (0=off, 1=on)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
Note:
|
||||
Camera frames are retrieved from SDK endpoints /v2/front and /v2/rear.
|
||||
Frames are decoded from base64 and converted from BGR to RGB format.
|
||||
Robot telemetry is retrieved from /data endpoint.
|
||||
All SDK values are normalized to appropriate ranges for dataset recording.
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
observation = {}
|
||||
|
||||
# Get camera images from SDK
|
||||
frames = self._get_camera_frames()
|
||||
observation[OBS_FRONT] = frames["front"]
|
||||
observation[OBS_REAR] = frames["rear"]
|
||||
|
||||
# Get robot state from SDK
|
||||
robot_data = self._get_robot_data()
|
||||
|
||||
# Motion state
|
||||
observation[OBS_LINEAR_VEL] = robot_data["speed"] / 100.0 # Normalize 0-100 to 0-1
|
||||
|
||||
# Robot state
|
||||
observation[OBS_BATTERY_LEVEL] = robot_data["battery"] / 100.0 # Normalize 0-100 to 0-1
|
||||
observation[OBS_ORIENTATION_DEG] = robot_data["orientation"] / 360.0 # Normalize to 0-1
|
||||
|
||||
# GPS data
|
||||
observation[OBS_GPS_LATITUDE] = robot_data["latitude"]
|
||||
observation[OBS_GPS_LONGITUDE] = robot_data["longitude"]
|
||||
observation[OBS_GPS_SIGNAL] = robot_data["gps_signal"] / 100.0 # Normalize percentage to 0-1
|
||||
|
||||
# Sensors
|
||||
observation[OBS_SIGNAL_LEVEL] = robot_data["signal_level"] / 5.0 # Normalize 0-5 to 0-1
|
||||
observation[OBS_VIBRATION] = robot_data["vibration"]
|
||||
observation[OBS_LAMP_STATE] = float(robot_data["lamp"]) # 0 or 1
|
||||
|
||||
return observation
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Send action to robot via SDK.
|
||||
|
||||
Args:
|
||||
action: Action dict with keys:
|
||||
- linear.vel: Target linear velocity (-1 to 1)
|
||||
- angular.vel: Target angular velocity (-1 to 1)
|
||||
|
||||
Returns:
|
||||
dict: The action that was sent (matches action_features keys)
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
|
||||
Note:
|
||||
Actions are sent to SDK via POST /control endpoint.
|
||||
SDK expects commands in range [-1, 1].
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Extract action values and convert to float
|
||||
linear = float(action.get(ACTION_LINEAR_VEL, 0.0))
|
||||
angular = float(action.get(ACTION_ANGULAR_VEL, 0.0))
|
||||
|
||||
# Send command to SDK
|
||||
try:
|
||||
self._send_command_to_sdk(linear, angular)
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending action: {e}")
|
||||
|
||||
# Return action in format matching action_features
|
||||
return {
|
||||
ACTION_LINEAR_VEL: linear,
|
||||
ACTION_ANGULAR_VEL: angular,
|
||||
}
|
||||
|
||||
def disconnect(self) -> None:
|
||||
"""Disconnect from robot.
|
||||
|
||||
Stops the robot and closes connection to SDK.
|
||||
|
||||
Raises:
|
||||
DeviceNotConnectedError: If robot is not connected
|
||||
"""
|
||||
if not self._is_connected:
|
||||
raise DeviceNotConnectedError(f"{self.name} is not connected")
|
||||
|
||||
# Stop the robot before disconnecting
|
||||
try:
|
||||
self._send_command_to_sdk(0.0, 0.0)
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to stop robot during disconnect: {e}")
|
||||
|
||||
self._is_connected = False
|
||||
logger.info(f"{self.name} disconnected")
|
||||
|
||||
# Private helper methods for SDK communication
|
||||
|
||||
def _get_camera_frames(self) -> dict[str, np.ndarray]:
|
||||
"""Get camera frames from SDK using v2 endpoints with caching fallback.
|
||||
|
||||
Returns:
|
||||
dict: Dictionary with 'front' and 'rear' keys containing:
|
||||
- Current frame (if request succeeds)
|
||||
- Cached frame (if request fails but cache exists)
|
||||
- Zero array (if request fails and no cache exists yet)
|
||||
|
||||
Note:
|
||||
Uses /v2/front and /v2/rear endpoints which are 15x faster than /screenshot.
|
||||
Images are base64 encoded, resized to 640x480, and converted from BGR to RGB.
|
||||
If request fails, returns the last successfully retrieved frame (cached).
|
||||
"""
|
||||
frames = {}
|
||||
|
||||
# Get front camera
|
||||
try:
|
||||
response = requests.get(f"{self.sdk_base_url}/v2/front", timeout=2.0)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "front_frame" in data and data["front_frame"]:
|
||||
front_img = self._decode_base64_image(data["front_frame"])
|
||||
if front_img is not None:
|
||||
# Resize and convert BGR to RGB
|
||||
front_img = cv2.resize(front_img, (640, 480))
|
||||
front_rgb = cv2.cvtColor(front_img, cv2.COLOR_BGR2RGB)
|
||||
frames["front"] = front_rgb
|
||||
# Cache the successful frame
|
||||
self._last_front_frame = front_rgb
|
||||
except Exception as e:
|
||||
logger.warning(f"Error fetching front camera: {e}")
|
||||
|
||||
# Fallback: use cache or zero array
|
||||
if "front" not in frames:
|
||||
if self._last_front_frame is not None:
|
||||
frames["front"] = self._last_front_frame
|
||||
else:
|
||||
frames["front"] = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
# Get rear camera
|
||||
try:
|
||||
response = requests.get(f"{self.sdk_base_url}/v2/rear", timeout=2.0)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
if "rear_frame" in data and data["rear_frame"]:
|
||||
rear_img = self._decode_base64_image(data["rear_frame"])
|
||||
if rear_img is not None:
|
||||
# Resize and convert BGR to RGB
|
||||
rear_img = cv2.resize(rear_img, (640, 480))
|
||||
rear_rgb = cv2.cvtColor(rear_img, cv2.COLOR_BGR2RGB)
|
||||
frames["rear"] = rear_rgb
|
||||
# Cache the successful frame
|
||||
self._last_rear_frame = rear_rgb
|
||||
except Exception as e:
|
||||
logger.warning(f"Error fetching rear camera: {e}")
|
||||
|
||||
# Fallback: use cache or zero array
|
||||
if "rear" not in frames:
|
||||
if self._last_rear_frame is not None:
|
||||
frames["rear"] = self._last_rear_frame
|
||||
else:
|
||||
frames["rear"] = np.zeros((480, 640, 3), dtype=np.uint8)
|
||||
|
||||
return frames
|
||||
|
||||
def _decode_base64_image(self, base64_string: str) -> np.ndarray | None:
|
||||
"""Decode base64 string to image.
|
||||
|
||||
Args:
|
||||
base64_string: Base64 encoded image string
|
||||
|
||||
Returns:
|
||||
np.ndarray: Decoded image in BGR format (OpenCV default), or None if decoding fails
|
||||
"""
|
||||
try:
|
||||
img_bytes = base64.b64decode(base64_string)
|
||||
nparr = np.frombuffer(img_bytes, np.uint8)
|
||||
img = cv2.imdecode(nparr, cv2.IMREAD_COLOR)
|
||||
return img # Return in BGR format (OpenCV default)
|
||||
except Exception as e:
|
||||
logger.error(f"Error decoding image: {e}")
|
||||
return None
|
||||
|
||||
def _get_robot_data(self) -> dict:
|
||||
"""Get robot telemetry data from SDK.
|
||||
|
||||
Returns:
|
||||
dict: Robot telemetry data including battery, speed, orientation, GPS, etc:
|
||||
- Current data (if request succeeds)
|
||||
- Cached data (if request fails but cache exists)
|
||||
- Default values (if request fails and no cache exists yet)
|
||||
|
||||
Note:
|
||||
Uses /data endpoint which provides comprehensive robot state.
|
||||
If request fails, returns the last successfully retrieved data (cached).
|
||||
"""
|
||||
try:
|
||||
response = requests.get(f"{self.sdk_base_url}/data", timeout=2.0)
|
||||
if response.status_code == 200:
|
||||
data = response.json()
|
||||
# Cache the successful data
|
||||
self._last_robot_data = data
|
||||
return data
|
||||
except Exception as e:
|
||||
logger.warning(f"Error fetching robot data: {e}")
|
||||
|
||||
# Fallback: use cache or default values
|
||||
if self._last_robot_data is not None:
|
||||
return self._last_robot_data
|
||||
else:
|
||||
# Return dict with default values (used only on first failure before any cache exists)
|
||||
return {
|
||||
"speed": 0,
|
||||
"battery": 0,
|
||||
"orientation": 0,
|
||||
"latitude": 0.0,
|
||||
"longitude": 0.0,
|
||||
"gps_signal": 0,
|
||||
"signal_level": 0,
|
||||
"vibration": 0.0,
|
||||
"lamp": 0,
|
||||
}
|
||||
|
||||
def _send_command_to_sdk(self, linear: float, angular: float, lamp: int = 0) -> bool:
|
||||
"""Send control command to SDK.
|
||||
|
||||
Args:
|
||||
linear: Linear velocity command (-1 to 1)
|
||||
angular: Angular velocity command (-1 to 1)
|
||||
lamp: Lamp control (0=off, 1=on)
|
||||
|
||||
Returns:
|
||||
bool: True if command sent successfully, False otherwise
|
||||
|
||||
Note:
|
||||
Uses POST /control endpoint. Commands are sent as JSON payload.
|
||||
"""
|
||||
try:
|
||||
payload = {
|
||||
"command": {
|
||||
"linear": linear,
|
||||
"angular": angular,
|
||||
"lamp": lamp,
|
||||
}
|
||||
}
|
||||
|
||||
response = requests.post(
|
||||
f"{self.sdk_base_url}/control",
|
||||
json=payload,
|
||||
timeout=1.0,
|
||||
)
|
||||
|
||||
return response.status_code == 200
|
||||
except Exception as e:
|
||||
logger.error(f"Error sending command: {e}")
|
||||
return False
|
||||
@@ -0,0 +1,18 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
"left_leg": {
|
||||
"kp": [150, 150, 150, 300, 40, 40],
|
||||
"kd": [2, 2, 2, 4, 2, 2],
|
||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||
}
|
||||
|
||||
|
||||
def _build_gains() -> tuple[list[float], list[float]]:
|
||||
"""Build kp and kd lists from body-part groupings."""
|
||||
kp = [v for g in _GAINS.values() for v in g["kp"]]
|
||||
kd = [v for g in _GAINS.values() for v in g["kd"]]
|
||||
return kp, kd
|
||||
|
||||
|
||||
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
|
||||
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
|
||||
|
||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||
|
||||
# socket config for ZMQ bridge
|
||||
robot_ip: str = "192.168.123.164"
|
||||
@@ -0,0 +1,89 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
kLeftHipRoll = 1
|
||||
kLeftHipYaw = 2
|
||||
kLeftKnee = 3
|
||||
kLeftAnklePitch = 4
|
||||
kLeftAnkleRoll = 5
|
||||
|
||||
# Right leg
|
||||
kRightHipPitch = 6
|
||||
kRightHipRoll = 7
|
||||
kRightHipYaw = 8
|
||||
kRightKnee = 9
|
||||
kRightAnklePitch = 10
|
||||
kRightAnkleRoll = 11
|
||||
|
||||
kWaistYaw = 12
|
||||
kWaistRoll = 13
|
||||
kWaistPitch = 14
|
||||
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
# not used
|
||||
kNotUsedJoint0 = 29
|
||||
kNotUsedJoint1 = 30
|
||||
kNotUsedJoint2 = 31
|
||||
kNotUsedJoint3 = 32
|
||||
kNotUsedJoint4 = 33
|
||||
kNotUsedJoint5 = 34
|
||||
@@ -0,0 +1,212 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
DDS-to-ZMQ bridge server for Unitree G1 robot.
|
||||
|
||||
This server runs on the robot and forwards:
|
||||
- Robot state (LowState) from DDS to ZMQ (for remote clients)
|
||||
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
|
||||
|
||||
Uses JSON for secure serialization instead of pickle.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
|
||||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||
kTopicLowState = "rt/lowstate" # observation from robot
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
|
||||
"""Convert LowState SDK message to a JSON-serializable dictionary."""
|
||||
motor_states = []
|
||||
for i in range(NUM_MOTORS):
|
||||
temp = msg.motor_state[i].temperature
|
||||
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
|
||||
motor_states.append(
|
||||
{
|
||||
"q": float(msg.motor_state[i].q),
|
||||
"dq": float(msg.motor_state[i].dq),
|
||||
"tau_est": float(msg.motor_state[i].tau_est),
|
||||
"temperature": avg_temp,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"motor_state": motor_states,
|
||||
"imu_state": {
|
||||
"quaternion": [float(x) for x in msg.imu_state.quaternion],
|
||||
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
|
||||
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
|
||||
"rpy": [float(x) for x in msg.imu_state.rpy],
|
||||
"temperature": float(msg.imu_state.temperature),
|
||||
},
|
||||
# Encode bytes as base64 for JSON compatibility
|
||||
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
}
|
||||
|
||||
|
||||
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
|
||||
"""Convert dictionary back to LowCmd SDK message."""
|
||||
cmd = unitree_hg_msg_dds__LowCmd_()
|
||||
cmd.mode_pr = data.get("mode_pr", 0)
|
||||
cmd.mode_machine = data.get("mode_machine", 0)
|
||||
|
||||
for i, motor_data in enumerate(data.get("motor_cmd", [])):
|
||||
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
|
||||
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
|
||||
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
|
||||
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
|
||||
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
|
||||
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def state_forward_loop(
|
||||
lowstate_sub: ChannelSubscriber,
|
||||
lowstate_sock: zmq.Socket,
|
||||
state_period: float,
|
||||
) -> None:
|
||||
"""Read observation from DDS and forward to ZMQ clients."""
|
||||
last_state_time = 0.0
|
||||
|
||||
while True:
|
||||
# read from DDS
|
||||
msg = lowstate_sub.Read()
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
# optional downsampling (if robot dds rate > state_period)
|
||||
if now - last_state_time >= state_period:
|
||||
# Convert to dict and serialize with JSON
|
||||
state_dict = lowstate_to_dict(msg)
|
||||
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
|
||||
# if no subscribers / tx buffer full, just drop
|
||||
with contextlib.suppress(zmq.Again):
|
||||
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||
last_state_time = now
|
||||
|
||||
|
||||
def cmd_forward_loop(
|
||||
lowcmd_sock: zmq.Socket,
|
||||
lowcmd_pub_debug: ChannelPublisher,
|
||||
crc: CRC,
|
||||
) -> None:
|
||||
"""Receive commands from ZMQ and forward to DDS."""
|
||||
while True:
|
||||
payload = lowcmd_sock.recv()
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
|
||||
topic = msg_dict.get("topic", "")
|
||||
cmd_data = msg_dict.get("data", {})
|
||||
|
||||
# Reconstruct LowCmd object from dict
|
||||
cmd = dict_to_lowcmd(cmd_data)
|
||||
|
||||
# recompute crc
|
||||
cmd.crc = crc.Crc(cmd)
|
||||
|
||||
if topic == kTopicLowCommand_Debug:
|
||||
lowcmd_pub_debug.Write(cmd)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the robot server bridge."""
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
# stop all active publishers on the robot
|
||||
msc = MotionSwitcherClient()
|
||||
msc.SetTimeout(5.0)
|
||||
msc.Init()
|
||||
|
||||
status, result = msc.CheckMode()
|
||||
while result is not None and "name" in result and result["name"]:
|
||||
msc.ReleaseMode()
|
||||
status, result = msc.CheckMode()
|
||||
time.sleep(1.0)
|
||||
|
||||
crc = CRC()
|
||||
|
||||
# initialize DDS publisher
|
||||
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
lowcmd_pub_debug.Init()
|
||||
|
||||
# initialize DDS subscriber
|
||||
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
lowstate_sub.Init()
|
||||
|
||||
# initialize ZMQ
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
# receive commands from remote client
|
||||
lowcmd_sock = ctx.socket(zmq.PULL)
|
||||
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
|
||||
|
||||
# publish state to remote clients
|
||||
lowstate_sock = ctx.socket(zmq.PUB)
|
||||
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
|
||||
|
||||
state_period = 0.002 # ~500 hz
|
||||
|
||||
# start observation forwarding thread
|
||||
t_state = threading.Thread(
|
||||
target=state_forward_loop,
|
||||
args=(lowstate_sub, lowstate_sock, state_period),
|
||||
daemon=True,
|
||||
)
|
||||
t_state.start()
|
||||
|
||||
# start action forwarding thread
|
||||
t_cmd = threading.Thread(
|
||||
target=cmd_forward_loop,
|
||||
args=(lowcmd_sock, lowcmd_pub_debug, crc),
|
||||
daemon=True,
|
||||
)
|
||||
t_cmd.start()
|
||||
|
||||
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
|
||||
# keep main thread alive so daemon threads don't exit
|
||||
try:
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("shutting down bridge...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,267 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||
ChannelFactoryInitialize,
|
||||
ChannelPublisher,
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
G1_29_Num_Motors = 35
|
||||
G1_23_Num_Motors = 35
|
||||
H1_2_Num_Motors = 35
|
||||
H1_Num_Motors = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorState:
|
||||
q: float | None = None # position
|
||||
dq: float | None = None # velocity
|
||||
tau_est: float | None = None # estimated torque
|
||||
temperature: float | None = None # motor temperature
|
||||
|
||||
|
||||
@dataclass
|
||||
class IMUState:
|
||||
quaternion: np.ndarray | None = None # [w, x, y, z]
|
||||
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
|
||||
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
|
||||
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
|
||||
temperature: float | None = None # IMU temperature
|
||||
|
||||
|
||||
# g1 observation class
|
||||
@dataclass
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(
|
||||
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||
)
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
class DataBuffer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_data(self):
|
||||
with self.lock:
|
||||
return self.data
|
||||
|
||||
def set_data(self, data):
|
||||
with self.lock:
|
||||
self.data = data
|
||||
|
||||
|
||||
class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
self.rx = 0
|
||||
self.ry = 0
|
||||
self.button = [0] * 16
|
||||
|
||||
def set(self, data):
|
||||
# wireless_remote
|
||||
keys = struct.unpack("H", data[2:4])[0]
|
||||
for i in range(16):
|
||||
self.button[i] = (keys & (1 << i)) >> i
|
||||
self.lx = struct.unpack("f", data[4:8])[0]
|
||||
self.rx = struct.unpack("f", data[8:12])[0]
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
self.config = config
|
||||
|
||||
self.control_dt = config.control_dt
|
||||
|
||||
# connect robot
|
||||
self.connect()
|
||||
|
||||
# initialize direct motor control interface
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
self.lowstate_buffer = DataBuffer()
|
||||
|
||||
# initialize subscribe thread to read robot state
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.daemon = True
|
||||
self.subscribe_thread.start()
|
||||
|
||||
while not self.is_connected:
|
||||
time.sleep(0.1)
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
self.msg.mode_pr = 0
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
while lowstate is None:
|
||||
lowstate = self.lowstate_buffer.get_data()
|
||||
if lowstate is None:
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(config.kp, dtype=np.float32)
|
||||
self.kd = np.array(config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
# Initialize remote controller
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while True:
|
||||
start_time = time.time()
|
||||
msg = self.lowstate_subscriber.Read()
|
||||
if msg is not None:
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states
|
||||
for id in range(G1_29_Num_Motors):
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||
|
||||
# Capture IMU state
|
||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
|
||||
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
|
||||
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
|
||||
lowstate.imu_state.temperature = msg.imu_state.temperature
|
||||
|
||||
# Capture wireless remote data
|
||||
lowstate.wireless_remote = msg.wireless_remote
|
||||
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
self.lowstate_buffer.set_data(lowstate)
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
def disconnect(self):
|
||||
pass
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
return self.lowstate_buffer.get_data()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.lowstate_buffer.get_data() is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
self.msg.crc = self.crc.Crc(action)
|
||||
self.lowcmd_publisher.Write(action)
|
||||
return action
|
||||
|
||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
@@ -0,0 +1,168 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
_ctx: zmq.Context | None = None
|
||||
_lowcmd_sock: zmq.Socket | None = None
|
||||
_lowstate_sock: zmq.Socket | None = None
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
|
||||
|
||||
class LowStateMsg:
|
||||
"""
|
||||
Wrapper class that mimics the Unitree SDK LowState_ message structure.
|
||||
|
||||
Reconstructs the message from deserialized JSON data to maintain
|
||||
compatibility with existing code that expects SDK message objects.
|
||||
"""
|
||||
|
||||
class MotorState:
|
||||
"""Motor state data for a single joint."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.q: float = data.get("q", 0.0)
|
||||
self.dq: float = data.get("dq", 0.0)
|
||||
self.tau_est: float = data.get("tau_est", 0.0)
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
class IMUState:
|
||||
"""IMU sensor data."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
|
||||
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
|
||||
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
|
||||
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
"""Initialize from deserialized JSON data."""
|
||||
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
|
||||
self.imu_state = self.IMUState(data.get("imu_state", {}))
|
||||
# Decode base64-encoded wireless_remote bytes
|
||||
wireless_b64 = data.get("wireless_remote", "")
|
||||
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
|
||||
self.mode_machine: int = data.get("mode_machine", 0)
|
||||
|
||||
|
||||
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||
"""Convert LowCmd message to a JSON-serializable dictionary."""
|
||||
motor_cmds = []
|
||||
# Iterate over all motor commands in the message
|
||||
for i in range(len(msg.motor_cmd)):
|
||||
motor_cmds.append(
|
||||
{
|
||||
"mode": int(msg.motor_cmd[i].mode),
|
||||
"q": float(msg.motor_cmd[i].q),
|
||||
"dq": float(msg.motor_cmd[i].dq),
|
||||
"kp": float(msg.motor_cmd[i].kp),
|
||||
"kd": float(msg.motor_cmd[i].kd),
|
||||
"tau": float(msg.motor_cmd[i].tau),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"topic": topic,
|
||||
"data": {
|
||||
"mode_pr": int(msg.mode_pr),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
"motor_cmd": motor_cmds,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||
"""
|
||||
Initialize ZMQ sockets for robot communication.
|
||||
|
||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||
"""
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
|
||||
# read socket config
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
_ctx = ctx
|
||||
|
||||
# lowcmd: send robot commands
|
||||
lowcmd_sock = ctx.socket(zmq.PUSH)
|
||||
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
||||
_lowcmd_sock = lowcmd_sock
|
||||
|
||||
# lowstate: receive robot observations
|
||||
lowstate_sock = ctx.socket(zmq.SUB)
|
||||
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
|
||||
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
_lowstate_sock = lowstate_sock
|
||||
|
||||
|
||||
class ChannelPublisher:
|
||||
"""ZMQ-based publisher that sends commands to the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the publisher (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Write(self, msg: Any) -> None: # noqa: N802
|
||||
"""Serialize and send a command message to the robot."""
|
||||
if _lowcmd_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
|
||||
_lowcmd_sock.send(payload)
|
||||
|
||||
|
||||
class ChannelSubscriber:
|
||||
"""ZMQ-based subscriber that receives state from the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the subscriber (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Read(self) -> LowStateMsg: # noqa: N802
|
||||
"""Receive and deserialize a state message from the robot."""
|
||||
if _lowstate_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = _lowstate_sock.recv()
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
return LowStateMsg(msg_dict.get("data", {}))
|
||||
@@ -52,7 +52,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -84,7 +84,7 @@ def calibrate(cfg: CalibrateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
calibrate()
|
||||
|
||||
|
||||
|
||||
@@ -65,7 +65,6 @@ import argparse
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -78,19 +77,6 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -119,12 +105,10 @@ def visualize_dataset(
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
@@ -18,7 +18,8 @@
|
||||
Edit LeRobot datasets using various transformation tools.
|
||||
|
||||
This script allows you to delete episodes, split datasets, merge datasets,
|
||||
and remove features. When new_repo_id is specified, creates a new dataset.
|
||||
remove features, and convert image datasets to video format.
|
||||
When new_repo_id is specified, creates a new dataset.
|
||||
|
||||
Usage Examples:
|
||||
|
||||
@@ -65,6 +66,25 @@ Remove camera feature:
|
||||
--operation.type remove_feature \
|
||||
--operation.feature_names "['observation.images.top']"
|
||||
|
||||
Convert image dataset to video format (saves locally):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--operation.type convert_to_video \
|
||||
--operation.output_dir /path/to/output/pusht_video
|
||||
|
||||
Convert image dataset and save with new repo_id:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video
|
||||
|
||||
Convert and push to hub:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
--new_repo_id lerobot/pusht_video \
|
||||
--operation.type convert_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Using JSON config file:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--config_path path/to/edit_config.json
|
||||
@@ -72,9 +92,13 @@ Using JSON config file:
|
||||
|
||||
import logging
|
||||
import shutil
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
|
||||
import pandas as pd
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.datasets.dataset_tools import (
|
||||
delete_episodes,
|
||||
@@ -82,8 +106,10 @@ from lerobot.datasets.dataset_tools import (
|
||||
remove_feature,
|
||||
split_dataset,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import write_stats, write_tasks
|
||||
from lerobot.datasets.video_utils import encode_video_frames, get_video_info
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME, OBS_IMAGE
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -111,10 +137,23 @@ class RemoveFeatureConfig:
|
||||
feature_names: list[str] | None = None
|
||||
|
||||
|
||||
@dataclass
|
||||
class ConvertToVideoConfig:
|
||||
type: str = "convert_to_video"
|
||||
output_dir: str | None = None
|
||||
vcodec: str = "libsvtav1"
|
||||
pix_fmt: str = "yuv420p"
|
||||
g: int = 2
|
||||
crf: int = 30
|
||||
fast_decode: int = 0
|
||||
episode_indices: list[int] | None = None
|
||||
num_workers: int = 4
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig
|
||||
operation: DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertToVideoConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
push_to_hub: bool = False
|
||||
@@ -258,6 +297,415 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def save_episode_images_for_video(
|
||||
dataset: LeRobotDataset,
|
||||
imgs_dir: Path,
|
||||
img_key: str,
|
||||
episode_index: int,
|
||||
num_workers: int = 4,
|
||||
) -> None:
|
||||
"""Save images from a specific episode and camera to disk for video encoding.
|
||||
|
||||
Args:
|
||||
dataset: The LeRobot dataset to extract images from
|
||||
imgs_dir: Directory to save images to
|
||||
img_key: The image key (camera) to extract
|
||||
episode_index: Index of the episode to save
|
||||
num_workers: Number of threads for parallel image saving
|
||||
"""
|
||||
# Create directory
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Get dataset without torch format for PIL image access
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
|
||||
# Select only this camera's images
|
||||
imgs_dataset = hf_dataset.select_columns(img_key)
|
||||
|
||||
# Get episode start and end indices
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# Get all items for this episode
|
||||
episode_dataset = imgs_dataset.select(range(from_idx, to_idx))
|
||||
|
||||
# Define function to save a single image
|
||||
def save_single_image(i_item_tuple):
|
||||
i, item = i_item_tuple
|
||||
img = item[img_key]
|
||||
# Use frame-XXXXXX.png format to match encode_video_frames expectations
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
return i
|
||||
|
||||
# Save images with proper naming convention for encode_video_frames (frame-XXXXXX.png)
|
||||
items = list(enumerate(episode_dataset))
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(save_single_image, item) for item in items]
|
||||
for future in as_completed(futures):
|
||||
future.result() # This will raise any exceptions that occurred
|
||||
|
||||
|
||||
def encode_episode_videos(
|
||||
dataset: LeRobotDataset,
|
||||
new_meta: LeRobotDatasetMetadata,
|
||||
episode_index: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int,
|
||||
crf: int,
|
||||
fast_decode: int,
|
||||
temp_dir: Path,
|
||||
num_image_workers: int = 4,
|
||||
) -> dict[str, dict]:
|
||||
"""Encode videos for a single episode and return video metadata.
|
||||
|
||||
Args:
|
||||
dataset: Source dataset with images
|
||||
new_meta: Metadata object for the new video dataset
|
||||
episode_index: Episode index to process
|
||||
vcodec: Video codec
|
||||
pix_fmt: Pixel format
|
||||
g: Group of pictures size
|
||||
crf: Constant rate factor
|
||||
fast_decode: Fast decode tuning
|
||||
temp_dir: Temporary directory for images
|
||||
num_image_workers: Number of workers for saving images
|
||||
|
||||
Returns:
|
||||
Dictionary mapping video keys to their metadata (chunk_index, file_index, timestamps)
|
||||
"""
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
video_metadata = {}
|
||||
fps = int(dataset.fps) # Convert to int for PyAV compatibility
|
||||
episode_length = dataset.meta.episodes["length"][episode_index]
|
||||
episode_duration = episode_length / dataset.fps # Use original fps for duration calculation
|
||||
|
||||
for img_key in img_keys:
|
||||
# Save images temporarily
|
||||
imgs_dir = temp_dir / f"episode_{episode_index:06d}" / img_key
|
||||
save_episode_images_for_video(dataset, imgs_dir, img_key, episode_index, num_image_workers)
|
||||
|
||||
# Determine chunk and file indices
|
||||
# For simplicity, we'll put each episode in its own file
|
||||
chunk_idx = episode_index // new_meta.chunks_size
|
||||
file_idx = episode_index % new_meta.chunks_size
|
||||
|
||||
# Create video path in the new dataset structure
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=chunk_idx, file_index=file_idx
|
||||
)
|
||||
video_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Encode video
|
||||
encode_video_frames(
|
||||
imgs_dir=imgs_dir,
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
overwrite=True,
|
||||
)
|
||||
|
||||
# Clean up temporary images
|
||||
shutil.rmtree(imgs_dir)
|
||||
|
||||
# Store video metadata
|
||||
video_metadata[img_key] = {
|
||||
f"videos/{img_key}/chunk_index": chunk_idx,
|
||||
f"videos/{img_key}/file_index": file_idx,
|
||||
f"videos/{img_key}/from_timestamp": 0.0,
|
||||
f"videos/{img_key}/to_timestamp": episode_duration,
|
||||
}
|
||||
|
||||
return video_metadata
|
||||
|
||||
|
||||
def convert_dataset_to_videos(
|
||||
dataset: LeRobotDataset,
|
||||
output_dir: Path,
|
||||
repo_id: str | None = None,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int = 2,
|
||||
crf: int = 30,
|
||||
fast_decode: int = 0,
|
||||
episode_indices: list[int] | None = None,
|
||||
num_workers: int = 4,
|
||||
) -> LeRobotDataset:
|
||||
"""Convert image-based dataset to video-based dataset.
|
||||
|
||||
Creates a new LeRobotDataset with videos instead of images, following the proper
|
||||
LeRobot dataset structure with videos stored in chunked MP4 files.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobot dataset with images
|
||||
output_dir: Directory to save the new video dataset
|
||||
repo_id: Repository ID for the new dataset (default: original_id + "_video")
|
||||
vcodec: Video codec (default: libsvtav1)
|
||||
pix_fmt: Pixel format (default: yuv420p)
|
||||
g: Group of pictures size (default: 2)
|
||||
crf: Constant rate factor (default: 30)
|
||||
fast_decode: Fast decode tuning (default: 0)
|
||||
episode_indices: List of episode indices to convert (None = all episodes)
|
||||
num_workers: Number of threads for parallel processing (default: 4)
|
||||
|
||||
Returns:
|
||||
New LeRobotDataset with videos
|
||||
"""
|
||||
# Check that it's an image dataset
|
||||
if len(dataset.meta.video_keys) > 0:
|
||||
raise ValueError(
|
||||
f"This operation is for image datasets only. Video dataset provided: {dataset.repo_id}"
|
||||
)
|
||||
|
||||
# Get all image keys
|
||||
hf_dataset = dataset.hf_dataset.with_format(None)
|
||||
img_keys = [key for key in hf_dataset.features if key.startswith(OBS_IMAGE)]
|
||||
|
||||
if len(img_keys) == 0:
|
||||
raise ValueError(f"No image keys found in dataset {dataset.repo_id}")
|
||||
|
||||
# Determine which episodes to process
|
||||
if episode_indices is None:
|
||||
episode_indices = list(range(dataset.meta.total_episodes))
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_video"
|
||||
|
||||
logging.info(
|
||||
f"Converting {len(episode_indices)} episodes with {len(img_keys)} cameras from {dataset.repo_id}"
|
||||
)
|
||||
logging.info(f"Video codec: {vcodec}, pixel format: {pix_fmt}, GOP: {g}, CRF: {crf}")
|
||||
|
||||
# Create new features dict, converting image features to video features
|
||||
new_features = {}
|
||||
for key, value in dataset.meta.features.items():
|
||||
if key not in img_keys:
|
||||
new_features[key] = value
|
||||
else:
|
||||
# Convert image key to video format
|
||||
new_features[key] = value.copy()
|
||||
new_features[key]["dtype"] = "video" # Change dtype from "image" to "video"
|
||||
# Video info will be updated after episodes are encoded
|
||||
|
||||
# Create new metadata for video dataset
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=new_features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=True,
|
||||
chunks_size=dataset.meta.chunks_size,
|
||||
data_files_size_in_mb=dataset.meta.data_files_size_in_mb,
|
||||
video_files_size_in_mb=dataset.meta.video_files_size_in_mb,
|
||||
)
|
||||
|
||||
# Create temporary directory for image extraction
|
||||
temp_dir = output_dir / "temp_images"
|
||||
temp_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Process each episode
|
||||
all_episode_metadata = []
|
||||
|
||||
try:
|
||||
for ep_idx in tqdm(episode_indices, desc="Converting episodes to videos"):
|
||||
# Get episode metadata from source
|
||||
src_episode = dataset.meta.episodes[ep_idx]
|
||||
|
||||
# Encode videos for this episode
|
||||
video_metadata = encode_episode_videos(
|
||||
dataset=dataset,
|
||||
new_meta=new_meta,
|
||||
episode_index=ep_idx,
|
||||
vcodec=vcodec,
|
||||
pix_fmt=pix_fmt,
|
||||
g=g,
|
||||
crf=crf,
|
||||
fast_decode=fast_decode,
|
||||
temp_dir=temp_dir,
|
||||
num_image_workers=num_workers,
|
||||
)
|
||||
|
||||
# Build episode metadata
|
||||
episode_meta = {
|
||||
"episode_index": ep_idx,
|
||||
"length": src_episode["length"],
|
||||
"dataset_from_index": ep_idx * src_episode["length"],
|
||||
"dataset_to_index": (ep_idx + 1) * src_episode["length"],
|
||||
}
|
||||
|
||||
# Add video metadata
|
||||
for img_key in img_keys:
|
||||
episode_meta.update(video_metadata[img_key])
|
||||
|
||||
# Add data chunk/file info (using same structure as source)
|
||||
if "data/chunk_index" in src_episode:
|
||||
episode_meta["data/chunk_index"] = src_episode["data/chunk_index"]
|
||||
episode_meta["data/file_index"] = src_episode["data/file_index"]
|
||||
|
||||
all_episode_metadata.append(episode_meta)
|
||||
|
||||
# Copy and transform data files (removing image columns)
|
||||
_copy_data_without_images(dataset, new_meta, episode_indices, img_keys)
|
||||
|
||||
# Save episode metadata
|
||||
episodes_df = pd.DataFrame(all_episode_metadata)
|
||||
episodes_path = new_meta.root / "meta" / "episodes" / "chunk-000" / "file-000.parquet"
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
episodes_df.to_parquet(episodes_path, index=False)
|
||||
|
||||
# Update metadata info
|
||||
new_meta.info["total_episodes"] = len(episode_indices)
|
||||
new_meta.info["total_frames"] = sum(ep["length"] for ep in all_episode_metadata)
|
||||
new_meta.info["total_tasks"] = dataset.meta.total_tasks
|
||||
new_meta.info["splits"] = {"train": f"0:{len(episode_indices)}"}
|
||||
|
||||
# Update video info for all image keys (now videos)
|
||||
# We need to manually set video info since update_video_info() checks video_keys first
|
||||
for img_key in img_keys:
|
||||
if not new_meta.features[img_key].get("info", None):
|
||||
video_path = new_meta.root / new_meta.video_path.format(
|
||||
video_key=img_key, chunk_index=0, file_index=0
|
||||
)
|
||||
new_meta.info["features"][img_key]["info"] = get_video_info(video_path)
|
||||
|
||||
from lerobot.datasets.utils import write_info
|
||||
|
||||
write_info(new_meta.info, new_meta.root)
|
||||
|
||||
# Copy stats and tasks
|
||||
if dataset.meta.stats is not None:
|
||||
# Remove image stats
|
||||
new_stats = {k: v for k, v in dataset.meta.stats.items() if k not in img_keys}
|
||||
write_stats(new_stats, new_meta.root)
|
||||
|
||||
if dataset.meta.tasks is not None:
|
||||
write_tasks(dataset.meta.tasks, new_meta.root)
|
||||
|
||||
finally:
|
||||
# Clean up temporary directory
|
||||
if temp_dir.exists():
|
||||
shutil.rmtree(temp_dir)
|
||||
|
||||
logging.info(f"✓ Completed converting {dataset.repo_id} to video format")
|
||||
logging.info(f"New dataset saved to: {output_dir}")
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def _copy_data_without_images(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_indices: list[int],
|
||||
img_keys: list[str],
|
||||
) -> None:
|
||||
"""Copy data files without image columns.
|
||||
|
||||
Args:
|
||||
src_dataset: Source dataset
|
||||
dst_meta: Destination metadata
|
||||
episode_indices: Episodes to include
|
||||
img_keys: Image keys to remove
|
||||
"""
|
||||
from lerobot.datasets.utils import DATA_DIR
|
||||
|
||||
data_dir = src_dataset.root / DATA_DIR
|
||||
parquet_files = sorted(data_dir.glob("*/*.parquet"))
|
||||
|
||||
if not parquet_files:
|
||||
raise ValueError(f"No parquet files found in {data_dir}")
|
||||
|
||||
episode_set = set(episode_indices)
|
||||
|
||||
for src_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(src_path).reset_index(drop=True)
|
||||
|
||||
# Filter to only include selected episodes
|
||||
df = df[df["episode_index"].isin(episode_set)].copy()
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Remove image columns
|
||||
columns_to_drop = [col for col in img_keys if col in df.columns]
|
||||
if columns_to_drop:
|
||||
df = df.drop(columns=columns_to_drop)
|
||||
|
||||
# Get chunk and file indices from path
|
||||
relative_path = src_path.relative_to(src_dataset.root)
|
||||
chunk_dir = relative_path.parts[1]
|
||||
file_name = relative_path.parts[2]
|
||||
chunk_idx = int(chunk_dir.split("-")[1])
|
||||
file_idx = int(file_name.split("-")[1].split(".")[0])
|
||||
|
||||
# Write to destination without pandas index
|
||||
dst_path = dst_meta.root / f"data/chunk-{chunk_idx:03d}/file-{file_idx:03d}.parquet"
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(dst_path, index=False)
|
||||
|
||||
|
||||
def handle_convert_to_video(cfg: EditDatasetConfig) -> None:
|
||||
# Note: Parser may create any config type with the right fields, so we access fields directly
|
||||
# instead of checking isinstance()
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
|
||||
# Determine output directory and repo_id
|
||||
# Priority: 1) new_repo_id, 2) operation.output_dir, 3) auto-generated name
|
||||
output_dir_config = getattr(cfg.operation, "output_dir", None)
|
||||
|
||||
if cfg.new_repo_id:
|
||||
# Use new_repo_id for both local storage and hub push
|
||||
output_repo_id = cfg.new_repo_id
|
||||
output_dir = Path(cfg.root) / cfg.new_repo_id if cfg.root else HF_LEROBOT_HOME / cfg.new_repo_id
|
||||
logging.info(f"Saving to new dataset: {cfg.new_repo_id}")
|
||||
elif output_dir_config:
|
||||
# Use custom output directory for local-only storage
|
||||
output_dir = Path(output_dir_config)
|
||||
# Extract repo name from output_dir for the dataset
|
||||
output_repo_id = output_dir.name
|
||||
logging.info(f"Saving to local directory: {output_dir}")
|
||||
else:
|
||||
# Auto-generate name: append "_video" to original repo_id
|
||||
output_repo_id = f"{cfg.repo_id}_video"
|
||||
output_dir = Path(cfg.root) / output_repo_id if cfg.root else HF_LEROBOT_HOME / output_repo_id
|
||||
logging.info(f"Saving to auto-generated location: {output_dir}")
|
||||
|
||||
logging.info(f"Converting dataset {cfg.repo_id} to video format")
|
||||
|
||||
new_dataset = convert_dataset_to_videos(
|
||||
dataset=dataset,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
vcodec=getattr(cfg.operation, "vcodec", "libsvtav1"),
|
||||
pix_fmt=getattr(cfg.operation, "pix_fmt", "yuv420p"),
|
||||
g=getattr(cfg.operation, "g", 2),
|
||||
crf=getattr(cfg.operation, "crf", 30),
|
||||
fast_decode=getattr(cfg.operation, "fast_decode", 0),
|
||||
episode_indices=getattr(cfg.operation, "episode_indices", None),
|
||||
num_workers=getattr(cfg.operation, "num_workers", 4),
|
||||
)
|
||||
|
||||
logging.info("Video dataset created successfully!")
|
||||
logging.info(f"Location: {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}")
|
||||
logging.info(f"Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}...")
|
||||
new_dataset.push_to_hub()
|
||||
logging.info("✓ Successfully pushed to hub!")
|
||||
else:
|
||||
logging.info("Dataset saved locally (not pushed to hub)")
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
operation_type = cfg.operation.type
|
||||
@@ -270,10 +718,12 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_merge(cfg)
|
||||
elif operation_type == "remove_feature":
|
||||
handle_remove_feature(cfg)
|
||||
elif operation_type == "convert_to_video":
|
||||
handle_convert_to_video(cfg)
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Unknown operation type: {operation_type}\n"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature"
|
||||
f"Available operations: delete_episodes, split, merge, remove_feature, convert_to_video"
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -82,6 +82,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.utils import (
|
||||
@@ -533,7 +534,7 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
)
|
||||
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
@@ -792,6 +793,7 @@ def eval_policy_all(
|
||||
|
||||
def main():
|
||||
init_logging()
|
||||
register_third_party_plugins()
|
||||
eval_main()
|
||||
|
||||
|
||||
|
||||
@@ -15,18 +15,23 @@
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Simple script to control a robot from teleoperation.
|
||||
Script to find joint limits and end-effector bounds via teleoperation.
|
||||
|
||||
Example:
|
||||
|
||||
```shell
|
||||
lerobot-find-joint-limits \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
||||
--robot.port=/dev/tty.usbmodem58760432981 \
|
||||
--robot.id=black \
|
||||
--teleop.type=so100_leader \
|
||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
||||
--teleop.id=blue
|
||||
--teleop.port=/dev/tty.usbmodem58760434471 \
|
||||
--teleop.id=blue \
|
||||
--urdf_path=<user>/SO-ARM100-main/Simulation/SO101/so101_new_calib.urdf \
|
||||
--target_frame_name=gripper \
|
||||
--teleop_time_s=30 \
|
||||
--warmup_time_s=5 \
|
||||
--control_loop_fps=30
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -42,6 +47,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
TeleoperatorConfig,
|
||||
@@ -49,18 +55,28 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
|
||||
@dataclass
|
||||
class FindJointLimitsConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
robot: RobotConfig
|
||||
# Limit the maximum frames per second. By default, no limit.
|
||||
|
||||
# Path to URDF file for kinematics
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo:
|
||||
# https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
urdf_path: str
|
||||
target_frame_name: str = "gripper"
|
||||
|
||||
# Duration of the recording phase in seconds
|
||||
teleop_time_s: float = 30
|
||||
# Display all cameras on screen
|
||||
display_data: bool = False
|
||||
# Duration of the warmup phase in seconds
|
||||
warmup_time_s: float = 5
|
||||
# Control loop frequency
|
||||
control_loop_fps: int = 30
|
||||
|
||||
|
||||
@draccus.wrap()
|
||||
@@ -68,53 +84,127 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
|
||||
print(f"Connecting to robot: {cfg.robot.type}...")
|
||||
teleop.connect()
|
||||
robot.connect()
|
||||
print("Devices connected.")
|
||||
|
||||
start_episode_t = time.perf_counter()
|
||||
robot_type = getattr(robot.config, "robot_type", "so101")
|
||||
if "so100" in robot_type or "so101" in robot_type:
|
||||
# Note to be compatible with the rest of the codebase,
|
||||
# we are using the new calibration method for so101 and so100
|
||||
robot_type = "so_new_calibration"
|
||||
kinematics = RobotKinematics(cfg.robot.urdf_path, cfg.robot.target_frame_name)
|
||||
# Initialize Kinematics
|
||||
try:
|
||||
kinematics = RobotKinematics(cfg.urdf_path, cfg.target_frame_name)
|
||||
except Exception as e:
|
||||
print(f"Error initializing kinematics: {e}")
|
||||
print("Ensure URDF path and target frame name are correct.")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
return
|
||||
|
||||
# Initialize min/max values
|
||||
observation = robot.get_observation()
|
||||
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
|
||||
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
|
||||
# Initialize variables
|
||||
max_pos = None
|
||||
min_pos = None
|
||||
max_ee = None
|
||||
min_ee = None
|
||||
|
||||
max_pos = joint_positions.copy()
|
||||
min_pos = joint_positions.copy()
|
||||
max_ee = ee_pos.copy()
|
||||
min_ee = ee_pos.copy()
|
||||
start_t = time.perf_counter()
|
||||
warmup_done = False
|
||||
|
||||
while True:
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
print("\n" + "=" * 40)
|
||||
print(f" WARMUP PHASE ({cfg.warmup_time_s}s)")
|
||||
print(" Move the robot freely to ensure control works.")
|
||||
print(" Data is NOT being recorded yet.")
|
||||
print("=" * 40 + "\n")
|
||||
|
||||
observation = robot.get_observation()
|
||||
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
|
||||
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
|
||||
try:
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Skip initial warmup period
|
||||
if (time.perf_counter() - start_episode_t) < 5:
|
||||
continue
|
||||
# 1. Teleoperation Control Loop
|
||||
action = teleop.get_action()
|
||||
robot.send_action(action)
|
||||
|
||||
# Update min/max values
|
||||
max_ee = np.maximum(max_ee, ee_pos)
|
||||
min_ee = np.minimum(min_ee, ee_pos)
|
||||
max_pos = np.maximum(max_pos, joint_positions)
|
||||
min_pos = np.minimum(min_pos, joint_positions)
|
||||
# 2. Read Observations
|
||||
observation = robot.get_observation()
|
||||
joint_positions = np.array([observation[f"{key}.pos"] for key in robot.bus.motors])
|
||||
|
||||
if time.perf_counter() - start_episode_t > cfg.teleop_time_s:
|
||||
print(f"Max ee position {np.round(max_ee, 4).tolist()}")
|
||||
print(f"Min ee position {np.round(min_ee, 4).tolist()}")
|
||||
print(f"Max joint pos position {np.round(max_pos, 4).tolist()}")
|
||||
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
|
||||
break
|
||||
# 3. Calculate Kinematics
|
||||
# Forward kinematics to get (x, y, z) translation
|
||||
ee_pos = kinematics.forward_kinematics(joint_positions)[:3, 3]
|
||||
|
||||
busy_wait(0.01)
|
||||
current_time = time.perf_counter()
|
||||
elapsed = current_time - start_t
|
||||
|
||||
# 4. Handle Phases
|
||||
if elapsed < cfg.warmup_time_s:
|
||||
# Still in warmup
|
||||
pass
|
||||
|
||||
else:
|
||||
# Phase Transition: Warmup -> Recording
|
||||
if not warmup_done:
|
||||
print("\n" + "=" * 40)
|
||||
print(" RECORDING STARTED")
|
||||
print(" Move robot to ALL joint limits.")
|
||||
print(" Press Ctrl+C to stop early and save results.")
|
||||
print("=" * 40 + "\n")
|
||||
|
||||
# Initialize limits with current position at start of recording
|
||||
max_pos = joint_positions.copy()
|
||||
min_pos = joint_positions.copy()
|
||||
max_ee = ee_pos.copy()
|
||||
min_ee = ee_pos.copy()
|
||||
warmup_done = True
|
||||
|
||||
# Update Limits
|
||||
max_ee = np.maximum(max_ee, ee_pos)
|
||||
min_ee = np.minimum(min_ee, ee_pos)
|
||||
max_pos = np.maximum(max_pos, joint_positions)
|
||||
min_pos = np.minimum(min_pos, joint_positions)
|
||||
|
||||
# Time check
|
||||
recording_time = elapsed - cfg.warmup_time_s
|
||||
remaining = cfg.teleop_time_s - recording_time
|
||||
|
||||
# Simple throttle for print statements (every ~1 sec)
|
||||
if int(recording_time * 100) % 100 == 0:
|
||||
print(f"Time remaining: {remaining:.1f}s", end="\r")
|
||||
|
||||
if recording_time > cfg.teleop_time_s:
|
||||
print("\nTime limit reached.")
|
||||
break
|
||||
|
||||
precise_sleep(max(1.0 / cfg.control_loop_fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n\nInterrupted by user. Stopping safely...")
|
||||
|
||||
finally:
|
||||
# Safety: Disconnect devices
|
||||
print("\nDisconnecting devices...")
|
||||
robot.disconnect()
|
||||
teleop.disconnect()
|
||||
|
||||
# Results Output
|
||||
if max_pos is not None:
|
||||
print("\n" + "=" * 40)
|
||||
print("FINAL RESULTS")
|
||||
print("=" * 40)
|
||||
|
||||
# Rounding for readability
|
||||
r_max_ee = np.round(max_ee, 4).tolist()
|
||||
r_min_ee = np.round(min_ee, 4).tolist()
|
||||
r_max_pos = np.round(max_pos, 4).tolist()
|
||||
r_min_pos = np.round(min_pos, 4).tolist()
|
||||
|
||||
print("\n# End Effector Bounds (x, y, z):")
|
||||
print(f"max_ee = {r_max_ee}")
|
||||
print(f"min_ee = {r_min_ee}")
|
||||
|
||||
print("\n# Joint Position Limits (radians):")
|
||||
print(f"max_pos = {r_max_pos}")
|
||||
print(f"min_pos = {r_min_pos}")
|
||||
|
||||
else:
|
||||
print("No data recorded (exited during warmup).")
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -93,6 +93,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
@@ -118,8 +119,8 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_name,
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
@@ -364,7 +365,7 @@ def record_loop(
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
@@ -512,7 +513,7 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
record()
|
||||
|
||||
|
||||
|
||||
@@ -54,6 +54,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
@@ -61,8 +62,8 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -121,13 +122,13 @@ def replay(cfg: ReplayConfig):
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
replay()
|
||||
|
||||
|
||||
|
||||
@@ -71,6 +71,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
bi_so100_follower,
|
||||
earthrover_mini_plus,
|
||||
hope_jr,
|
||||
koch_follower,
|
||||
make_robot_from_config,
|
||||
@@ -83,13 +84,14 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
bi_so100_leader,
|
||||
gamepad,
|
||||
homunculus,
|
||||
keyboard,
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -170,12 +172,13 @@ def teleop_loop(
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
move_cursor_up(len(robot_action_to_send) + 3)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
move_cursor_up(1)
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
@@ -216,7 +219,7 @@ def teleoperate(cfg: TeleoperateConfig):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_devices()
|
||||
register_third_party_plugins()
|
||||
teleoperate()
|
||||
|
||||
|
||||
|
||||
@@ -36,6 +36,7 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.scripts.lerobot_eval import eval_policy_all
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.logging_utils import AverageMeter, MetricsTracker
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.train_utils import (
|
||||
@@ -260,7 +261,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info("Creating environment processors")
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy
|
||||
)
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
@@ -446,6 +449,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
|
||||
def main():
|
||||
register_third_party_plugins()
|
||||
train()
|
||||
|
||||
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user