mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-14 08:09:45 +00:00
Compare commits
41 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 112eb70a65 | |||
| 6e3b972534 | |||
| fa5004bd8c | |||
| b98c70376b | |||
| 2fa045eedc | |||
| adc476d8af | |||
| 73dd4f10f7 | |||
| 2889c0650a | |||
| f2ad86831d | |||
| 3ed0425d2c | |||
| 425eced2de | |||
| cc2e91febe | |||
| c66aef878c | |||
| 599c2477c5 | |||
| 456d9fe3ff | |||
| 006185ff4a | |||
| 2dc2a3ae55 | |||
| 0c99b768f4 | |||
| c774818eda | |||
| 3b31c2d9d3 | |||
| 6b6a82bbdf | |||
| 7beb20819e | |||
| 9a5a0ad575 | |||
| 2af40615b8 | |||
| 8d2fb5d298 | |||
| d286ea30d4 | |||
| ca67231892 | |||
| 5245332e36 | |||
| 4367348327 | |||
| c2c0dbf52e | |||
| 3d28dc3681 | |||
| 9bd69bb236 | |||
| 52b080fd8c | |||
| 0d84f4724d | |||
| 1ffdc6f49e | |||
| f688eb160b | |||
| 69868360c7 | |||
| 3c9149e909 | |||
| cf0f878dbb | |||
| 1da9eee095 | |||
| d9f0c8c3ae |
@@ -60,19 +60,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
persist-credentials: false
|
||||
lfs: true
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
# TODO(Steven): Evaluate the need of these dependencies
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
|
||||
@@ -58,19 +58,12 @@ jobs:
|
||||
github.event_name == 'workflow_dispatch'
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
@@ -45,19 +45,12 @@ jobs:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
MUJOCO_GL: egl
|
||||
HF_HOME: /mnt/cache/.cache/huggingface
|
||||
HF_LEROBOT_HOME: /mnt/cache/.cache/huggingface/lerobot
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
lfs: true
|
||||
persist-credentials: false
|
||||
|
||||
# NOTE(Steven): Mount to `/mnt` to avoid the limited storage on `/home`. Consider cleaning default SDKs or using self-hosted runners for more space.
|
||||
# (As of 2024-06-10, the runner's `/home` has only 6.2 GB free—8% of its 72 GB total.)
|
||||
- name: Setup /mnt storage
|
||||
run: sudo chown -R $USER:$USER /mnt
|
||||
|
||||
- name: Install apt dependencies
|
||||
run: |
|
||||
sudo apt-get update && sudo apt-get install -y build-essential \
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
class TimeBenchmark(ContextDecorator):
|
||||
"""
|
||||
Measures execution time using a context manager or decorator.
|
||||
|
||||
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
||||
environments.
|
||||
|
||||
Args:
|
||||
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
||||
to False.
|
||||
|
||||
Examples:
|
||||
|
||||
Using as a context manager:
|
||||
|
||||
>>> benchmark = TimeBenchmark()
|
||||
>>> with benchmark:
|
||||
... time.sleep(1)
|
||||
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
||||
Block took approximately 1.0000 seconds
|
||||
|
||||
Using with multithreading:
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
benchmark = TimeBenchmark()
|
||||
|
||||
|
||||
def context_manager_example():
|
||||
with benchmark:
|
||||
time.sleep(0.01)
|
||||
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
||||
|
||||
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
t1 = threading.Thread(target=context_manager_example)
|
||||
threads.append(t1)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
```
|
||||
Expected output:
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.local.end_time = time.perf_counter()
|
||||
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
||||
if self.print_time:
|
||||
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
||||
return False
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return getattr(self.local, "elapsed_time", None)
|
||||
|
||||
@property
|
||||
def result_ms(self):
|
||||
return self.result * 1e3
|
||||
Executable
+102
@@ -0,0 +1,102 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Capture video feed from a camera as raw images."""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
|
||||
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||
rr.init("lerobot_capture_camera_feed")
|
||||
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
capture_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Opens the default webcam
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Error: Could not open video stream.")
|
||||
return
|
||||
|
||||
cap.set(cv2.CAP_PROP_FPS, fps)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
rr.log("video/stream", rr.Image(frame), static=True)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Release the capture
|
||||
cap.release()
|
||||
|
||||
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/cam_capture/"),
|
||||
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames Per Second of the capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1280,
|
||||
help="Width of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Duration in seconds for which the video stream should be captured.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
@@ -21,13 +21,11 @@ See the provided README.md or run `python benchmark/video/run_video_benchmark.py
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -37,13 +35,13 @@ import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmarks.video.benchmark import TimeBenchmark
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames,
|
||||
decode_video_frames_torchvision,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -88,7 +86,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
@@ -99,21 +97,21 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -127,7 +125,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
@@ -151,6 +149,18 @@ def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> lis
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
@@ -162,8 +172,8 @@ def benchmark_decoding(
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
def process_sample(sample: int):
|
||||
time_benchmark = TimeBenchmark()
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
@@ -172,13 +182,13 @@ def benchmark_decoding(
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark, lock:
|
||||
with time_benchmark:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
@@ -205,10 +215,8 @@ def benchmark_decoding(
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
@@ -350,27 +358,24 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
@@ -404,9 +409,9 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
@@ -414,7 +419,7 @@ if __name__ == "__main__":
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
default=["libx264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -463,7 +468,7 @@ if __name__ == "__main__":
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["torchcodec", "pyav"],
|
||||
default=["pyav", "video_reader"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -47,8 +47,8 @@
|
||||
- sections:
|
||||
- local: envhub
|
||||
title: Environments from the Hub
|
||||
- local: envhub_leisaac
|
||||
title: Control & Train Robots in Sim (LeIsaac)
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
- local: metaworld
|
||||
@@ -79,8 +79,6 @@
|
||||
title: Hope Jr
|
||||
- local: reachy2
|
||||
title: Reachy 2
|
||||
- local: unitree_g1
|
||||
title: Unitree G1
|
||||
title: "Robots"
|
||||
- sections:
|
||||
- local: phone_teleop
|
||||
|
||||
@@ -196,7 +196,7 @@ client_cfg = RobotClientConfig(
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
pretrained_name_or_path="fracapuano/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
@@ -1,301 +0,0 @@
|
||||
# LeIsaac × LeRobot EnvHub
|
||||
|
||||
LeRobot EnvHub now supports **imitation learning in simulation** with LeIsaac.
|
||||
Spin up everyday manipulation tasks, teleoperate the robot, collect demos, push them to the Hub, and train policies in LeRobot — all in one loop.
|
||||
|
||||
[LeIsaac](https://github.com/LightwheelAI/leisaac) integrates with IsaacLab and the SO101 Leader/Follower setup to provide:
|
||||
|
||||
- 🕹️ **Teleoperation-first workflows** for data collection
|
||||
- 📦 **Built-in data conversion** ready for LeRobot training
|
||||
- 🤖 **Everyday skills** like picking oranges, lifting cubes, cleaning tables, and folding cloth
|
||||
- ☁️ **Ongoing upgrades** from [LightWheel](https://lightwheel.ai/): cloud simulation, EnvHub support, Sim2Real tooling, and more
|
||||
|
||||
Below you’ll find the currently supported LeIsaac tasks exposed through LeRobot EnvHub.
|
||||
|
||||
# Available Environments
|
||||
|
||||
The following table lists all available tasks and environments in LeIsaac x LeRobot Envhub. You can also get the latest list of environments by running the following command:
|
||||
|
||||
```bash
|
||||
python scripts/environments/list_envs.py
|
||||
```
|
||||
|
||||
| Task | Environment ID | Task Description | Related Robot |
|
||||
| :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------- |
|
||||
| <video src="https://github.com/user-attachments/assets/466eddff-f720-4f99-94d5-5e123e4c302c" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-PickOrange-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/pick_orange_env_cfg.py)<br /><br />[LeIsaac-SO101-PickOrange-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/direct/pick_orange_env.py) | Pick three oranges and put them into the plate, then reset the arm to rest state. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/1e4eb83a-0b38-40fb-a0b2-ddb0fe201e6d" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-LiftCube-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/lift_cube_env_cfg.py)<br /><br />[LeIsaac-SO101-LiftCube-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/direct/lift_cube_env.py) | Lift the red cube up. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e49d8f1c-dcc9-412b-a88f-100680d8a45b" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-CleanToyTable-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/direct/clean_toy_table_bi_arm_env.py) | Pick two letter e objects into the box, and reset the arm to rest state. | Single-Arm SO101 Follower<br /><br />Bi-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e29a0f8a-9286-4ce6-b45d-342c3d3ba754" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-FoldCloth-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/fold_cloth_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-FoldCloth-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/direct/fold_cloth_bi_arm_env.py) | Fold the cloth, and reset the arm to rest state.<br /><br />_Note: Only the DirectEnv support check_success in this task._ | Bi-Arm SO101 Follower |
|
||||
|
||||
# Load LeIsaac directly in LeRobot with one line of code
|
||||
|
||||
> EnvHub: Share LeIsaac environments through HuggingFace
|
||||
|
||||
[EnvHub](https://huggingface.co/docs/lerobot/envhub) is our reproducible environment hub, spin up a packaged simulation with one line, experiment immediately, and publish your own tasks for the community.
|
||||
|
||||
LeIsaac offers EnvHub support so you can consume or share tasks with only a few commands.
|
||||
|
||||
<video
|
||||
controls
|
||||
src="https://github.com/user-attachments/assets/687666f5-ebe0-421d-84a0-eb86116ac5f8"
|
||||
style={{ width: "100%", maxWidth: "960px", borderRadius: "8px" }}
|
||||
/>
|
||||
|
||||
## How to get started, environment Setup
|
||||
|
||||
Run the following commands to setup your code environments:
|
||||
|
||||
```bash
|
||||
# Refer to Getting Started/Installation to install leisaac firstly
|
||||
conda create -n leisaac_envhub python=3.11
|
||||
conda activate leisaac_envhub
|
||||
|
||||
conda install -c "nvidia/label/cuda-12.8.1" cuda-toolkit
|
||||
pip install -U torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
pip install 'leisaac[isaaclab] @ git+https://github.com/LightwheelAI/leisaac.git#subdirectory=source/leisaac' --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Install lerobot
|
||||
pip install lerobot==0.4.1
|
||||
|
||||
# Fix numpy version
|
||||
pip install numpy==1.26.0
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples below load `so101_pick_orange` and demonstrate a random-action rollout and an interactive teleoperation.
|
||||
|
||||
### Random Action
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_random_action.py
|
||||
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# Use it like any gym environment
|
||||
obs, info = env.reset()
|
||||
|
||||
while True:
|
||||
action = torch.tensor(env.action_space.sample())
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_random_action.py
|
||||
```
|
||||
|
||||
You should see the SO101 arm swinging under purely random commands.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
LeRobot’s teleoperation stack can drive the simulated arm.
|
||||
|
||||
Connect the SO101 Leader controller, run the calibration command below.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
And then launch the teleop script.
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_teleop_example.py
|
||||
|
||||
import logging
|
||||
import time
|
||||
import gymnasium as gym
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
make_teleoperator_from_config,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeleoperateConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
env_name: str = "so101_pick_orange"
|
||||
fps: int = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrap:
|
||||
env: gym.Env
|
||||
|
||||
|
||||
def make_env_from_leisaac(env_name: str = "so101_pick_orange"):
|
||||
envs_dict = make_env(
|
||||
f'LightwheelAI/leisaac_env:envs/{env_name}.py',
|
||||
n_envs=1,
|
||||
trust_remote_code=True
|
||||
)
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
|
||||
from leisaac.devices.action_process import preprocess_device_action
|
||||
from leisaac.assets.robots.lerobot import SO101_FOLLOWER_MOTOR_LIMITS
|
||||
from leisaac.utils.env_utils import dynamic_reset_gripper_effort_limit_sim
|
||||
|
||||
env_wrap = EnvWrap(env=env)
|
||||
|
||||
obs, info = env.reset()
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
if env.cfg.dynamic_reset_gripper_effort_limit:
|
||||
dynamic_reset_gripper_effort_limit_sim(env, 'so101leader')
|
||||
|
||||
raw_action = teleop.get_action()
|
||||
processed_action = preprocess_device_action(
|
||||
dict(
|
||||
so101_leader=True,
|
||||
joint_state={
|
||||
k.removesuffix(".pos"): v for k, v in raw_action.items()},
|
||||
motor_limits=SO101_FOLLOWER_MOTOR_LIMITS),
|
||||
env_wrap
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
env = make_env_from_leisaac(cfg.env_name)
|
||||
|
||||
teleop.connect()
|
||||
if hasattr(env, 'initialize'):
|
||||
env.initialize()
|
||||
try:
|
||||
teleop_loop(teleop=teleop, env=env, fps=cfg.fps)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
teleop.disconnect()
|
||||
env.close()
|
||||
|
||||
|
||||
def main():
|
||||
teleoperate(TeleoperateConfig(
|
||||
teleop=so101_leader.SO101LeaderConfig(
|
||||
port="/dev/ttyACM0",
|
||||
id='leader',
|
||||
use_degrees=False,
|
||||
),
|
||||
env_name="so101_pick_orange",
|
||||
fps=60,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_teleop_example.py
|
||||
```
|
||||
|
||||
Running the script lets you operate the simulated arm using the physical Leader device.
|
||||
|
||||
## ☁️ Cloud Simulation (No GPU Required)
|
||||
|
||||
Don’t have a local GPU or the right drivers? No problem! You can run LeIsaac entirely in the cloud with zero setup.
|
||||
LeIsaac works out-of-the-box on **NVIDIA Brev**, giving you a fully configured environment directly in your browser.
|
||||
|
||||
👉 **Start here:** [https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev](https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev)
|
||||
|
||||
Once your instance is deployed, simply open the link for **port 80 (HTTP)** to launch **Visual Studio Code Server** (default password: `password`). From there, you can run simulations, edit code, and visualize IsaacLab environments — all from your web browser.
|
||||
|
||||
**No GPU, no drivers, no local installation. Just click and run.**
|
||||
|
||||
## Additional Notes
|
||||
|
||||
We keep EnvHub coverage aligned with the LeIsaac task. Currently supported:
|
||||
|
||||
- `so101_pick_orange`
|
||||
- `so101_lift_cube`
|
||||
- `so101_clean_toytable`
|
||||
- `bi_so101_fold_cloth`
|
||||
|
||||
Switch tasks by targeting a different script when calling `make_env`, for example:
|
||||
|
||||
```python
|
||||
envs_dict_pick_orange = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_lift_cube = make_env("LightwheelAI/leisaac_env:envs/so101_lift_cube.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_clean_toytable = make_env("LightwheelAI/leisaac_env:envs/so101_clean_toytable.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_fold_cloth = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
```
|
||||
|
||||
Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately after retrieving the env before performing any other operations:
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# NOTE: initialize() first
|
||||
env.initialize()
|
||||
|
||||
# other operation with env...
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -393,7 +393,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
@@ -415,7 +415,7 @@ for idx in range(dataset.num_frames):
|
||||
}
|
||||
robot.send_action(action)
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
```
|
||||
|
||||
@@ -0,0 +1,220 @@
|
||||
# Imitation Learning in Sim
|
||||
|
||||
This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning.
|
||||
|
||||
**You'll learn:**
|
||||
|
||||
1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset.
|
||||
2. How to train a policy using your data.
|
||||
3. How to evaluate your policy in simulation and visualize the results.
|
||||
|
||||
For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm.
|
||||
This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format.
|
||||
Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge.
|
||||
|
||||
## Installation
|
||||
|
||||
First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command:
|
||||
|
||||
```bash
|
||||
pip install -e ".[hilserl]"
|
||||
```
|
||||
|
||||
## Teleoperate and Record a Dataset
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
<hfoptions id="teleop_sim">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls.
|
||||
|
||||
Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control!
|
||||
|
||||
**Gamepad Controls**
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true"
|
||||
alt="Figure shows the control mappings on a Logitech gamepad."
|
||||
title="Gamepad Control Mapping"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Gamepad button mapping for robot control and episode management</i>
|
||||
</p>
|
||||
|
||||
**Keyboard controls**
|
||||
|
||||
For keyboard controls use the `spacebar` to enable control and the following keys to move the robot:
|
||||
|
||||
```bash
|
||||
Arrow keys: Move in X-Y plane
|
||||
Shift and Shift_R: Move in Z axis
|
||||
Right Ctrl and Left Ctrl: Open and close gripper
|
||||
ESC: Exit
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/dataset_visualizer_sim.png"
|
||||
alt="Figure shows the dataset visualizer"
|
||||
title="Dataset visualization"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Dataset visualizer</i>
|
||||
</p>
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
--job_name=il_sim_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain the command:
|
||||
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`.
|
||||
|
||||
#### Train using Collab
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test \
|
||||
outputs/train/il_sim_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json).
|
||||
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
<hfoptions id="eval_policy">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!WARNING]
|
||||
> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
|
||||
|
||||
Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
@@ -1,203 +0,0 @@
|
||||
# Unitree G1 Robot Setup and Control
|
||||
|
||||
This guide covers the complete setup process for the Unitree G1 humanoid, from initial connection to running gr00t_wbc locomotion.
|
||||
|
||||
## About the Unitree G1
|
||||
|
||||
We offer support for both 29 and 23 DOF G1. In this first PR we introduce:
|
||||
|
||||
- **`unitree g1` robot class, handling low level communication with the humanoid**
|
||||
- **ZMQ socket bridge** for remote communication over WiFi, allowing one to deploy policies remotely instead of over ethernet or directly on the Orin
|
||||
- **GR00T locomotion policy** for bipedal walking and balance
|
||||
|
||||
---
|
||||
|
||||
## Part 1: Connect to Robot over Ethernet
|
||||
|
||||
### Step 1: Configure Your Computer's Ethernet Interface
|
||||
|
||||
Set a static IP on the same subnet as the robot:
|
||||
|
||||
```bash
|
||||
# Replace 'enp131s0' with your ethernet interface name (check with `ip a`)
|
||||
sudo ip addr flush dev enp131s0
|
||||
sudo ip addr add 192.168.123.200/24 dev enp131s0
|
||||
sudo ip link set enp131s0 up
|
||||
```
|
||||
|
||||
**Note**: The robot's Ethernet IP is fixed at `192.168.123.164`. Your computer must use `192.168.123.x` where x ≠ 164.
|
||||
|
||||
### Step 2: SSH into the Robot
|
||||
|
||||
```bash
|
||||
ssh unitree@192.168.123.164
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
You should now be connected to the robot's onboard computer.
|
||||
|
||||
---
|
||||
|
||||
## Part 2: Enable WiFi on the Robot
|
||||
|
||||
Once connected via Ethernet, follow these steps to enable WiFi:
|
||||
|
||||
### Step 1: Enable WiFi Hardware
|
||||
|
||||
```bash
|
||||
# Unblock WiFi radio
|
||||
sudo rfkill unblock wifi
|
||||
sudo rfkill unblock all
|
||||
|
||||
# Bring up WiFi interface
|
||||
sudo ip link set wlan0 up
|
||||
|
||||
# Enable NetworkManager control
|
||||
sudo nmcli radio wifi on
|
||||
sudo nmcli device set wlan0 managed yes
|
||||
sudo systemctl restart NetworkManager
|
||||
```
|
||||
|
||||
### Step 2: Enable Internet Forwarding
|
||||
|
||||
**On your laptop:**
|
||||
|
||||
```bash
|
||||
# Enable IP forwarding
|
||||
sudo sysctl -w net.ipv4.ip_forward=1
|
||||
|
||||
# Set up NAT (replace wlp132s0f0 with your WiFi interface)
|
||||
sudo iptables -t nat -A POSTROUTING -o wlp132s0f0 -s 192.168.123.0/24 -j MASQUERADE
|
||||
sudo iptables -A FORWARD -i wlp132s0f0 -o enp131s0 -m state --state RELATED,ESTABLISHED -j ACCEPT
|
||||
sudo iptables -A FORWARD -i enp131s0 -o wlp132s0f0 -j ACCEPT
|
||||
```
|
||||
|
||||
**On the robot:**
|
||||
|
||||
```bash
|
||||
# Add laptop as default gateway
|
||||
sudo ip route del default 2>/dev/null || true
|
||||
sudo ip route add default via 192.168.123.200 dev eth0
|
||||
echo "nameserver 8.8.8.8" | sudo tee /etc/resolv.conf
|
||||
|
||||
# Test connection
|
||||
ping -c 3 8.8.8.8
|
||||
```
|
||||
|
||||
### Step 3: Connect to WiFi Network
|
||||
|
||||
```bash
|
||||
# List available networks
|
||||
nmcli device wifi list
|
||||
|
||||
# Connect to your WiFi (example)
|
||||
sudo nmcli connection add type wifi ifname wlan0 con-name "YourNetwork" ssid "YourNetwork"
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.key-mgmt wpa-psk
|
||||
sudo nmcli connection modify "YourNetwork" wifi-sec.psk "YourPassword"
|
||||
sudo nmcli connection modify "YourNetwork" connection.autoconnect yes
|
||||
sudo nmcli connection up "YourNetwork"
|
||||
|
||||
# Check WiFi IP address
|
||||
ip a show wlan0
|
||||
```
|
||||
|
||||
### Step 4: SSH Over WiFi
|
||||
|
||||
Once connected to WiFi, note the robot's IP address and disconnect the Ethernet cable. You can now SSH over WiFi:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
# Password: 123
|
||||
```
|
||||
|
||||
Replace `<YOUR_ROBOT_IP>` with your robot's actual WiFi IP address (e.g., `172.18.129.215`).
|
||||
|
||||
---
|
||||
|
||||
## Part 3: Robot Server Setup
|
||||
|
||||
### Step 1: Install LeRobot on the Orin
|
||||
|
||||
SSH into the robot and install LeRobot:
|
||||
|
||||
```bash
|
||||
ssh unitree@<YOUR_ROBOT_IP>
|
||||
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
**Note**: The Unitree SDK requires CycloneDDS v0.10.2 to be installed. See the [Unitree SDK documentation](https://github.com/unitreerobotics/unitree_sdk2_python) for details.
|
||||
|
||||
### Step 2: Run the Robot Server
|
||||
|
||||
On the robot:
|
||||
|
||||
```bash
|
||||
python src/lerobot/robots/unitree_g1/run_g1_server.py
|
||||
```
|
||||
|
||||
**Important**: Keep this terminal running. The server must be active for remote control.
|
||||
|
||||
---
|
||||
|
||||
## Part 4: Running GR00T Locomotion
|
||||
|
||||
With the robot server running, you can now control the robot from your laptop.
|
||||
|
||||
### Step 1: Install LeRobot on your machine
|
||||
|
||||
```bash
|
||||
conda create -y -n lerobot python=3.10
|
||||
conda activate lerobot
|
||||
git clone https://github.com/huggingface/lerobot.git
|
||||
cd lerobot
|
||||
pip install -e '.[unitree_g1]'
|
||||
git clone https://github.com/unitreerobotics/unitree_sdk2_python.git
|
||||
cd unitree_sdk2_python && pip install -e .
|
||||
```
|
||||
|
||||
### Step 2: Update Robot IP in Config
|
||||
|
||||
Edit the config file to match your robot's WiFi IP:
|
||||
|
||||
```python
|
||||
# In src/lerobot/robots/unitree_g1/config_unitree_g1.py
|
||||
robot_ip: str = "<YOUR_ROBOT_IP>" # Replace with your robot's WiFi IP.
|
||||
```
|
||||
|
||||
**Note**: When running directly on the G1 (not remotely), set `robot_ip: str = "127.0.0.1"` instead.
|
||||
|
||||
### Step 3: Run the Locomotion Policy
|
||||
|
||||
```bash
|
||||
# Run GR00T locomotion controller
|
||||
python examples/unitree_g1/gr00t_locomotion.py --repo-id "nepyope/GR00T-WholeBodyControl_g1"
|
||||
```
|
||||
|
||||
### Step 4: Control with Remote
|
||||
|
||||
- **Left stick**: Forward/backward and left/right movement
|
||||
- **Right stick**: Rotation
|
||||
- **R1 button**: Raise waist height
|
||||
- **R2 button**: Lower waist height
|
||||
|
||||
Press `Ctrl+C` to stop the policy.
|
||||
|
||||
---
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Unitree SDK Documentation](https://github.com/unitreerobotics/unitree_sdk2_python)
|
||||
- [GR00T Policy Repository](https://huggingface.co/nepyope/GR00T-WholeBodyControl_g1)
|
||||
- [LeRobot Documentation](https://github.com/huggingface/lerobot)
|
||||
- [Unitree_IL_Lerobot](https://github.com/unitreerobotics/unitree_IL_lerobot)
|
||||
|
||||
---
|
||||
|
||||
_Last updated: December 2025_
|
||||
@@ -45,7 +45,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig):
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -34,106 +34,105 @@ from huggingface_hub import HfApi
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
def main():
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
@@ -145,7 +144,3 @@ def main():
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -0,0 +1,525 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Visualize SARM Subtask Annotations
|
||||
|
||||
This script creates visualizations of the subtask annotations generated by subtask_annotation.py.
|
||||
For each episode, it shows:
|
||||
- A timeline with dashed vertical lines at subtask boundaries
|
||||
- Sample frames from the episode at key points (start, middle, end of each subtask)
|
||||
- Color-coded subtask segments
|
||||
|
||||
Usage:
|
||||
python visualize_subtask_annotations.py --repo-id pepijn223/mydataset --video-key observation.images.top --num-episodes 5
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from matplotlib.lines import Line2D
|
||||
from rich.console import Console
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import load_episodes
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
|
||||
|
||||
def timestamp_to_seconds(timestamp: str) -> float:
|
||||
"""Convert MM:SS or SS timestamp to seconds"""
|
||||
parts = timestamp.split(":")
|
||||
if len(parts) == 2:
|
||||
return int(parts[0]) * 60 + int(parts[1])
|
||||
else:
|
||||
return int(parts[0])
|
||||
|
||||
|
||||
def load_annotations_from_dataset(dataset_path: Path) -> dict[int, SubtaskAnnotation]:
|
||||
"""
|
||||
Load annotations from LeRobot dataset parquet files.
|
||||
|
||||
Reads subtask annotations from the episodes metadata parquet files.
|
||||
"""
|
||||
episodes_dataset = load_episodes(dataset_path)
|
||||
|
||||
if episodes_dataset is None or len(episodes_dataset) == 0:
|
||||
return {}
|
||||
|
||||
# Check if subtask columns exist
|
||||
if "subtask_names" not in episodes_dataset.column_names:
|
||||
return {}
|
||||
|
||||
# Convert to pandas DataFrame for easier access
|
||||
episodes_df = episodes_dataset.to_pandas()
|
||||
|
||||
annotations = {}
|
||||
|
||||
for ep_idx in episodes_df.index:
|
||||
subtask_names = episodes_df.loc[ep_idx, "subtask_names"]
|
||||
|
||||
# Skip episodes without annotations
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
continue
|
||||
|
||||
start_times = episodes_df.loc[ep_idx, "subtask_start_times"]
|
||||
end_times = episodes_df.loc[ep_idx, "subtask_end_times"]
|
||||
|
||||
# Reconstruct SubtaskAnnotation from stored data
|
||||
subtasks = []
|
||||
for i, name in enumerate(subtask_names):
|
||||
# Convert seconds back to MM:SS format
|
||||
start_sec = int(start_times[i])
|
||||
end_sec = int(end_times[i])
|
||||
start_str = f"{start_sec // 60:02d}:{start_sec % 60:02d}"
|
||||
end_str = f"{end_sec // 60:02d}:{end_sec % 60:02d}"
|
||||
|
||||
subtasks.append(
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(start=start_str, end=end_str)
|
||||
)
|
||||
)
|
||||
|
||||
annotations[int(ep_idx)] = SubtaskAnnotation(subtasks=subtasks)
|
||||
|
||||
return annotations
|
||||
|
||||
|
||||
# Color palette for subtasks (colorblind-friendly)
|
||||
SUBTASK_COLORS = [
|
||||
"#E69F00", # Orange
|
||||
"#56B4E9", # Sky blue
|
||||
"#009E73", # Bluish green
|
||||
"#F0E442", # Yellow
|
||||
"#0072B2", # Blue
|
||||
"#D55E00", # Vermillion
|
||||
"#CC79A7", # Reddish purple
|
||||
"#999999", # Gray
|
||||
]
|
||||
|
||||
|
||||
def extract_frame_from_video(video_path: Path, timestamp: float) -> np.ndarray | None:
|
||||
"""Extract a single frame from video at given timestamp."""
|
||||
cap = cv2.VideoCapture(str(video_path))
|
||||
if not cap.isOpened():
|
||||
return None
|
||||
|
||||
# Set position to timestamp
|
||||
cap.set(cv2.CAP_PROP_POS_MSEC, timestamp * 1000)
|
||||
ret, frame = cap.read()
|
||||
cap.release()
|
||||
|
||||
if ret:
|
||||
# Convert BGR to RGB
|
||||
return cv2.cvtColor(frame, cv2.COLOR_BGR2RGB)
|
||||
return None
|
||||
|
||||
|
||||
def visualize_episode(
|
||||
episode_idx: int,
|
||||
annotation,
|
||||
video_path: Path,
|
||||
video_start_timestamp: float,
|
||||
video_end_timestamp: float,
|
||||
fps: int,
|
||||
output_path: Path,
|
||||
video_key: str,
|
||||
):
|
||||
"""
|
||||
Create visualization for a single episode.
|
||||
|
||||
Shows:
|
||||
- Top row: Sample frames from the episode (one per subtask)
|
||||
- Bottom: Timeline with subtask segments and boundary lines
|
||||
"""
|
||||
subtasks = annotation.subtasks
|
||||
num_subtasks = len(subtasks)
|
||||
|
||||
if num_subtasks == 0:
|
||||
print(f"No subtasks found for episode {episode_idx}")
|
||||
return
|
||||
|
||||
# Calculate episode duration
|
||||
episode_duration = video_end_timestamp - video_start_timestamp
|
||||
|
||||
# Extract sample frames - get frame from middle of each subtask
|
||||
sample_frames = []
|
||||
frame_timestamps = []
|
||||
|
||||
for subtask in subtasks:
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
mid_sec = (start_sec + end_sec) / 2
|
||||
|
||||
# Convert to video timestamp (add video_start_timestamp offset)
|
||||
video_timestamp = video_start_timestamp + mid_sec
|
||||
frame_timestamps.append(mid_sec)
|
||||
|
||||
frame = extract_frame_from_video(video_path, video_timestamp)
|
||||
sample_frames.append(frame)
|
||||
|
||||
# Create figure
|
||||
fig = plt.figure(figsize=(16, 10))
|
||||
|
||||
# Use a dark background for better contrast
|
||||
fig.patch.set_facecolor('#1a1a2e')
|
||||
|
||||
# Calculate grid layout
|
||||
# Top section: frames (variable number of columns based on subtasks)
|
||||
# Bottom section: timeline
|
||||
|
||||
# Create gridspec
|
||||
gs = fig.add_gridspec(
|
||||
2, max(num_subtasks, 1),
|
||||
height_ratios=[2, 1],
|
||||
hspace=0.3,
|
||||
wspace=0.1,
|
||||
left=0.05, right=0.95,
|
||||
top=0.88, bottom=0.1
|
||||
)
|
||||
|
||||
# Add title
|
||||
fig.suptitle(
|
||||
f"Episode {episode_idx} - Subtask Annotations",
|
||||
fontsize=18,
|
||||
fontweight='bold',
|
||||
color='white',
|
||||
y=0.96
|
||||
)
|
||||
|
||||
# Add subtitle with video info
|
||||
fig.text(
|
||||
0.5, 0.91,
|
||||
f"Camera: {video_key} | Duration: {episode_duration:.1f}s | {num_subtasks} subtasks",
|
||||
ha='center',
|
||||
fontsize=11,
|
||||
color='#888888'
|
||||
)
|
||||
|
||||
# Plot sample frames
|
||||
for i, (frame, subtask) in enumerate(zip(sample_frames, subtasks)):
|
||||
ax = fig.add_subplot(gs[0, i])
|
||||
ax.set_facecolor('#16213e')
|
||||
|
||||
if frame is not None:
|
||||
ax.imshow(frame)
|
||||
else:
|
||||
ax.text(0.5, 0.5, "Frame\nN/A", ha='center', va='center',
|
||||
fontsize=12, color='white', transform=ax.transAxes)
|
||||
|
||||
ax.set_title(
|
||||
f"{subtask.name}",
|
||||
fontsize=10,
|
||||
fontweight='bold',
|
||||
color=SUBTASK_COLORS[i % len(SUBTASK_COLORS)],
|
||||
pad=8
|
||||
)
|
||||
ax.axis('off')
|
||||
|
||||
# Add frame timestamp below
|
||||
ax.text(
|
||||
0.5, -0.08,
|
||||
f"t={frame_timestamps[i]:.1f}s",
|
||||
ha='center',
|
||||
fontsize=9,
|
||||
color='#888888',
|
||||
transform=ax.transAxes
|
||||
)
|
||||
|
||||
# Create timeline subplot spanning all columns
|
||||
ax_timeline = fig.add_subplot(gs[1, :])
|
||||
ax_timeline.set_facecolor('#16213e')
|
||||
|
||||
# Get total duration from last subtask end time
|
||||
total_duration = timestamp_to_seconds(subtasks[-1].timestamps.end)
|
||||
|
||||
# Draw subtask segments as colored bars
|
||||
bar_height = 0.6
|
||||
bar_y = 0.5
|
||||
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
color = SUBTASK_COLORS[i % len(SUBTASK_COLORS)]
|
||||
|
||||
# Draw segment bar
|
||||
rect = mpatches.FancyBboxPatch(
|
||||
(start_sec, bar_y - bar_height/2),
|
||||
end_sec - start_sec,
|
||||
bar_height,
|
||||
boxstyle="round,pad=0.02,rounding_size=0.1",
|
||||
facecolor=color,
|
||||
edgecolor='white',
|
||||
linewidth=1.5,
|
||||
alpha=0.85
|
||||
)
|
||||
ax_timeline.add_patch(rect)
|
||||
|
||||
# Add subtask label inside bar
|
||||
mid_x = (start_sec + end_sec) / 2
|
||||
duration = end_sec - start_sec
|
||||
|
||||
# Only add text if segment is wide enough
|
||||
if duration > total_duration * 0.08:
|
||||
ax_timeline.text(
|
||||
mid_x, bar_y,
|
||||
subtask.name,
|
||||
ha='center', va='center',
|
||||
fontsize=9,
|
||||
fontweight='bold',
|
||||
color='black' if i in [3] else 'white', # Yellow needs dark text
|
||||
rotation=0 if duration > total_duration * 0.15 else 45
|
||||
)
|
||||
|
||||
# Draw boundary lines (dashed vertical lines between subtasks)
|
||||
boundary_times = []
|
||||
for i, subtask in enumerate(subtasks):
|
||||
start_sec = timestamp_to_seconds(subtask.timestamps.start)
|
||||
end_sec = timestamp_to_seconds(subtask.timestamps.end)
|
||||
|
||||
# Add start boundary (except for first subtask at t=0)
|
||||
if i == 0 and start_sec > 0:
|
||||
boundary_times.append(start_sec)
|
||||
elif i > 0:
|
||||
boundary_times.append(start_sec)
|
||||
|
||||
# Add end boundary for last subtask
|
||||
if i == len(subtasks) - 1:
|
||||
boundary_times.append(end_sec)
|
||||
|
||||
# Draw dashed lines at boundaries
|
||||
for t in boundary_times:
|
||||
ax_timeline.axvline(
|
||||
x=t,
|
||||
ymin=0.1, ymax=0.9,
|
||||
color='white',
|
||||
linestyle='--',
|
||||
linewidth=2,
|
||||
alpha=0.9
|
||||
)
|
||||
|
||||
# Add time label below line
|
||||
ax_timeline.text(
|
||||
t, 0.0,
|
||||
f"{int(t//60):02d}:{int(t%60):02d}",
|
||||
ha='center', va='top',
|
||||
fontsize=8,
|
||||
color='#cccccc'
|
||||
)
|
||||
|
||||
# Add start line at t=0
|
||||
ax_timeline.axvline(x=0, ymin=0.1, ymax=0.9, color='#00ff00', linestyle='-', linewidth=2.5, alpha=0.9)
|
||||
ax_timeline.text(0, 0.0, "00:00", ha='center', va='top', fontsize=8, color='#00ff00', fontweight='bold')
|
||||
|
||||
# Configure timeline axes
|
||||
ax_timeline.set_xlim(-total_duration * 0.02, total_duration * 1.02)
|
||||
ax_timeline.set_ylim(-0.3, 1.2)
|
||||
ax_timeline.set_xlabel("Time (seconds)", fontsize=11, color='white', labelpad=10)
|
||||
ax_timeline.set_ylabel("")
|
||||
|
||||
# Style the axes
|
||||
ax_timeline.spines['top'].set_visible(False)
|
||||
ax_timeline.spines['right'].set_visible(False)
|
||||
ax_timeline.spines['left'].set_visible(False)
|
||||
ax_timeline.spines['bottom'].set_color('#444444')
|
||||
ax_timeline.tick_params(axis='x', colors='#888888', labelsize=9)
|
||||
ax_timeline.tick_params(axis='y', left=False, labelleft=False)
|
||||
|
||||
# Add x-axis ticks at regular intervals
|
||||
tick_interval = max(1, int(total_duration / 10))
|
||||
ax_timeline.set_xticks(np.arange(0, total_duration + tick_interval, tick_interval))
|
||||
|
||||
# Add legend explaining line styles
|
||||
legend_elements = [
|
||||
Line2D([0], [0], color='#00ff00', linewidth=2.5, linestyle='-', label='Start'),
|
||||
Line2D([0], [0], color='white', linewidth=2, linestyle='--', label='Subtask boundary'),
|
||||
]
|
||||
ax_timeline.legend(
|
||||
handles=legend_elements,
|
||||
loc='upper right',
|
||||
framealpha=0.3,
|
||||
facecolor='#16213e',
|
||||
edgecolor='#444444',
|
||||
fontsize=9,
|
||||
labelcolor='white'
|
||||
)
|
||||
|
||||
# Save figure
|
||||
plt.savefig(output_path, dpi=150, facecolor=fig.get_facecolor(), edgecolor='none', bbox_inches='tight')
|
||||
plt.close()
|
||||
|
||||
return output_path
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Visualize SARM subtask annotations",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-episodes",
|
||||
type=int,
|
||||
default=5,
|
||||
help="Number of random episodes to visualize (default: 5)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episodes",
|
||||
type=int,
|
||||
nargs="+",
|
||||
default=None,
|
||||
help="Specific episode indices to visualize (overrides --num-episodes)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--video-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Camera/video key to use. If not specified, uses first available.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="./subtask_viz",
|
||||
help="Output directory for visualizations (default: ./subtask_viz)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--seed",
|
||||
type=int,
|
||||
default=None,
|
||||
help="Random seed for reproducibility",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
console = Console()
|
||||
|
||||
# Set random seed if specified
|
||||
if args.seed is not None:
|
||||
random.seed(args.seed)
|
||||
|
||||
console.print(f"\n[cyan]Loading dataset: {args.repo_id}[/cyan]")
|
||||
dataset = LeRobotDataset(args.repo_id, download_videos=True)
|
||||
fps = dataset.fps
|
||||
|
||||
# Get video key
|
||||
if args.video_key:
|
||||
if args.video_key not in dataset.meta.video_keys:
|
||||
console.print(f"[red]Error: Video key '{args.video_key}' not found[/red]")
|
||||
console.print(f"[yellow]Available: {', '.join(dataset.meta.video_keys)}[/yellow]")
|
||||
return
|
||||
video_key = args.video_key
|
||||
else:
|
||||
video_key = dataset.meta.video_keys[0]
|
||||
|
||||
console.print(f"[cyan]Using camera: {video_key}[/cyan]")
|
||||
console.print(f"[cyan]FPS: {fps}[/cyan]")
|
||||
|
||||
# Load annotations
|
||||
console.print(f"\n[cyan]Loading annotations...[/cyan]")
|
||||
annotations = load_annotations_from_dataset(dataset.root)
|
||||
|
||||
if not annotations:
|
||||
console.print("[red]Error: No annotations found in dataset[/red]")
|
||||
console.print("[yellow]Run subtask_annotation.py first to generate annotations[/yellow]")
|
||||
return
|
||||
|
||||
console.print(f"[green]Found {len(annotations)} annotated episodes[/green]")
|
||||
|
||||
# Determine which episodes to visualize
|
||||
if args.episodes:
|
||||
episode_indices = args.episodes
|
||||
# Validate episodes exist
|
||||
for ep in episode_indices:
|
||||
if ep not in annotations:
|
||||
console.print(f"[yellow]Warning: Episode {ep} has no annotation, skipping[/yellow]")
|
||||
episode_indices = [ep for ep in episode_indices if ep in annotations]
|
||||
else:
|
||||
# Random selection
|
||||
available_episodes = list(annotations.keys())
|
||||
num_to_select = min(args.num_episodes, len(available_episodes))
|
||||
episode_indices = random.sample(available_episodes, num_to_select)
|
||||
episode_indices.sort()
|
||||
|
||||
if not episode_indices:
|
||||
console.print("[red]Error: No valid episodes to visualize[/red]")
|
||||
return
|
||||
|
||||
console.print(f"[cyan]Visualizing episodes: {episode_indices}[/cyan]")
|
||||
|
||||
# Create output directory
|
||||
output_dir = Path(args.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Generate visualizations
|
||||
for ep_idx in episode_indices:
|
||||
console.print(f"\n[cyan]Processing episode {ep_idx}...[/cyan]")
|
||||
|
||||
annotation = annotations[ep_idx]
|
||||
|
||||
# Get video path and timestamps
|
||||
video_path = dataset.root / dataset.meta.get_video_file_path(ep_idx, video_key)
|
||||
|
||||
if not video_path.exists():
|
||||
console.print(f"[red]Video not found: {video_path}[/red]")
|
||||
continue
|
||||
|
||||
# Get episode-specific timestamps within the video file
|
||||
video_path_key = f"videos/{video_key}/from_timestamp"
|
||||
video_path_key_to = f"videos/{video_key}/to_timestamp"
|
||||
|
||||
video_start_timestamp = float(dataset.meta.episodes[video_path_key][ep_idx])
|
||||
video_end_timestamp = float(dataset.meta.episodes[video_path_key_to][ep_idx])
|
||||
|
||||
# Create visualization
|
||||
output_path = output_dir / f"episode_{ep_idx:04d}_subtasks.png"
|
||||
|
||||
try:
|
||||
visualize_episode(
|
||||
episode_idx=ep_idx,
|
||||
annotation=annotation,
|
||||
video_path=video_path,
|
||||
video_start_timestamp=video_start_timestamp,
|
||||
video_end_timestamp=video_end_timestamp,
|
||||
fps=fps,
|
||||
output_path=output_path,
|
||||
video_key=video_key,
|
||||
)
|
||||
console.print(f"[green]✓ Saved: {output_path}[/green]")
|
||||
except Exception as e:
|
||||
console.print(f"[red]✗ Failed to visualize episode {ep_idx}: {e}[/red]")
|
||||
|
||||
# Print summary
|
||||
console.print(f"\n[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"[bold green]Visualization Complete![/bold green]")
|
||||
console.print(f"[bold green]{'=' * 50}[/bold green]")
|
||||
console.print(f"Output directory: {output_dir.absolute()}")
|
||||
console.print(f"Episodes visualized: {len(episode_indices)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+80
-86
@@ -33,68 +33,83 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
robot = LeKiwiClient(robot_config)
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -103,42 +118,21 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
+76
-82
@@ -34,62 +34,78 @@ RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -97,45 +113,23 @@ def main():
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
+26
-32
@@ -20,48 +20,42 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
robot.disconnect()
|
||||
|
||||
@@ -19,60 +19,54 @@ import time
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
+123
-131
@@ -52,114 +52,125 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -168,40 +179,21 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
+132
-141
@@ -50,122 +50,133 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -173,42 +184,22 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -29,78 +29,72 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
@@ -32,90 +32,82 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
def main():
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
@@ -52,114 +52,126 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -168,40 +180,21 @@ def main():
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -48,122 +48,134 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -171,42 +183,22 @@ def main():
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
@@ -30,78 +30,72 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
@@ -32,96 +32,90 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
def main():
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
@@ -19,86 +19,80 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
|
||||
@@ -8,56 +8,50 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
robot.send_action(action)
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
@@ -1,17 +1,11 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
@@ -6,56 +6,50 @@ from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
@@ -19,87 +19,81 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
|
||||
@@ -8,57 +8,53 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_diffusion"
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
@@ -11,63 +11,57 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
@@ -20,8 +20,6 @@ from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
def run_learner(
|
||||
@@ -225,123 +223,123 @@ def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
}
|
||||
|
||||
|
||||
def main():
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
|
||||
@@ -4,64 +4,59 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
def main():
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
|
||||
@@ -11,62 +11,56 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
@@ -1,347 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: GR00T Locomotion with Pre-loaded Policies
|
||||
|
||||
This example demonstrates the NEW pattern for loading GR00T policies externally
|
||||
and passing them to the robot class.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
GROOT_DEFAULT_ANGLES = np.zeros(29, dtype=np.float32)
|
||||
GROOT_DEFAULT_ANGLES[[0, 6]] = -0.1 # hip pitch
|
||||
GROOT_DEFAULT_ANGLES[[3, 9]] = 0.3 # knee
|
||||
GROOT_DEFAULT_ANGLES[[4, 10]] = -0.2 # ankle pitch
|
||||
|
||||
MISSING_JOINTS = []
|
||||
G1_MODEL = "g1_23" # or "g1_29"
|
||||
if G1_MODEL == "g1_23":
|
||||
MISSING_JOINTS = [12, 14, 20, 21, 27, 28] # waist yaw/pitch, wrist pitch/yaw
|
||||
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
|
||||
LOCOMOTION_CONTROL_DT = 0.02
|
||||
|
||||
ANG_VEL_SCALE: float = 0.25
|
||||
DOF_POS_SCALE: float = 1.0
|
||||
DOF_VEL_SCALE: float = 0.05
|
||||
CMD_SCALE: list = [2.0, 2.0, 0.25]
|
||||
|
||||
|
||||
DEFAULT_GROOT_REPO_ID = "nepyope/GR00T-WholeBodyControl_g1"
|
||||
|
||||
|
||||
def load_groot_policies(
|
||||
repo_id: str = DEFAULT_GROOT_REPO_ID,
|
||||
) -> tuple[ort.InferenceSession, ort.InferenceSession]:
|
||||
"""Load GR00T dual-policy system (Balance + Walk) from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the ONNX policies.
|
||||
"""
|
||||
logger.info(f"Loading GR00T dual-policy system from Hugging Face Hub ({repo_id})...")
|
||||
|
||||
# Download ONNX policies from Hugging Face Hub
|
||||
balance_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Balance.onnx",
|
||||
)
|
||||
walk_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename="GR00T-WholeBodyControl-Walk.onnx",
|
||||
)
|
||||
|
||||
# Load ONNX policies
|
||||
policy_balance = ort.InferenceSession(balance_path)
|
||||
policy_walk = ort.InferenceSession(walk_path)
|
||||
|
||||
logger.info("GR00T policies loaded successfully")
|
||||
|
||||
return policy_balance, policy_walk
|
||||
|
||||
|
||||
class GrootLocomotionController:
|
||||
"""
|
||||
Handles GR00T-style locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Dual-policy system (Balance + Walk)
|
||||
- 29-joint observation processing
|
||||
- 15D action output (legs + waist)
|
||||
- Policy inference and motor command generation
|
||||
"""
|
||||
|
||||
def __init__(self, policy_balance, policy_walk, robot, config):
|
||||
self.policy_balance = policy_balance
|
||||
self.policy_walk = policy_walk
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32) # vx, vy, theta_dot
|
||||
|
||||
# GR00T-specific state
|
||||
self.groot_qj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_dqj_all = np.zeros(29, dtype=np.float32)
|
||||
self.groot_action = np.zeros(15, dtype=np.float32)
|
||||
self.groot_obs_single = np.zeros(86, dtype=np.float32)
|
||||
self.groot_obs_history = deque(maxlen=6)
|
||||
self.groot_obs_stacked = np.zeros(516, dtype=np.float32)
|
||||
self.groot_height_cmd = 0.74 # Default base height
|
||||
self.groot_orientation_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# input to gr00t is 6 frames (6*86D=516)
|
||||
for _ in range(6):
|
||||
self.groot_obs_history.append(np.zeros(86, dtype=np.float32))
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("GrootLocomotionController initialized")
|
||||
|
||||
def groot_locomotion_run(self):
|
||||
# get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
if self.robot.remote_controller.button[0]: # R1 - raise waist
|
||||
self.groot_height_cmd += 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
if self.robot.remote_controller.button[4]: # R2 - lower waist
|
||||
self.groot_height_cmd -= 0.001
|
||||
self.groot_height_cmd = np.clip(self.groot_height_cmd, 0.50, 1.00)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # rotation rate
|
||||
|
||||
for i in range(29):
|
||||
self.groot_qj_all[i] = robot_state.motor_state[i].q
|
||||
self.groot_dqj_all[i] = robot_state.motor_state[i].dq
|
||||
|
||||
# adapt observation for g1_23dof
|
||||
for idx in MISSING_JOINTS:
|
||||
self.groot_qj_all[idx] = 0.0
|
||||
self.groot_dqj_all[idx] = 0.0
|
||||
|
||||
# Scale joint positions and velocities
|
||||
qj_obs = self.groot_qj_all.copy()
|
||||
dqj_obs = self.groot_dqj_all.copy()
|
||||
|
||||
# express imu data in gravity frame of reference
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# scale joint positions and velocities before policy inference
|
||||
qj_obs = (qj_obs - GROOT_DEFAULT_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = dqj_obs * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# build single frame observation
|
||||
self.groot_obs_single[:3] = self.locomotion_cmd * np.array(CMD_SCALE)
|
||||
self.groot_obs_single[3] = self.groot_height_cmd
|
||||
self.groot_obs_single[4:7] = self.groot_orientation_cmd
|
||||
self.groot_obs_single[7:10] = ang_vel_scaled
|
||||
self.groot_obs_single[10:13] = gravity_orientation
|
||||
self.groot_obs_single[13:42] = qj_obs
|
||||
self.groot_obs_single[42:71] = dqj_obs
|
||||
self.groot_obs_single[71:86] = self.groot_action # 15D previous actions
|
||||
|
||||
# Add to history and stack observations (6 frames × 86D = 516D)
|
||||
self.groot_obs_history.append(self.groot_obs_single.copy())
|
||||
|
||||
# Stack all 6 frames into 516D vector
|
||||
for i, obs_frame in enumerate(self.groot_obs_history):
|
||||
start_idx = i * 86
|
||||
end_idx = start_idx + 86
|
||||
self.groot_obs_stacked[start_idx:end_idx] = obs_frame
|
||||
|
||||
# Run policy inference (ONNX) with 516D stacked observation
|
||||
|
||||
cmd_magnitude = np.linalg.norm(self.locomotion_cmd)
|
||||
|
||||
selected_policy = (
|
||||
self.policy_balance if cmd_magnitude < 0.05 else self.policy_walk
|
||||
) # balance/standing policy for small commands, walking policy for movement commands
|
||||
|
||||
# run policy inference
|
||||
ort_inputs = {selected_policy.get_inputs()[0].name: np.expand_dims(self.groot_obs_stacked, axis=0)}
|
||||
ort_outs = selected_policy.run(None, ort_inputs)
|
||||
self.groot_action = ort_outs[0].squeeze()
|
||||
|
||||
# transform action back to target joint positions
|
||||
target_dof_pos_15 = GROOT_DEFAULT_ANGLES[:15] + self.groot_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# command motors
|
||||
for i in range(15):
|
||||
motor_idx = i
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos_15[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# adapt action for g1_23dof
|
||||
for joint_idx in MISSING_JOINTS:
|
||||
self.robot.msg.motor_cmd[joint_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[joint_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[joint_idx].kp = self.robot.kp[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].kd = self.robot.kd[joint_idx]
|
||||
self.robot.msg.motor_cmd[joint_idx].tau = 0
|
||||
|
||||
# send action to robot
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.groot_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move robot legs to default standing position over 2 seconds (arms are not moved)."""
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Only control legs, not arms (first 12 joints)
|
||||
default_pos = GROOT_DEFAULT_ANGLES # First 12 values are leg angles
|
||||
dof_size = len(default_pos)
|
||||
|
||||
# Get current lowstate
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record the current leg positions
|
||||
init_dof_pos = np.zeros(dof_size, dtype=np.float32)
|
||||
for i in range(dof_size):
|
||||
init_dof_pos[i] = robot_state.motor_state[i].q
|
||||
|
||||
# Move legs to default pos
|
||||
for i in range(num_step):
|
||||
alpha = i / num_step
|
||||
for motor_idx in range(dof_size):
|
||||
target_pos = default_pos[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_dof_pos[motor_idx] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.robot.kp[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.robot.kd[motor_idx]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
logger.info("Reached default position (legs only)")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="GR00T Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_GROOT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for GR00T policies (default: {DEFAULT_GROOT_REPO_ID})",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# load policies
|
||||
policy_balance, policy_walk = load_groot_policies(repo_id=args.repo_id)
|
||||
|
||||
# initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# initialize gr00t locomotion controller
|
||||
groot_controller = GrootLocomotionController(
|
||||
policy_balance=policy_balance,
|
||||
policy_walk=policy_walk,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# reset legs and start locomotion thread
|
||||
try:
|
||||
groot_controller.reset_robot()
|
||||
groot_controller.start_locomotion_thread()
|
||||
|
||||
# log status
|
||||
logger.info("Robot initialized with GR00T locomotion policies")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
groot_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -1,479 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Holosoma Whole-Body Locomotion (23-DOF and 29-DOF)
|
||||
|
||||
This example demonstrates loading Holosoma whole-body locomotion policies
|
||||
and running them on the Unitree G1 robot.
|
||||
|
||||
Supports both:
|
||||
- 23-DOF native policies (82D observations, 23D actions)
|
||||
- 29-DOF policies (100D observations, 29D actions)
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
from huggingface_hub import hf_hub_download
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# =============================================================================
|
||||
# 29-DOF Configuration
|
||||
# =============================================================================
|
||||
# fmt: off
|
||||
HOLOSOMA_29DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg
|
||||
0.0, 0.0, 0.0, # waist (yaw, roll, pitch)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # left arm
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, 0.0, 0.0, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, 28.501246196, 28.501246196, # waist
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, 16.778327481, 16.778327481, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_29DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, 1.814445687, 1.814445687, # waist
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, 1.068141502, 1.068141502, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# =============================================================================
|
||||
# 23-DOF Configuration (native G1-23: no waist_roll/pitch, no wrist_pitch/yaw)
|
||||
# Derived from 29-DOF Holosoma values
|
||||
# =============================================================================
|
||||
# Joint order: 6 left leg, 6 right leg, 1 waist_yaw, 5 left arm, 5 right arm
|
||||
HOLOSOMA_23DOF_DEFAULT_ANGLES = np.array([
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # left leg (from 29-DOF)
|
||||
-0.312, 0.0, 0.0, 0.669, -0.363, 0.0, # right leg (from 29-DOF)
|
||||
0.0, # waist_yaw only (from 29-DOF)
|
||||
0.2, 0.2, 0.0, 0.6, 0.0, # left arm first 5 joints (from 29-DOF)
|
||||
0.2, -0.2, 0.0, 0.6, 0.0, # right arm first 5 joints (from 29-DOF)
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KP = np.array([
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # left leg
|
||||
40.179238471, 99.098427777, 40.179238471, 99.098427777, 28.501246196, 28.501246196, # right leg
|
||||
40.179238471, # waist_yaw
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # left arm
|
||||
14.250623098, 14.250623098, 14.250623098, 14.250623098, 14.250623098, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
HOLOSOMA_23DOF_KD = np.array([
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # left leg
|
||||
2.557889765, 6.308801854, 2.557889765, 6.308801854, 1.814445687, 1.814445687, # right leg
|
||||
2.557889765, # waist_yaw
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # left arm
|
||||
0.907222843, 0.907222843, 0.907222843, 0.907222843, 0.907222843, # right arm
|
||||
], dtype=np.float32)
|
||||
|
||||
# Maps 23-DOF policy index → 29-DOF motor index
|
||||
# 23-DOF: legs(0-11), waist_yaw(12), L_arm(13-17), R_arm(18-22)
|
||||
# 29-DOF: legs(0-11), waist(12-14), L_arm(15-21), R_arm(22-28)
|
||||
DOF_23_TO_MOTOR_MAP = [
|
||||
0, 1, 2, 3, 4, 5, # left leg → motor 0-5
|
||||
6, 7, 8, 9, 10, 11, # right leg → motor 6-11
|
||||
12, # waist_yaw → motor 12
|
||||
15, 16, 17, 18, 19, # left arm (skip wrist_pitch/yaw) → motor 15-19
|
||||
22, 23, 24, 25, 26, # right arm (skip wrist_pitch/yaw) → motor 22-26
|
||||
]
|
||||
# fmt: on
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
GAIT_PERIOD = 1.0
|
||||
|
||||
DEFAULT_HOLOSOMA_REPO_ID = "nepyope/holosoma_locomotion"
|
||||
|
||||
|
||||
def load_holosoma_policy(
|
||||
repo_id: str = DEFAULT_HOLOSOMA_REPO_ID,
|
||||
policy_name: str = "fastsac",
|
||||
local_path: str | None = None,
|
||||
) -> tuple[ort.InferenceSession, int]:
|
||||
"""Load Holosoma policy and detect observation dimension.
|
||||
|
||||
Returns:
|
||||
(policy, obs_dim) tuple where obs_dim is 82 (23-DOF) or 100 (29-DOF)
|
||||
"""
|
||||
if local_path is not None:
|
||||
logger.info(f"Loading policy from local path: {local_path}")
|
||||
policy_path = local_path
|
||||
else:
|
||||
logger.info(f"Loading policy from Hugging Face Hub: {repo_id}")
|
||||
policy_path = hf_hub_download(repo_id=repo_id, filename=f"{policy_name}_g1_29dof.onnx")
|
||||
|
||||
policy = ort.InferenceSession(policy_path)
|
||||
|
||||
# Detect observation dimension from model input shape
|
||||
input_shape = policy.get_inputs()[0].shape
|
||||
obs_dim = input_shape[1] if len(input_shape) > 1 else input_shape[0]
|
||||
|
||||
logger.info(f"Policy loaded successfully")
|
||||
logger.info(f" Input: {policy.get_inputs()[0].name}, shape: {input_shape} → obs_dim={obs_dim}")
|
||||
logger.info(f" Output: {policy.get_outputs()[0].name}, shape: {policy.get_outputs()[0].shape}")
|
||||
|
||||
return policy, obs_dim
|
||||
|
||||
|
||||
class HolosomaLocomotionController:
|
||||
"""
|
||||
Handles Holosoma whole-body locomotion for Unitree G1.
|
||||
Supports both 23-DOF (82D obs) and 29-DOF (100D obs) policies.
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config, obs_dim: int = 100):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
self.obs_dim = obs_dim
|
||||
|
||||
# Detect policy type from observation dimension
|
||||
self.is_23dof = (obs_dim == 82)
|
||||
self.num_dof = 23 if self.is_23dof else 29
|
||||
|
||||
# Velocity commands
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables sized for policy type
|
||||
self.qj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.dqj = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(obs_dim, dtype=np.float32)
|
||||
self.last_unscaled_action = np.zeros(self.num_dof, dtype=np.float32)
|
||||
|
||||
# Select config based on DOF
|
||||
if self.is_23dof:
|
||||
self.default_angles = HOLOSOMA_23DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_23DOF_KP
|
||||
self.kd = HOLOSOMA_23DOF_KD
|
||||
self.motor_map = DOF_23_TO_MOTOR_MAP
|
||||
else:
|
||||
self.default_angles = HOLOSOMA_29DOF_DEFAULT_ANGLES
|
||||
self.kp = HOLOSOMA_29DOF_KP
|
||||
self.kd = HOLOSOMA_29DOF_KD
|
||||
self.motor_map = list(range(29)) # Identity map for 29-DOF
|
||||
|
||||
# Phase state for gait
|
||||
self.phase = np.zeros((1, 2), dtype=np.float32)
|
||||
self.phase[0, 0] = 0.0
|
||||
self.phase[0, 1] = np.pi
|
||||
self.phase_dt = 2 * np.pi / (50.0 * GAIT_PERIOD)
|
||||
self.is_standing = False
|
||||
|
||||
self.counter = 0
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info(f"HolosomaLocomotionController initialized")
|
||||
logger.info(f" Mode: {'23-DOF (82D obs)' if self.is_23dof else '29-DOF (100D obs)'}")
|
||||
logger.info(f" Action dim: {self.num_dof}")
|
||||
|
||||
def holosoma_locomotion_run(self):
|
||||
"""Main locomotion loop - handles both 23-DOF and 29-DOF."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print(f"🚀 RUNNING HOLOSOMA {self.num_dof}-DOF LOCOMOTION POLICY")
|
||||
print(f" {self.obs_dim}D observations → {self.num_dof}D actions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
# Deadzone
|
||||
ly = self.robot.remote_controller.ly if abs(self.robot.remote_controller.ly) > 0.1 else 0.0
|
||||
lx = self.robot.remote_controller.lx if abs(self.robot.remote_controller.lx) > 0.1 else 0.0
|
||||
rx = self.robot.remote_controller.rx if abs(self.robot.remote_controller.rx) > 0.1 else 0.0
|
||||
|
||||
self.locomotion_cmd[0] = ly
|
||||
self.locomotion_cmd[1] = -lx
|
||||
self.locomotion_cmd[2] = -rx
|
||||
|
||||
# Read joint states using motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# IMU
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
|
||||
# Scale observations
|
||||
qj_obs = (self.qj - self.default_angles) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Phase update
|
||||
cmd_norm = np.linalg.norm(self.locomotion_cmd[:2])
|
||||
ang_cmd_norm = np.abs(self.locomotion_cmd[2])
|
||||
|
||||
if cmd_norm < 0.01 and ang_cmd_norm < 0.01:
|
||||
self.phase[0, :] = np.pi * np.ones(2)
|
||||
self.is_standing = True
|
||||
elif self.is_standing:
|
||||
self.phase = np.array([[0.0, np.pi]], dtype=np.float32)
|
||||
self.is_standing = False
|
||||
else:
|
||||
phase_tp1 = self.phase + self.phase_dt
|
||||
self.phase = np.fmod(phase_tp1 + np.pi, 2 * np.pi) - np.pi
|
||||
|
||||
sin_phase = np.sin(self.phase[0, :])
|
||||
cos_phase = np.cos(self.phase[0, :])
|
||||
|
||||
# Build observation (format depends on DOF)
|
||||
if self.is_23dof:
|
||||
# 82D: [23 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 23 pos, 23 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:23] = self.last_unscaled_action
|
||||
self.locomotion_obs[23:26] = ang_vel_scaled
|
||||
self.locomotion_obs[26] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[27:29] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[29:31] = cos_phase
|
||||
self.locomotion_obs[31:54] = qj_obs
|
||||
self.locomotion_obs[54:77] = dqj_obs
|
||||
self.locomotion_obs[77:80] = gravity_orientation
|
||||
self.locomotion_obs[80:82] = sin_phase
|
||||
else:
|
||||
# 100D: [29 actions, 3 ang_vel, 1 cmd_yaw, 2 cmd_lin, 2 cos, 29 pos, 29 vel, 3 grav, 2 sin]
|
||||
self.locomotion_obs[0:29] = self.last_unscaled_action
|
||||
self.locomotion_obs[29:32] = ang_vel_scaled
|
||||
self.locomotion_obs[32] = self.locomotion_cmd[2]
|
||||
self.locomotion_obs[33:35] = self.locomotion_cmd[:2]
|
||||
self.locomotion_obs[35:37] = cos_phase
|
||||
self.locomotion_obs[37:66] = qj_obs
|
||||
self.locomotion_obs[66:95] = dqj_obs
|
||||
self.locomotion_obs[95:98] = gravity_orientation
|
||||
self.locomotion_obs[98:100] = sin_phase
|
||||
|
||||
# Policy inference
|
||||
obs_input = self.locomotion_obs.reshape(1, -1).astype(np.float32)
|
||||
ort_inputs = {self.policy.get_inputs()[0].name: obs_input}
|
||||
ort_outs = self.policy.run(None, ort_inputs)
|
||||
|
||||
raw_action = ort_outs[0].squeeze()
|
||||
clipped_action = np.clip(raw_action, -100.0, 100.0)
|
||||
|
||||
self.last_unscaled_action = clipped_action.copy()
|
||||
self.locomotion_action = clipped_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Holosoma Debug #{self.counter}]")
|
||||
print(f" Phase: ({self.phase[0, 0]:.3f}, {self.phase[0, 1]:.3f})")
|
||||
print(f" Cmd: ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{raw_action.min():.3f}, {raw_action.max():.3f}]")
|
||||
|
||||
# Compute target positions
|
||||
target_dof_pos = self.default_angles + self.locomotion_action
|
||||
|
||||
# Send commands to motors via motor map
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_dof_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# For 23-DOF: zero out missing joints (waist_roll/pitch, wrist_pitch/yaw)
|
||||
if self.is_23dof:
|
||||
missing_motors = [13, 14, 20, 21, 27, 28] # waist_roll, waist_pitch, wrist_pitch/yaw
|
||||
for motor_idx in missing_motors:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.holosoma_locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move joints to default position."""
|
||||
logger.info(f"Moving {self.num_dof} joints to default position...")
|
||||
|
||||
total_time = 3.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Record current positions
|
||||
init_dof_pos = np.zeros(self.num_dof, dtype=np.float32)
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
init_dof_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate to target
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
target = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = init_dof_pos[i] * (1 - alpha) + target * alpha
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Zero missing joints for 23-DOF
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info(f"Reached default position ({self.num_dof} joints)")
|
||||
|
||||
# Hold for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_steps = int(2.0 / self.robot.control_dt)
|
||||
for _ in range(hold_steps):
|
||||
for i in range(self.num_dof):
|
||||
motor_idx = self.motor_map[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.default_angles[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = self.kp[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = self.kd[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
if self.is_23dof:
|
||||
for motor_idx in [13, 14, 20, 21, 27, 28]:
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = 40.0
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = 2.0
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Holosoma Locomotion Controller for Unitree G1")
|
||||
parser.add_argument("--repo-id", type=str, default=DEFAULT_HOLOSOMA_REPO_ID)
|
||||
parser.add_argument("--policy", type=str, default="fastsac", choices=["fastsac", "ppo"])
|
||||
parser.add_argument("--local-path", type=str, default=None, help="Path to local ONNX file")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy and detect dimensions
|
||||
policy, obs_dim = load_holosoma_policy(
|
||||
repo_id=args.repo_id,
|
||||
policy_name=args.policy,
|
||||
local_path=args.local_path,
|
||||
)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize controller with detected obs_dim
|
||||
controller = HolosomaLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
obs_dim=obs_dim,
|
||||
)
|
||||
|
||||
try:
|
||||
controller.reset_robot()
|
||||
controller.start_locomotion_thread()
|
||||
|
||||
logger.info(f"Robot initialized with Holosoma {'23-DOF' if obs_dim == 82 else '29-DOF'} policy")
|
||||
logger.info("Use remote controller: LY=fwd/back, LX=left/right, RX=rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
@@ -1,447 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""
|
||||
Example: Unitree RL 12-DOF Legs-Only Locomotion (TorchScript)
|
||||
|
||||
This example demonstrates loading a 12-DOF legs-only locomotion policy
|
||||
(TorchScript .pt format) and running it on the Unitree G1 robot.
|
||||
|
||||
Key characteristics:
|
||||
- Single TorchScript policy (.pt)
|
||||
- 47D observations, 12D actions (legs only)
|
||||
- Phase-based gait timing
|
||||
- Arms and waist held at fixed positions
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import threading
|
||||
import time
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from huggingface_hub import hf_hub_download
|
||||
from scipy.spatial.transform import Rotation as R
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
from lerobot.robots.unitree_g1.unitree_g1 import UnitreeG1
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# 12-DOF leg joint configuration
|
||||
# Joint order: [L_hip_pitch, L_hip_roll, L_hip_yaw, L_knee, L_ankle_pitch, L_ankle_roll,
|
||||
# R_hip_pitch, R_hip_roll, R_hip_yaw, R_knee, R_ankle_pitch, R_ankle_roll]
|
||||
LEG_JOINT_INDICES = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11]
|
||||
|
||||
# Default leg angles for standing
|
||||
DEFAULT_LEG_ANGLES = np.array([
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # left leg
|
||||
-0.1, 0.0, 0.0, 0.3, -0.2, 0.0, # right leg
|
||||
], dtype=np.float32)
|
||||
|
||||
# KP/KD for leg joints
|
||||
LEG_KPS = np.array([150, 150, 150, 300, 40, 40, 150, 150, 150, 300, 40, 40], dtype=np.float32)
|
||||
LEG_KDS = np.array([6, 6, 6, 4, 2, 2, 6, 6, 6, 4, 2, 2], dtype=np.float32)
|
||||
|
||||
# Waist configuration (held at zero)
|
||||
WAIST_JOINT_INDICES = [12, 13, 14] # yaw, roll, pitch
|
||||
WAIST_KPS = np.array([250, 250, 250], dtype=np.float32)
|
||||
WAIST_KDS = np.array([5, 5, 5], dtype=np.float32)
|
||||
|
||||
# Arm configuration (indices 15-28, held at initial position)
|
||||
ARM_JOINT_INDICES = list(range(15, 29))
|
||||
ARM_KPS = np.array([80, 80, 80, 80, 40, 40, 40, # left arm (shoulder + wrist)
|
||||
80, 80, 80, 80, 40, 40, 40], dtype=np.float32) # right arm
|
||||
ARM_KDS = np.array([3, 3, 3, 3, 1.5, 1.5, 1.5,
|
||||
3, 3, 3, 3, 1.5, 1.5, 1.5], dtype=np.float32)
|
||||
|
||||
# Control parameters
|
||||
LOCOMOTION_CONTROL_DT = 0.02 # 50Hz control rate
|
||||
LOCOMOTION_ACTION_SCALE = 0.25
|
||||
ANG_VEL_SCALE = 0.25
|
||||
DOF_POS_SCALE = 1.0
|
||||
DOF_VEL_SCALE = 0.05
|
||||
CMD_SCALE = np.array([2.0, 2.0, 0.25], dtype=np.float32)
|
||||
MAX_CMD = np.array([0.8, 0.5, 1.57], dtype=np.float32) # max vx, vy, yaw_rate
|
||||
|
||||
# Gait parameters
|
||||
GAIT_PERIOD = 0.8 # seconds
|
||||
|
||||
DEFAULT_REPO_ID = "nepyope/unitree_rl_locomotion"
|
||||
|
||||
|
||||
def load_torchscript_policy(
|
||||
repo_id: str = DEFAULT_REPO_ID,
|
||||
filename: str = "motion.pt",
|
||||
) -> torch.jit.ScriptModule:
|
||||
"""Load TorchScript locomotion policy from Hugging Face Hub.
|
||||
|
||||
Args:
|
||||
repo_id: Hugging Face Hub repository ID containing the policy.
|
||||
filename: Policy filename (default: motion.pt).
|
||||
"""
|
||||
logger.info(f"Loading TorchScript policy from Hugging Face Hub ({repo_id}/{filename})...")
|
||||
|
||||
policy_path = hf_hub_download(
|
||||
repo_id=repo_id,
|
||||
filename=filename,
|
||||
)
|
||||
|
||||
policy = torch.jit.load(policy_path)
|
||||
policy.eval()
|
||||
|
||||
logger.info("TorchScript policy loaded successfully")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
class UnitreeRLLocomotionController:
|
||||
"""
|
||||
Handles 12-DOF legs-only locomotion control for the Unitree G1 robot.
|
||||
|
||||
This controller manages:
|
||||
- Single TorchScript policy
|
||||
- 47D observations (single frame)
|
||||
- 12D action output (legs only)
|
||||
- Arms and waist held at fixed positions
|
||||
- Phase-based gait timing
|
||||
"""
|
||||
|
||||
def __init__(self, policy, robot, config):
|
||||
self.policy = policy
|
||||
self.robot = robot
|
||||
self.config = config
|
||||
|
||||
# Velocity commands (vx, vy, yaw_rate)
|
||||
self.locomotion_cmd = np.array([0.0, 0.0, 0.0], dtype=np.float32)
|
||||
|
||||
# State variables (12 DOF legs)
|
||||
self.qj = np.zeros(12, dtype=np.float32)
|
||||
self.dqj = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_action = np.zeros(12, dtype=np.float32)
|
||||
self.locomotion_obs = np.zeros(47, dtype=np.float32)
|
||||
|
||||
# Initial arm positions (captured on reset)
|
||||
self.initial_arm_positions = np.zeros(14, dtype=np.float32)
|
||||
|
||||
# Counter for phase calculation
|
||||
self.counter = 0
|
||||
|
||||
# Thread management
|
||||
self.locomotion_running = False
|
||||
self.locomotion_thread = None
|
||||
|
||||
logger.info("UnitreeRLLocomotionController initialized")
|
||||
logger.info(" Observation dim: 47, Action dim: 12 (legs only)")
|
||||
|
||||
def locomotion_run(self):
|
||||
"""12-DOF legs-only locomotion policy loop."""
|
||||
self.counter += 1
|
||||
|
||||
if self.counter == 1:
|
||||
print("\n" + "=" * 60)
|
||||
print("🚀 RUNNING UNITREE RL 12-DOF LOCOMOTION POLICY")
|
||||
print(" 47D observations → 12D actions (legs only)")
|
||||
print(" Arms and waist held at fixed positions")
|
||||
print("=" * 60 + "\n")
|
||||
|
||||
# Get current observation
|
||||
robot_state = self.robot.get_observation()
|
||||
if robot_state is None:
|
||||
return
|
||||
|
||||
# Get command from remote controller
|
||||
if robot_state.wireless_remote is not None:
|
||||
self.robot.remote_controller.set(robot_state.wireless_remote)
|
||||
else:
|
||||
self.robot.remote_controller.lx = 0.0
|
||||
self.robot.remote_controller.ly = 0.0
|
||||
self.robot.remote_controller.rx = 0.0
|
||||
self.robot.remote_controller.ry = 0.0
|
||||
|
||||
self.locomotion_cmd[0] = self.robot.remote_controller.ly # forward/backward
|
||||
self.locomotion_cmd[1] = self.robot.remote_controller.lx * -1 # left/right (inverted)
|
||||
self.locomotion_cmd[2] = self.robot.remote_controller.rx * -1 # yaw (inverted)
|
||||
|
||||
# Get leg joint positions and velocities (12 DOF)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.qj[i] = robot_state.motor_state[motor_idx].q
|
||||
self.dqj[i] = robot_state.motor_state[motor_idx].dq
|
||||
|
||||
# Get IMU data
|
||||
quat = robot_state.imu_state.quaternion
|
||||
ang_vel = np.array(robot_state.imu_state.gyroscope, dtype=np.float32)
|
||||
|
||||
# Scale observations
|
||||
gravity_orientation = self.robot.get_gravity_orientation(quat)
|
||||
qj_obs = (self.qj - DEFAULT_LEG_ANGLES) * DOF_POS_SCALE
|
||||
dqj_obs = self.dqj * DOF_VEL_SCALE
|
||||
ang_vel_scaled = ang_vel * ANG_VEL_SCALE
|
||||
|
||||
# Calculate phase
|
||||
count = self.counter * LOCOMOTION_CONTROL_DT
|
||||
phase = (count % GAIT_PERIOD) / GAIT_PERIOD
|
||||
sin_phase = np.sin(2 * np.pi * phase)
|
||||
cos_phase = np.cos(2 * np.pi * phase)
|
||||
|
||||
# Build 47D observation vector
|
||||
# [0:3] - angular velocity (scaled)
|
||||
# [3:6] - gravity orientation
|
||||
# [6:9] - velocity command (scaled)
|
||||
# [9:21] - joint positions (12D, relative to default)
|
||||
# [21:33] - joint velocities (12D, scaled)
|
||||
# [33:45] - previous actions (12D)
|
||||
# [45] - sin_phase
|
||||
# [46] - cos_phase
|
||||
self.locomotion_obs[0:3] = ang_vel_scaled
|
||||
self.locomotion_obs[3:6] = gravity_orientation
|
||||
self.locomotion_obs[6:9] = self.locomotion_cmd * CMD_SCALE * MAX_CMD
|
||||
self.locomotion_obs[9:21] = qj_obs
|
||||
self.locomotion_obs[21:33] = dqj_obs
|
||||
self.locomotion_obs[33:45] = self.locomotion_action
|
||||
self.locomotion_obs[45] = sin_phase
|
||||
self.locomotion_obs[46] = cos_phase
|
||||
|
||||
# Run policy inference (TorchScript)
|
||||
obs_tensor = torch.from_numpy(self.locomotion_obs).unsqueeze(0).float()
|
||||
with torch.no_grad():
|
||||
action_tensor = self.policy(obs_tensor)
|
||||
self.locomotion_action = action_tensor.squeeze().numpy()
|
||||
|
||||
# Transform action to target joint positions
|
||||
target_leg_pos = DEFAULT_LEG_ANGLES + self.locomotion_action * LOCOMOTION_ACTION_SCALE
|
||||
|
||||
# Debug logging (first 3 iterations)
|
||||
if self.counter <= 3:
|
||||
print(f"\n[Unitree RL Debug #{self.counter}]")
|
||||
print(f" Phase: {phase:.3f} (sin={sin_phase:.3f}, cos={cos_phase:.3f})")
|
||||
print(f" Cmd (vx, vy, yaw): ({self.locomotion_cmd[0]:.2f}, {self.locomotion_cmd[1]:.2f}, {self.locomotion_cmd[2]:.2f})")
|
||||
print(f" Action range: [{self.locomotion_action.min():.3f}, {self.locomotion_action.max():.3f}]")
|
||||
|
||||
# Send commands to LEG motors (0-11)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = target_leg_pos[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold WAIST motors at zero (12, 13, 14)
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold ARM motors at initial position (15-28)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Send command
|
||||
self.robot.send_action(self.robot.msg)
|
||||
|
||||
def _locomotion_thread_loop(self):
|
||||
"""Background thread that runs the locomotion policy at specified rate."""
|
||||
logger.info("Locomotion thread started")
|
||||
while self.locomotion_running:
|
||||
start_time = time.time()
|
||||
try:
|
||||
self.locomotion_run()
|
||||
except Exception as e:
|
||||
logger.error(f"Error in locomotion loop: {e}")
|
||||
import traceback
|
||||
traceback.print_exc()
|
||||
|
||||
# Sleep to maintain control rate
|
||||
elapsed = time.time() - start_time
|
||||
sleep_time = max(0, LOCOMOTION_CONTROL_DT - elapsed)
|
||||
time.sleep(sleep_time)
|
||||
logger.info("Locomotion thread stopped")
|
||||
|
||||
def start_locomotion_thread(self):
|
||||
if self.locomotion_running:
|
||||
logger.warning("Locomotion thread already running")
|
||||
return
|
||||
|
||||
logger.info("Starting locomotion control thread...")
|
||||
self.locomotion_running = True
|
||||
self.locomotion_thread = threading.Thread(target=self._locomotion_thread_loop, daemon=True)
|
||||
self.locomotion_thread.start()
|
||||
|
||||
logger.info("Locomotion control thread started!")
|
||||
|
||||
def stop_locomotion_thread(self):
|
||||
if not self.locomotion_running:
|
||||
return
|
||||
|
||||
logger.info("Stopping locomotion control thread...")
|
||||
self.locomotion_running = False
|
||||
if self.locomotion_thread:
|
||||
self.locomotion_thread.join(timeout=2.0)
|
||||
logger.info("Locomotion control thread stopped")
|
||||
|
||||
def reset_robot(self):
|
||||
"""Move legs to default standing position over 2 seconds (arms are captured and held)."""
|
||||
logger.info("Moving legs to default position...")
|
||||
|
||||
total_time = 2.0
|
||||
num_step = int(total_time / self.robot.control_dt)
|
||||
|
||||
# Get current state
|
||||
robot_state = self.robot.get_observation()
|
||||
|
||||
# Capture initial arm positions (to hold during locomotion)
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.initial_arm_positions[i] = robot_state.motor_state[motor_idx].q
|
||||
logger.info(f"Captured initial arm positions: {self.initial_arm_positions[:4]}...")
|
||||
|
||||
# Record current leg positions
|
||||
init_leg_pos = np.zeros(12, dtype=np.float32)
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
init_leg_pos[i] = robot_state.motor_state[motor_idx].q
|
||||
|
||||
# Interpolate legs to default position
|
||||
for step in range(num_step):
|
||||
alpha = step / num_step
|
||||
|
||||
# Interpolate leg positions
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
target_pos = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].q = (
|
||||
init_leg_pos[i] * (1 - alpha) + target_pos * alpha
|
||||
)
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Reached default leg position")
|
||||
|
||||
# Hold position for 2 seconds
|
||||
logger.info("Holding default position for 2 seconds...")
|
||||
hold_time = 2.0
|
||||
num_hold_steps = int(hold_time / self.robot.control_dt)
|
||||
|
||||
for _ in range(num_hold_steps):
|
||||
# Hold legs at default
|
||||
for i, motor_idx in enumerate(LEG_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = DEFAULT_LEG_ANGLES[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = LEG_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = LEG_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold waist at zero
|
||||
for i, motor_idx in enumerate(WAIST_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = 0.0
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = WAIST_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = WAIST_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
# Hold arms at initial position
|
||||
for i, motor_idx in enumerate(ARM_JOINT_INDICES):
|
||||
self.robot.msg.motor_cmd[motor_idx].q = self.initial_arm_positions[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].qd = 0
|
||||
self.robot.msg.motor_cmd[motor_idx].kp = ARM_KPS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].kd = ARM_KDS[i]
|
||||
self.robot.msg.motor_cmd[motor_idx].tau = 0
|
||||
|
||||
self.robot.msg.crc = self.robot.crc.Crc(self.robot.msg)
|
||||
self.robot.lowcmd_publisher.Write(self.robot.msg)
|
||||
time.sleep(self.robot.control_dt)
|
||||
|
||||
logger.info("Ready to start locomotion!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser(description="Unitree RL 12-DOF Locomotion Controller for Unitree G1")
|
||||
parser.add_argument(
|
||||
"--repo-id",
|
||||
type=str,
|
||||
default=DEFAULT_REPO_ID,
|
||||
help=f"Hugging Face Hub repo ID for policy (default: {DEFAULT_REPO_ID})",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--filename",
|
||||
type=str,
|
||||
default="motion.pt",
|
||||
help="Policy filename (default: motion.pt)",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load policy
|
||||
policy = load_torchscript_policy(repo_id=args.repo_id, filename=args.filename)
|
||||
|
||||
# Initialize robot
|
||||
config = UnitreeG1Config()
|
||||
robot = UnitreeG1(config)
|
||||
|
||||
# Initialize locomotion controller
|
||||
locomotion_controller = UnitreeRLLocomotionController(
|
||||
policy=policy,
|
||||
robot=robot,
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Reset robot and start locomotion thread
|
||||
try:
|
||||
locomotion_controller.reset_robot()
|
||||
locomotion_controller.start_locomotion_thread()
|
||||
|
||||
# Log status
|
||||
logger.info("Robot initialized with Unitree RL locomotion policy")
|
||||
logger.info("Locomotion controller running in background thread")
|
||||
logger.info("Use remote controller to command velocity:")
|
||||
logger.info(" Left stick Y: forward/backward")
|
||||
logger.info(" Left stick X: left/right")
|
||||
logger.info(" Right stick X: rotate")
|
||||
logger.info("Press Ctrl+C to stop")
|
||||
|
||||
# Keep robot alive
|
||||
while True:
|
||||
time.sleep(1.0)
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping locomotion...")
|
||||
locomotion_controller.stop_locomotion_thread()
|
||||
print("Done!")
|
||||
|
||||
+1
-5
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.3"
|
||||
version = "0.4.2"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
@@ -107,10 +107,6 @@ dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
gamepad = ["lerobot[pygame-dep]", "hidapi>=0.14.0,<0.15.0"]
|
||||
hopejr = ["lerobot[feetech]", "lerobot[pygame-dep]"]
|
||||
lekiwi = ["lerobot[feetech]", "pyzmq>=26.2.1,<28.0.0"]
|
||||
unitree_g1 = [
|
||||
"pyzmq>=26.2.1,<28.0.0",
|
||||
"onnxruntime>=1.16.0"
|
||||
]
|
||||
reachy2 = ["reachy2_sdk>=1.0.14,<1.1.0"]
|
||||
kinematics = ["lerobot[placo-dep]"]
|
||||
intelrealsense = [
|
||||
|
||||
@@ -0,0 +1,761 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Inference script for SARM (Stage-Aware Reward Model).
|
||||
|
||||
This script loads a trained SARM model and runs inference on a dataset episode,
|
||||
generating visualizations of the predicted task stages and progress over time.
|
||||
|
||||
Example usage:
|
||||
python scripts/visualize_sarm_predictions.py \
|
||||
--model-id username/sarm-model \
|
||||
--dataset-repo lerobot/aloha_sim_insertion_human \
|
||||
--episode-index 0 \
|
||||
--output-dir outputs/sarm_viz \
|
||||
--task-description "insert the peg into the socket"
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import matplotlib.gridspec as gridspec
|
||||
import matplotlib.patches as mpatches
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
pad_state_to_max_dim,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
from lerobot.datasets.utils import load_stats
|
||||
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def parse_args():
|
||||
parser = argparse.ArgumentParser(description="Run SARM inference and visualize predictions")
|
||||
|
||||
# Model arguments
|
||||
parser.add_argument(
|
||||
"--model-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace model ID or local path to trained SARM model"
|
||||
)
|
||||
|
||||
# Dataset arguments
|
||||
parser.add_argument(
|
||||
"--dataset-repo",
|
||||
type=str,
|
||||
required=True,
|
||||
help="HuggingFace dataset repository ID (e.g., lerobot/aloha_sim_insertion_human)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--episode-index",
|
||||
type=int,
|
||||
default=0,
|
||||
help="Index of the episode to visualize (default: 0)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--task-description",
|
||||
type=str,
|
||||
default="perform the task",
|
||||
help="Task description for the reward model (default: 'perform the task')"
|
||||
)
|
||||
|
||||
# Output arguments
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=str,
|
||||
default="outputs/sarm_inference",
|
||||
help="Directory to save visualization outputs (default: outputs/sarm_inference)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for images in dataset (e.g., observation.images.image). If not specified, uses model config's image_key"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--state-key",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Key for joint states in dataset. If None, auto-detects from dataset"
|
||||
)
|
||||
|
||||
# Visualization options
|
||||
parser.add_argument(
|
||||
"--show-frames",
|
||||
action="store_true",
|
||||
help="Include sample frames in the visualization"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-sample-frames",
|
||||
type=int,
|
||||
default=8,
|
||||
help="Number of sample frames to show (default: 8)"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--figsize",
|
||||
type=int,
|
||||
nargs=2,
|
||||
default=[14, 8],
|
||||
help="Figure size as width height (default: 14 8)"
|
||||
)
|
||||
|
||||
# Device
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default=None,
|
||||
help="Device to run inference on (cuda/cpu, default: auto-detect)"
|
||||
)
|
||||
|
||||
return parser.parse_args()
|
||||
|
||||
|
||||
def load_episode_data(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
image_key: str,
|
||||
state_key: str | None = None
|
||||
) -> tuple[np.ndarray, np.ndarray, int, int, str]:
|
||||
"""
|
||||
Load all frames and states from a specific episode.
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode to load
|
||||
image_key: Key for accessing images in the dataset
|
||||
state_key: Key for accessing joint states (auto-detected if None)
|
||||
|
||||
Returns:
|
||||
Tuple of (frames, states, start_index, end_index, task_description)
|
||||
"""
|
||||
# Get episode boundaries
|
||||
episode_data = dataset.meta.episodes
|
||||
start_idx = episode_data["dataset_from_index"][episode_index]
|
||||
end_idx = episode_data["dataset_to_index"][episode_index]
|
||||
|
||||
logger.info(f"Loading episode {episode_index}: frames {start_idx} to {end_idx} ({end_idx - start_idx} frames)")
|
||||
|
||||
# Auto-detect state key if not provided
|
||||
if state_key is None:
|
||||
first_item = dataset[start_idx]
|
||||
state_keys = [k for k in first_item.keys() if 'state' in k.lower() or 'qpos' in k.lower()]
|
||||
if state_keys:
|
||||
state_key = state_keys[0]
|
||||
logger.info(f"Auto-detected state key: {state_key}")
|
||||
|
||||
# Get task description from the dataset if available
|
||||
task_description = None
|
||||
first_item = dataset[start_idx]
|
||||
if "task" in first_item:
|
||||
task_description = first_item["task"]
|
||||
logger.info(f"✓ Extracted task from episode {episode_index}: '{task_description}'")
|
||||
|
||||
# Load all frames and states from the episode
|
||||
frames = []
|
||||
states = []
|
||||
for idx in tqdm(range(start_idx, end_idx), desc="Loading frames"):
|
||||
item = dataset[idx]
|
||||
|
||||
# Get image
|
||||
img = item[image_key]
|
||||
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
img = img.cpu().numpy()
|
||||
|
||||
# Handle different image formats (C, H, W) or (H, W, C)
|
||||
if img.shape[0] in [1, 3]: # Channel first
|
||||
img = np.transpose(img, (1, 2, 0))
|
||||
|
||||
# Convert to uint8 if needed
|
||||
if img.dtype != np.uint8:
|
||||
if img.max() <= 1.0:
|
||||
img = (img * 255).astype(np.uint8)
|
||||
else:
|
||||
img = img.astype(np.uint8)
|
||||
|
||||
frames.append(img)
|
||||
|
||||
# Get state if available
|
||||
if state_key and state_key in item:
|
||||
state = item[state_key]
|
||||
if isinstance(state, torch.Tensor):
|
||||
state = state.cpu().numpy()
|
||||
states.append(state)
|
||||
|
||||
frames = np.array(frames)
|
||||
states = np.array(states) if states else None
|
||||
logger.info(f"Loaded {len(frames)} frames with shape {frames[0].shape}")
|
||||
if states is not None:
|
||||
logger.info(f"Loaded states with shape {states.shape}")
|
||||
|
||||
return frames, states, start_idx, end_idx, task_description
|
||||
|
||||
|
||||
@torch.no_grad()
|
||||
def run_inference(
|
||||
model: SARMRewardModel,
|
||||
frames: np.ndarray,
|
||||
states: Optional[np.ndarray],
|
||||
task_description: str,
|
||||
dataset_stats: dict | None = None,
|
||||
state_key: str = "observation.state",
|
||||
batch_size: int = 32
|
||||
) -> tuple[np.ndarray, np.ndarray]:
|
||||
"""
|
||||
Run SARM inference on video frames and joint states.
|
||||
|
||||
(per SARM paper Section A.4):
|
||||
- Frame 0: Initial frame of the episode (frame 0)
|
||||
- Frames 1-8: 8 consecutive frames with frame_gap spacing ending at current frame t
|
||||
Pattern: [frame_0, t-(7*gap), t-(6*gap), ..., t-gap, t]
|
||||
|
||||
Args:
|
||||
model: SARM model
|
||||
frames: Video frames (num_frames, H, W, C) - all frames from ONE episode
|
||||
states: Joint states (num_frames, state_dim)
|
||||
task_description: Task description text
|
||||
dataset_stats: Dataset statistics for state normalization (same as training)
|
||||
state_key: Key for state in dataset_stats
|
||||
batch_size: Batch size for processing slices
|
||||
|
||||
Returns:
|
||||
Tuple of (progress_predictions, stage_predictions)
|
||||
- progress_predictions: (num_frames,)
|
||||
- stage_predictions: (num_frames, num_stages)
|
||||
"""
|
||||
logger.info("Encoding video frames with CLIP...")
|
||||
video_embeddings = model.encode_images(frames)
|
||||
|
||||
logger.info("Encoding task description with CLIP...")
|
||||
text_embedding = model.encode_text(task_description)
|
||||
|
||||
# Get config values
|
||||
num_frames_model = model.config.num_frames # 9
|
||||
frame_gap = model.config.frame_gap # 30
|
||||
|
||||
logger.info("Creating video slices (SARM paper: initial frame + 8 consecutive)...")
|
||||
|
||||
# Convert to tensors
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
text_embedding = torch.tensor(text_embedding, dtype=torch.float32)
|
||||
if states is not None:
|
||||
state_embeddings = torch.tensor(states, dtype=torch.float32)
|
||||
|
||||
# Normalize states using dataset stats (same as training processor)
|
||||
if dataset_stats is not None and state_key in dataset_stats:
|
||||
mean = torch.tensor(dataset_stats[state_key]["mean"], dtype=torch.float32)
|
||||
std = torch.tensor(dataset_stats[state_key]["std"], dtype=torch.float32)
|
||||
state_embeddings = (state_embeddings - mean) / (std + 1e-8)
|
||||
logger.info(f"✓ Applied MEAN_STD normalization to states using {state_key}")
|
||||
else:
|
||||
logger.warning("⚠ No dataset_stats provided - states not normalized (may differ from training)")
|
||||
else:
|
||||
state_embeddings = None
|
||||
|
||||
video_slices = []
|
||||
state_slices = []
|
||||
|
||||
for current_frame in tqdm(range(len(video_embeddings)), desc="Creating slices"):
|
||||
# Compute frame indices using symmetric bidirectional pattern:
|
||||
# [initial (0), t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
# Boundary handling: clamp to [0, last_valid]
|
||||
deltas = model.config.observation_delta_indices
|
||||
last_valid = len(video_embeddings) - 1
|
||||
|
||||
frame_indices = []
|
||||
for delta in deltas:
|
||||
idx = current_frame + delta
|
||||
idx = max(0, min(idx, last_valid)) # Clamp to valid range
|
||||
frame_indices.append(idx)
|
||||
|
||||
video_slice = video_embeddings[frame_indices]
|
||||
video_slices.append(video_slice)
|
||||
|
||||
if state_embeddings is not None:
|
||||
state_slice = state_embeddings[frame_indices]
|
||||
state_slices.append(state_slice)
|
||||
|
||||
video_slices = torch.stack(video_slices) # (num_frames, num_frames_model, 512)
|
||||
if state_embeddings is not None:
|
||||
state_slices = torch.stack(state_slices) # (num_frames, num_frames_model, state_dim)
|
||||
# Pad states to max_state_dim (same as training processor)
|
||||
state_slices = pad_state_to_max_dim(state_slices, model.config.max_state_dim)
|
||||
else:
|
||||
state_slices = None
|
||||
|
||||
logger.info("Running SARM inference on all slices...")
|
||||
# Process in batches
|
||||
all_progress = []
|
||||
all_stages = []
|
||||
|
||||
for i in tqdm(range(0, len(video_slices), batch_size), desc="Inference"):
|
||||
batch_video = video_slices[i:i + batch_size].to(model.device)
|
||||
batch_states = state_slices[i:i + batch_size].to(model.device) if state_slices is not None else None
|
||||
batch_size_actual = batch_video.shape[0]
|
||||
|
||||
# Replicate text embedding for batch
|
||||
batch_text = text_embedding.unsqueeze(0).repeat(batch_size_actual, 1).to(model.device)
|
||||
|
||||
# Get predictions
|
||||
stage_logits, stage_probs, progress_preds = model.sarm_transformer(
|
||||
batch_video, batch_text, batch_states
|
||||
)
|
||||
|
||||
# Extract predictions at the "current frame" position
|
||||
# With symmetric pattern [initial, t-4g, t-3g, t-2g, t-g, t, t+g, t+2g, t+3g],
|
||||
# the current frame is at position 5 (0-indexed)
|
||||
current_frame_idx = 5
|
||||
batch_progress = progress_preds[:, current_frame_idx, 0].cpu().numpy()
|
||||
batch_stages = stage_probs[:, current_frame_idx, :].cpu().numpy()
|
||||
|
||||
all_progress.extend(batch_progress)
|
||||
all_stages.extend(batch_stages)
|
||||
|
||||
return np.array(all_progress), np.array(all_stages)
|
||||
|
||||
|
||||
def compute_ground_truth_progress(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
temporal_proportions: dict[str, float],
|
||||
subtask_names_ordered: list[str],
|
||||
) -> tuple[np.ndarray, np.ndarray] | tuple[None, None]:
|
||||
"""
|
||||
Compute ground truth progress and stage labels for an episode using annotations.
|
||||
|
||||
Uses SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
dataset: LeRobotDataset instance
|
||||
episode_index: Index of the episode
|
||||
temporal_proportions: Dict mapping subtask name to proportion
|
||||
subtask_names_ordered: Ordered list of subtask names (for consistent stage indexing)
|
||||
|
||||
Returns:
|
||||
Tuple of (ground_truth_progress, ground_truth_stages) arrays, or (None, None) if no annotations
|
||||
"""
|
||||
# Load episode metadata
|
||||
episodes_df = dataset.meta.episodes.to_pandas()
|
||||
|
||||
# Check if annotations exist
|
||||
if "subtask_names" not in episodes_df.columns:
|
||||
logger.warning("No subtask_names column found in episodes metadata")
|
||||
return None, None
|
||||
|
||||
ep_subtask_names = episodes_df.loc[episode_index, "subtask_names"]
|
||||
if ep_subtask_names is None or (isinstance(ep_subtask_names, float) and pd.isna(ep_subtask_names)):
|
||||
logger.warning(f"No annotations found for episode {episode_index}")
|
||||
return None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[episode_index, "subtask_start_frames"]
|
||||
subtask_end_frames = episodes_df.loc[episode_index, "subtask_end_frames"]
|
||||
|
||||
# Get episode boundaries
|
||||
ep_start = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
ep_end = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
num_frames = ep_end - ep_start
|
||||
|
||||
# Get temporal proportions as ordered list
|
||||
temporal_proportions_list = [
|
||||
temporal_proportions.get(name, 0.0) for name in subtask_names_ordered
|
||||
]
|
||||
|
||||
logger.info(f"Computing ground truth for {num_frames} frames using {len(ep_subtask_names)} annotated subtasks")
|
||||
logger.info(f"Subtask names in episode: {ep_subtask_names}")
|
||||
logger.info(f"Subtask start frames: {subtask_start_frames}")
|
||||
logger.info(f"Subtask end frames: {subtask_end_frames}")
|
||||
logger.info(f"Temporal proportions (ordered): {dict(zip(subtask_names_ordered, temporal_proportions_list))}")
|
||||
|
||||
# Compute ground truth for each frame
|
||||
gt_progress = np.zeros(num_frames)
|
||||
gt_stages = np.zeros(num_frames, dtype=np.int32)
|
||||
|
||||
for frame_rel in range(num_frames):
|
||||
# Find which subtask this frame belongs to
|
||||
found = False
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(ep_subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if frame_rel >= start_frame and frame_rel <= end_frame:
|
||||
# Found the subtask - get its global index
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else 0
|
||||
|
||||
# Compute τ_t using utility function
|
||||
tau = compute_tau(frame_rel, start_frame, end_frame)
|
||||
|
||||
# Compute cumulative progress using utility function
|
||||
progress = compute_cumulative_progress_batch(tau, stage_idx, temporal_proportions_list)
|
||||
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
found = True
|
||||
break
|
||||
|
||||
if not found:
|
||||
# Handle frames outside annotated subtasks
|
||||
if frame_rel < subtask_start_frames[0]:
|
||||
gt_progress[frame_rel] = 0.0
|
||||
gt_stages[frame_rel] = 0
|
||||
elif frame_rel > subtask_end_frames[-1]:
|
||||
gt_progress[frame_rel] = 1.0
|
||||
gt_stages[frame_rel] = len(subtask_names_ordered) - 1
|
||||
else:
|
||||
# Between subtasks - find previous subtask
|
||||
for j in range(len(ep_subtask_names) - 1):
|
||||
if frame_rel > subtask_end_frames[j] and frame_rel < subtask_start_frames[j + 1]:
|
||||
name = ep_subtask_names[j]
|
||||
stage_idx = subtask_names_ordered.index(name) if name in subtask_names_ordered else j
|
||||
progress = compute_cumulative_progress_batch(1.0, stage_idx, temporal_proportions_list)
|
||||
gt_progress[frame_rel] = progress
|
||||
gt_stages[frame_rel] = stage_idx
|
||||
break
|
||||
|
||||
logger.info(f"✓ Ground truth computed: final={gt_progress[-1]:.3f}, max={gt_progress.max():.3f}")
|
||||
return gt_progress, gt_stages
|
||||
|
||||
|
||||
def visualize_predictions(
|
||||
frames: np.ndarray,
|
||||
progress_predictions: np.ndarray,
|
||||
stage_predictions: np.ndarray,
|
||||
task_description: str,
|
||||
output_path: Path,
|
||||
num_sample_frames: int = 8,
|
||||
figsize: tuple = (14, 8),
|
||||
subtask_names: list[str] | None = None,
|
||||
temporal_proportions: dict[str, float] | None = None,
|
||||
ground_truth_progress: np.ndarray | None = None,
|
||||
ground_truth_stages: np.ndarray | None = None,
|
||||
):
|
||||
"""
|
||||
Create visualization of SARM predictions with optional ground truth comparison.
|
||||
|
||||
Args:
|
||||
frames: Video frames (num_frames, H, W, C)
|
||||
progress_predictions: Progress predictions (num_frames,)
|
||||
stage_predictions: Stage probabilities (num_frames, num_stages)
|
||||
task_description: Task description
|
||||
output_path: Path to save the figure
|
||||
num_sample_frames: Number of frames to show
|
||||
figsize: Figure size (width, height)
|
||||
subtask_names: Optional list of subtask names for labeling
|
||||
temporal_proportions: Optional dict of temporal proportions for each subtask
|
||||
ground_truth_progress: Optional ground truth progress array (num_frames,)
|
||||
ground_truth_stages: Optional ground truth stage indices array (num_frames,)
|
||||
"""
|
||||
num_stages = stage_predictions.shape[1]
|
||||
stage_colors = plt.cm.tab10(np.linspace(0, 1, num_stages))
|
||||
|
||||
# Use subtask names if available, otherwise use generic labels
|
||||
if subtask_names is not None and len(subtask_names) == num_stages:
|
||||
stage_labels = subtask_names
|
||||
else:
|
||||
stage_labels = [f'Stage {i+1}' for i in range(num_stages)]
|
||||
|
||||
# Create figure with progress plot, stage plot, and sample frames
|
||||
fig = plt.figure(figsize=(figsize[0], figsize[1] + 4))
|
||||
gs = gridspec.GridSpec(3, 1, height_ratios=[2, 1, 1], hspace=0.3)
|
||||
|
||||
ax_progress = fig.add_subplot(gs[0])
|
||||
ax_stages = fig.add_subplot(gs[1], sharex=ax_progress)
|
||||
ax_frames = fig.add_subplot(gs[2])
|
||||
|
||||
frame_indices = np.arange(len(progress_predictions))
|
||||
|
||||
# Plot 1: Progress over time
|
||||
ax_progress.plot(frame_indices, progress_predictions, linewidth=2, color='#2E86AB', label='Predicted Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, progress_predictions, alpha=0.3, color='#2E86AB')
|
||||
|
||||
# Plot ground truth if available
|
||||
if ground_truth_progress is not None:
|
||||
ax_progress.plot(frame_indices, ground_truth_progress, linewidth=2, color='#28A745',
|
||||
linestyle='--', label='Ground Truth Progress')
|
||||
ax_progress.fill_between(frame_indices, 0, ground_truth_progress, alpha=0.15, color='#28A745')
|
||||
|
||||
ax_progress.axhline(y=1.0, color='gray', linestyle='--', alpha=0.5, linewidth=1)
|
||||
ax_progress.set_ylabel('Task Progress', fontsize=12)
|
||||
ax_progress.set_title(f'Task: "{task_description}"', fontsize=14, fontweight='bold')
|
||||
ax_progress.grid(True, alpha=0.3)
|
||||
ax_progress.set_ylim(-0.05, 1.1)
|
||||
ax_progress.legend(loc='upper left')
|
||||
|
||||
# Add statistics box
|
||||
stats_text = (
|
||||
f'Frames: {len(progress_predictions)}\n'
|
||||
f'Final Progress: {progress_predictions[-1]:.3f}\n'
|
||||
f'Max Progress: {progress_predictions.max():.3f}\n'
|
||||
f'Mean Progress: {progress_predictions.mean():.3f}'
|
||||
)
|
||||
if ground_truth_progress is not None:
|
||||
mse = np.mean((progress_predictions - ground_truth_progress) ** 2)
|
||||
stats_text += f'\nMSE vs GT: {mse:.4f}'
|
||||
stats_text += f'\nGT Final: {ground_truth_progress[-1]:.3f}'
|
||||
|
||||
ax_progress.text(0.98, 0.02, stats_text, transform=ax_progress.transAxes,
|
||||
fontsize=10, verticalalignment='bottom', horizontalalignment='right',
|
||||
bbox=dict(boxstyle='round', facecolor='wheat', alpha=0.5))
|
||||
|
||||
# Plot 2: Stage predictions (stacked area plot)
|
||||
ax_stages.stackplot(frame_indices, *[stage_predictions[:, i] for i in range(num_stages)],
|
||||
colors=stage_colors, alpha=0.8, labels=stage_labels)
|
||||
|
||||
# Plot ground truth stage as vertical bands or markers
|
||||
if ground_truth_stages is not None:
|
||||
# Find stage transition points in ground truth
|
||||
stage_changes = np.where(np.diff(ground_truth_stages) != 0)[0] + 1
|
||||
for change_idx in stage_changes:
|
||||
ax_stages.axvline(x=change_idx, color='black', linestyle='-', alpha=0.7, linewidth=1.5)
|
||||
ax_progress.axvline(x=change_idx, color='black', linestyle='-', alpha=0.3, linewidth=1)
|
||||
|
||||
# Add small markers at bottom showing GT stage
|
||||
gt_stage_normalized = ground_truth_stages / max(num_stages - 1, 1)
|
||||
ax_stages.scatter(frame_indices[::30], np.zeros(len(frame_indices[::30])) + 0.02,
|
||||
c=[stage_colors[s] for s in ground_truth_stages[::30]],
|
||||
s=20, marker='|', alpha=0.8, label='GT Stage Markers')
|
||||
|
||||
ax_stages.set_xlabel('Frame Index', fontsize=12)
|
||||
ax_stages.set_ylabel('Stage Probability', fontsize=12)
|
||||
ax_stages.set_ylim(0, 1)
|
||||
ax_stages.grid(True, alpha=0.3)
|
||||
|
||||
# Adjust legend based on number of stages and label lengths
|
||||
if num_stages <= 5:
|
||||
ax_stages.legend(loc='upper left', ncol=num_stages, fontsize=8)
|
||||
else:
|
||||
ax_stages.legend(loc='upper left', ncol=3, fontsize=7)
|
||||
|
||||
# Add vertical lines and labels for expected stage transitions (if temporal proportions available)
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
cumulative_progress = 0.0
|
||||
for i, name in enumerate(stage_labels):
|
||||
if name in temporal_proportions:
|
||||
# Find approximate frame where this stage should end
|
||||
stage_end_progress = cumulative_progress + temporal_proportions[name]
|
||||
|
||||
# Find frame index closest to this progress
|
||||
progress_diffs = np.abs(progress_predictions - stage_end_progress)
|
||||
stage_end_frame = np.argmin(progress_diffs)
|
||||
|
||||
# Draw vertical line
|
||||
ax_progress.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
ax_stages.axvline(x=stage_end_frame, color='gray', linestyle=':', alpha=0.5, linewidth=1)
|
||||
|
||||
cumulative_progress = stage_end_progress
|
||||
|
||||
# Plot 3: Sample frames (if requested)
|
||||
frame_indices_to_show = np.linspace(0, len(frames) - 1, num_sample_frames, dtype=int)
|
||||
|
||||
ax_frames.axis('off')
|
||||
|
||||
# Create grid for frames
|
||||
frame_height = frames[0].shape[0]
|
||||
frame_width = frames[0].shape[1]
|
||||
|
||||
combined_width = frame_width * num_sample_frames
|
||||
combined_image = np.zeros((frame_height, combined_width, 3), dtype=np.uint8)
|
||||
|
||||
for i, frame_idx in enumerate(frame_indices_to_show):
|
||||
frame = frames[frame_idx]
|
||||
if frame.shape[-1] == 1:
|
||||
frame = np.repeat(frame, 3, axis=-1)
|
||||
|
||||
# Add frame to combined image
|
||||
x_start = i * frame_width
|
||||
x_end = (i + 1) * frame_width
|
||||
combined_image[:, x_start:x_end] = frame
|
||||
|
||||
# Add frame number, progress, and stage
|
||||
progress_val = progress_predictions[frame_idx]
|
||||
stage_idx = np.argmax(stage_predictions[frame_idx])
|
||||
stage_name = stage_labels[stage_idx] if stage_idx < len(stage_labels) else f'{stage_idx+1}'
|
||||
|
||||
# Truncate long stage names for display
|
||||
if len(stage_name) > 15:
|
||||
stage_name = stage_name[:12] + '...'
|
||||
|
||||
label = f'Frame {frame_idx}\nProg: {progress_val:.2f}\n{stage_name}'
|
||||
|
||||
# Draw label on image
|
||||
ax_frames.text(x_start + frame_width / 2, -10, label,
|
||||
ha='center', va='top', fontsize=7,
|
||||
bbox=dict(boxstyle='round', facecolor='white', alpha=0.7))
|
||||
|
||||
ax_frames.imshow(combined_image)
|
||||
ax_frames.set_title('Sample Frames', fontsize=12, pad=20)
|
||||
|
||||
plt.tight_layout()
|
||||
output_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
plt.savefig(output_path, dpi=150, bbox_inches='tight')
|
||||
logger.info(f"Saved visualization to {output_path}")
|
||||
|
||||
plt.close()
|
||||
|
||||
|
||||
def main():
|
||||
args = parse_args()
|
||||
|
||||
# Setup device
|
||||
if args.device is None:
|
||||
device = "cuda" if torch.cuda.is_available() else "cpu"
|
||||
else:
|
||||
device = args.device
|
||||
logger.info(f"Using device: {device}")
|
||||
|
||||
# Load model
|
||||
logger.info(f"Loading SARM model from {args.model_id}...")
|
||||
model = SARMRewardModel.from_pretrained(args.model_id)
|
||||
model.to(device)
|
||||
model.eval()
|
||||
logger.info("Model loaded successfully")
|
||||
|
||||
# Load dataset
|
||||
logger.info(f"Loading dataset {args.dataset_repo}...")
|
||||
dataset = LeRobotDataset(args.dataset_repo)
|
||||
logger.info(f"Dataset loaded: {len(dataset.meta.episodes)} episodes, {len(dataset)} frames")
|
||||
|
||||
# Validate episode index
|
||||
if args.episode_index >= len(dataset.meta.episodes):
|
||||
raise ValueError(
|
||||
f"Episode index {args.episode_index} out of range. "
|
||||
f"Dataset has {len(dataset.meta.episodes)} episodes."
|
||||
)
|
||||
|
||||
image_key = args.image_key if args.image_key is not None else model.config.image_key
|
||||
state_key = args.state_key if args.state_key is not None else model.config.state_key
|
||||
logger.info(f"Using image key: {image_key}")
|
||||
logger.info(f"Using state key: {state_key}")
|
||||
|
||||
# Load dataset stats for state normalization (same as training)
|
||||
dataset_stats = load_stats(dataset.root)
|
||||
if dataset_stats:
|
||||
logger.info(f"✓ Loaded dataset stats from {dataset.root}")
|
||||
else:
|
||||
logger.warning("⚠ Could not load dataset stats - states will not be normalized")
|
||||
|
||||
# Load episode data
|
||||
frames, states, start_idx, end_idx, dataset_task = load_episode_data(
|
||||
dataset, args.episode_index, image_key, state_key
|
||||
)
|
||||
|
||||
# Use task description from dataset if available, otherwise use command-line argument
|
||||
task_description = dataset_task if dataset_task is not None else args.task_description
|
||||
logger.info(f"Using task description: '{task_description}'")
|
||||
|
||||
# Run inference
|
||||
progress_predictions, stage_predictions = run_inference(
|
||||
model, frames, states, task_description,
|
||||
dataset_stats=dataset_stats, state_key=state_key
|
||||
)
|
||||
|
||||
# Extract subtask names and temporal proportions from model config if available
|
||||
subtask_names = None
|
||||
temporal_proportions = None
|
||||
|
||||
if hasattr(model.config, 'subtask_names') and model.config.subtask_names is not None:
|
||||
subtask_names = model.config.subtask_names
|
||||
logger.info(f"✓ Found {len(subtask_names)} subtask names in model config: {subtask_names}")
|
||||
|
||||
# Try to load temporal proportions from model config
|
||||
if hasattr(model.config, 'temporal_proportions') and model.config.temporal_proportions is not None:
|
||||
temporal_proportions = {
|
||||
name: prop for name, prop in zip(model.config.subtask_names, model.config.temporal_proportions)
|
||||
}
|
||||
logger.info(f"✓ Loaded temporal proportions from model config: {temporal_proportions}")
|
||||
|
||||
# Fallback: try to load from dataset meta
|
||||
if temporal_proportions is None:
|
||||
proportions_path = dataset.root / "meta" / "temporal_proportions.json"
|
||||
if proportions_path.exists():
|
||||
with open(proportions_path, 'r') as f:
|
||||
temporal_proportions = json.load(f)
|
||||
logger.info(f"✓ Loaded temporal proportions from dataset: {temporal_proportions}")
|
||||
|
||||
# Also extract subtask names from proportions if not already set
|
||||
if subtask_names is None:
|
||||
subtask_names = sorted(temporal_proportions.keys())
|
||||
logger.info(f"✓ Extracted subtask names from proportions: {subtask_names}")
|
||||
|
||||
# Compute ground truth progress if annotations are available
|
||||
ground_truth_progress = None
|
||||
ground_truth_stages = None
|
||||
|
||||
if temporal_proportions is not None and subtask_names is not None:
|
||||
logger.info("Attempting to compute ground truth progress from annotations...")
|
||||
ground_truth_progress, ground_truth_stages = compute_ground_truth_progress(
|
||||
dataset,
|
||||
args.episode_index,
|
||||
temporal_proportions,
|
||||
subtask_names
|
||||
)
|
||||
if ground_truth_progress is None:
|
||||
logger.warning("⚠ Ground truth not available - annotations may be missing for this episode")
|
||||
else:
|
||||
logger.warning("⚠ Cannot compute ground truth - temporal_proportions or subtask_names not available")
|
||||
|
||||
output_dir = Path(args.output_dir)
|
||||
output_path = output_dir / f"sarm_prediction_ep{args.episode_index}.png"
|
||||
|
||||
visualize_predictions(
|
||||
frames,
|
||||
progress_predictions,
|
||||
stage_predictions,
|
||||
task_description,
|
||||
output_path,
|
||||
num_sample_frames=args.num_sample_frames,
|
||||
figsize=tuple(args.figsize),
|
||||
subtask_names=subtask_names,
|
||||
temporal_proportions=temporal_proportions,
|
||||
ground_truth_progress=ground_truth_progress,
|
||||
ground_truth_stages=ground_truth_stages,
|
||||
)
|
||||
|
||||
predictions_path = output_dir / f"predictions_ep{args.episode_index}.npz"
|
||||
save_dict = {
|
||||
'progress': progress_predictions,
|
||||
'stages': stage_predictions
|
||||
}
|
||||
if ground_truth_progress is not None:
|
||||
save_dict['gt_progress'] = ground_truth_progress
|
||||
save_dict['gt_stages'] = ground_truth_stages
|
||||
np.savez(predictions_path, **save_dict)
|
||||
logger.info(f"Saved predictions to {predictions_path}")
|
||||
logger.info(f"\nVisualization: {output_path}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -64,9 +64,26 @@ class TrainPipelineConfig(HubMixin):
|
||||
scheduler: LRSchedulerConfig | None = None
|
||||
eval: EvalConfig = field(default_factory=EvalConfig)
|
||||
wandb: WandBConfig = field(default_factory=WandBConfig)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
# RA-BC (Reward-Aligned Behavior Cloning) parameters
|
||||
use_rabc: bool = False # Enable reward-weighted training
|
||||
reward_model_path: str | None = None # Path to pre-trained reward model (e.g., SARM)
|
||||
rabc_kappa: float = 0.01 # Hard threshold for high-quality samples
|
||||
rabc_epsilon: float = 1e-6 # Small constant for numerical stability
|
||||
rabc_update_freq: int = 1 # Compute rewards every N batches (1 = every batch)
|
||||
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
checkpoint_path: Path | None = field(init=False, default=None)
|
||||
|
||||
|
||||
def validate(self):
|
||||
# Validate RA-BC configuration
|
||||
if self.use_rabc and not self.reward_model_path:
|
||||
raise ValueError(
|
||||
"RA-BC is enabled (use_rabc=True) but no reward_model_path provided. "
|
||||
"Please specify a pre-trained reward model (e.g., SARM) path."
|
||||
)
|
||||
|
||||
def validate(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained paths if there was some.
|
||||
|
||||
@@ -999,10 +999,18 @@ def _copy_data_with_feature_changes(
|
||||
df[feature_name] = feature_values
|
||||
else:
|
||||
feature_slice = values[frame_idx:end_idx]
|
||||
if len(feature_slice.shape) > 1 and feature_slice.shape[1] == 1:
|
||||
df[feature_name] = feature_slice.flatten()
|
||||
else:
|
||||
if len(feature_slice.shape) == 1:
|
||||
# 1D array - can assign directly
|
||||
df[feature_name] = feature_slice
|
||||
elif len(feature_slice.shape) == 2 and feature_slice.shape[1] == 1:
|
||||
# 2D array with single column - flatten it
|
||||
df[feature_name] = feature_slice.flatten()
|
||||
elif len(feature_slice.shape) == 2:
|
||||
# 2D array with multiple columns (e.g., embeddings) - convert to list of lists
|
||||
df[feature_name] = feature_slice.tolist()
|
||||
else:
|
||||
# Higher dimensional - convert to list
|
||||
df[feature_name] = [row.tolist() for row in feature_slice]
|
||||
frame_idx = end_idx
|
||||
|
||||
# Write using the same chunk/file structure as source
|
||||
|
||||
@@ -0,0 +1,146 @@
|
||||
# LeRobot Embedding Generation Script
|
||||
|
||||
Generate embeddings for LeRobot datasets to make them more lightweight and efficient for training.
|
||||
|
||||
## Overview
|
||||
|
||||
This script processes v3.0 LeRobot datasets and adds pre-computed embeddings for:
|
||||
|
||||
- **Task embeddings**: Language command embeddings using MiniLM
|
||||
- **Image embeddings**: Frame embeddings using DinoV2
|
||||
|
||||
The resulting dataset can be used more efficiently during training by loading pre-computed embeddings instead of running encoders on-the-fly.
|
||||
|
||||
## Supported Encoders
|
||||
|
||||
### Image Encoders (DinoV2)
|
||||
|
||||
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings:
|
||||
|
||||
- **`dinov2_vits14`**: ViT-S/14 (384-dim) - Fastest, smaller model
|
||||
- **`dinov2_vitb14`**: ViT-B/14 (768-dim) - **Recommended** - Good balance
|
||||
- **`dinov2_vitl14`**: ViT-L/14 (1024-dim) - Best quality, slower
|
||||
|
||||
### Language Encoders (MiniLM)
|
||||
|
||||
MiniLM is a lightweight sentence transformer model:
|
||||
|
||||
- **`minilm-l6`**: MiniLM-L6-v2 (384-dim) - Faster
|
||||
- **`minilm-l12`**: MiniLM-L12-v2 (384-dim) - **Recommended** - Better quality
|
||||
|
||||
## Usage
|
||||
|
||||
### Basic Command
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
### Lightweight Version (No Videos)
|
||||
|
||||
Removes video files to significantly reduce storage:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--remove-videos \
|
||||
--push-to-hub
|
||||
```
|
||||
|
||||
## Output
|
||||
|
||||
The script adds new features to your dataset:
|
||||
|
||||
### New Features
|
||||
|
||||
1. **`task_embedding`**: Language embedding for each frame
|
||||
- Shape: `[384]` (MiniLM)
|
||||
- One embedding per frame based on its task
|
||||
|
||||
2. **`{camera_key}_embedding`**: Image embedding for each camera view
|
||||
- Shape: `[384]`, `[768]`, or `[1024]` depending on DinoV2 model
|
||||
- Examples: `observation.images.top_embedding`, `observation.images.wrist_embedding`
|
||||
|
||||
### Using Embeddings in Training
|
||||
|
||||
```python
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
# Load dataset with embeddings
|
||||
dataset = LeRobotDataset("your-username/utokyo_xarm_bimanual_embeddings")
|
||||
|
||||
# Access embeddings
|
||||
item = dataset[0]
|
||||
task_emb = item["task_embedding"] # Shape: [384]
|
||||
img_emb = item["observation.images.top_embedding"] # Shape: [768]
|
||||
|
||||
# Use in your policy
|
||||
# Instead of running encoders during training, use pre-computed embeddings
|
||||
```
|
||||
|
||||
## Extending with New Encoders
|
||||
|
||||
The script is designed to be easily extensible. To add a new encoder:
|
||||
|
||||
### 1. Create Encoder Class
|
||||
|
||||
```python
|
||||
class MyCustomImageEncoder(ImageEncoder):
|
||||
"""Your custom image encoder."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
super().__init__(device)
|
||||
# Load your model
|
||||
self.model = load_my_model()
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
# Your encoding logic here
|
||||
embeddings = []
|
||||
for img in images:
|
||||
emb = self.model(img)
|
||||
embeddings.append(emb)
|
||||
return np.array(embeddings)
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return embedding dimension."""
|
||||
return 512 # Your embedding dimension
|
||||
```
|
||||
|
||||
### 2. Add to Factory Function
|
||||
|
||||
```python
|
||||
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
|
||||
encoders = {
|
||||
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
|
||||
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
|
||||
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
|
||||
# Add your encoder
|
||||
"my_custom": lambda: MyCustomImageEncoder(device=device),
|
||||
}
|
||||
# ... rest of function
|
||||
```
|
||||
|
||||
## Validating Embeddings
|
||||
|
||||
After generating embeddings, you can validate them using `validate_embeddings.py`:
|
||||
|
||||
```bash
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--embeddings-repo-id pepijn223/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--num-samples 20
|
||||
```
|
||||
@@ -0,0 +1,147 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ImageEncoder:
|
||||
"""Base class for image encoders."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class DinoV2Encoder(ImageEncoder):
|
||||
"""DinoV2 image encoder.
|
||||
|
||||
DinoV2 is a self-supervised vision transformer that produces high-quality image embeddings.
|
||||
Supports multiple model sizes (ViT-S/14, ViT-B/14, ViT-L/14).
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "dinov2_vitb14", device: str = "cuda", batch_size: int = 32):
|
||||
super().__init__(device)
|
||||
self.batch_size = batch_size
|
||||
self.model_name = model_name
|
||||
logger.info(f"Loading DinoV2 model: {model_name}")
|
||||
self.model = torch.hub.load("facebookresearch/dinov2", model_name) # nosec B614
|
||||
self.model = self.model.to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
# DinoV2 preprocessing
|
||||
from torchvision import transforms
|
||||
|
||||
self.transform = transforms.Compose(
|
||||
[
|
||||
transforms.Resize(256, interpolation=transforms.InterpolationMode.BICUBIC),
|
||||
transforms.CenterCrop(224),
|
||||
transforms.ToTensor(),
|
||||
transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
||||
]
|
||||
)
|
||||
|
||||
def encode(self, images: list[np.ndarray]) -> np.ndarray:
|
||||
"""Encode a batch of images."""
|
||||
embeddings = []
|
||||
|
||||
with torch.inference_mode():
|
||||
for i in range(0, len(images), self.batch_size):
|
||||
batch_images = images[i : i + self.batch_size]
|
||||
# Convert numpy arrays to PIL Images and apply transforms
|
||||
pil_images = [Image.fromarray(img.astype(np.uint8)) for img in batch_images]
|
||||
tensors = torch.stack([self.transform(img) for img in pil_images]).to(self.device)
|
||||
|
||||
# Get embeddings
|
||||
batch_embeddings = self.model(tensors).cpu().numpy()
|
||||
embeddings.append(batch_embeddings)
|
||||
|
||||
return np.concatenate(embeddings, axis=0)
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return the embedding dimension based on model size."""
|
||||
if "vits14" in self.model_name:
|
||||
return 384 # DinoV2 ViT-S/14
|
||||
elif "vitb14" in self.model_name:
|
||||
return 768 # DinoV2 ViT-B/14
|
||||
elif "vitl14" in self.model_name:
|
||||
return 1024 # DinoV2 ViT-L/14
|
||||
else:
|
||||
return 768 # Default to ViT-B/14
|
||||
|
||||
|
||||
class LanguageEncoder:
|
||||
"""Base class for language encoders."""
|
||||
|
||||
def __init__(self, device: str = "cuda"):
|
||||
self.device = torch.device(device if torch.cuda.is_available() else "cpu")
|
||||
|
||||
def encode(self, texts: list[str]) -> np.ndarray:
|
||||
"""Encode a batch of texts."""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class MiniLMEncoder(LanguageEncoder):
|
||||
"""MiniLM language encoder.
|
||||
|
||||
MiniLM is a lightweight sentence transformer model that produces high-quality text embeddings.
|
||||
Supports L6 and L12 model sizes.
|
||||
"""
|
||||
|
||||
def __init__(self, model_name: str = "sentence-transformers/all-MiniLM-L12-v2", device: str = "cuda"):
|
||||
super().__init__(device)
|
||||
self.model_name = model_name
|
||||
logger.info(f"Loading MiniLM model: {model_name}")
|
||||
|
||||
from transformers import AutoModel, AutoTokenizer
|
||||
|
||||
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
|
||||
self.model = AutoModel.from_pretrained(model_name).to(self.device)
|
||||
self.model.eval()
|
||||
|
||||
def _mean_pooling(self, model_output, attention_mask):
|
||||
"""Mean pooling to get sentence embeddings."""
|
||||
token_embeddings = model_output[0]
|
||||
input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float()
|
||||
return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(
|
||||
input_mask_expanded.sum(1), min=1e-9
|
||||
)
|
||||
|
||||
def encode(self, texts: list[str]) -> np.ndarray:
|
||||
"""Encode a batch of texts."""
|
||||
with torch.inference_mode():
|
||||
encoded_input = self.tokenizer(texts, padding=True, truncation=True, return_tensors="pt")
|
||||
encoded_input = {k: v.to(self.device) for k, v in encoded_input.items()}
|
||||
|
||||
model_output = self.model(**encoded_input)
|
||||
embeddings = self._mean_pooling(model_output, encoded_input["attention_mask"])
|
||||
|
||||
return embeddings.cpu().numpy()
|
||||
|
||||
@property
|
||||
def embedding_dim(self) -> int:
|
||||
"""Return the embedding dimension."""
|
||||
return 384 # Both MiniLM-L6 and L12 output 384-dim embeddings
|
||||
@@ -0,0 +1,329 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Generate embeddings for LeRobot datasets to make them more lightweight and efficient.
|
||||
|
||||
This script:
|
||||
1. Loads a v3.0 LeRobot dataset from the hub
|
||||
2. Computes embeddings for tasks (language commands) and frames (images)
|
||||
3. Stores embeddings as new features in the dataset
|
||||
4. Optionally removes video files to reduce size
|
||||
5. Pushes the converted dataset to the hub
|
||||
|
||||
Current supported encoders:
|
||||
- Image: DinoV2 (dinov2_vits14, dinov2_vitb14, dinov2_vitl14)
|
||||
- Language: MiniLM (minilm-l6, minilm-l12)
|
||||
|
||||
The architecture is extensible - you can add more encoders by:
|
||||
1. Creating a new encoder class inheriting from ImageEncoder or LanguageEncoder
|
||||
2. Implementing the encode() method and embedding_dim property
|
||||
3. Adding it to the get_image_encoder() or get_language_encoder() factory function
|
||||
|
||||
Usage example:
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--output-repo-id lerobot/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--remove-videos \
|
||||
--push-to-hub
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import shutil
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.generating_embeddings.encoders import (
|
||||
DinoV2Encoder,
|
||||
ImageEncoder,
|
||||
LanguageEncoder,
|
||||
MiniLMEncoder,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def get_image_encoder(encoder_name: str, device: str = "cuda") -> ImageEncoder:
|
||||
"""Factory function to get image encoder.
|
||||
|
||||
To add a new encoder:
|
||||
1. Create a new class inheriting from ImageEncoder
|
||||
2. Implement encode() and embedding_dim property
|
||||
3. Add it to the encoders dictionary below
|
||||
"""
|
||||
encoders = {
|
||||
"dinov2_vits14": lambda: DinoV2Encoder(model_name="dinov2_vits14", device=device),
|
||||
"dinov2_vitb14": lambda: DinoV2Encoder(model_name="dinov2_vitb14", device=device),
|
||||
"dinov2_vitl14": lambda: DinoV2Encoder(model_name="dinov2_vitl14", device=device),
|
||||
}
|
||||
|
||||
if encoder_name not in encoders:
|
||||
raise ValueError(f"Unknown image encoder: {encoder_name}. Available options: {list(encoders.keys())}")
|
||||
|
||||
return encoders[encoder_name]()
|
||||
|
||||
|
||||
def get_language_encoder(encoder_name: str, device: str = "cuda") -> LanguageEncoder:
|
||||
"""Factory function to get language encoder.
|
||||
|
||||
To add a new encoder:
|
||||
1. Create a new class inheriting from LanguageEncoder
|
||||
2. Implement encode() and embedding_dim property
|
||||
3. Add it to the encoders dictionary below
|
||||
"""
|
||||
encoders = {
|
||||
"minilm-l6": lambda: MiniLMEncoder(
|
||||
model_name="sentence-transformers/all-MiniLM-L6-v2", device=device
|
||||
),
|
||||
"minilm-l12": lambda: MiniLMEncoder(
|
||||
model_name="sentence-transformers/all-MiniLM-L12-v2", device=device
|
||||
),
|
||||
}
|
||||
|
||||
if encoder_name not in encoders:
|
||||
raise ValueError(
|
||||
f"Unknown language encoder: {encoder_name}. Available options: {list(encoders.keys())}"
|
||||
)
|
||||
|
||||
return encoders[encoder_name]()
|
||||
|
||||
|
||||
def generate_embeddings_for_dataset(
|
||||
repo_id: str,
|
||||
output_repo_id: str,
|
||||
image_encoder: ImageEncoder,
|
||||
language_encoder: LanguageEncoder,
|
||||
remove_videos: bool = False,
|
||||
local_dir: Path | None = None,
|
||||
output_local_dir: Path | None = None,
|
||||
push_to_hub: bool = False,
|
||||
):
|
||||
"""Generate embeddings for a LeRobot dataset.
|
||||
|
||||
Args:
|
||||
repo_id: Source dataset repository ID
|
||||
output_repo_id: Output dataset repository ID
|
||||
image_encoder: Image encoder instance
|
||||
language_encoder: Language encoder instance
|
||||
remove_videos: Whether to remove video files
|
||||
local_dir: Local directory for source dataset
|
||||
output_local_dir: Local directory for output dataset
|
||||
push_to_hub: Whether to push to hub after conversion
|
||||
"""
|
||||
from lerobot.datasets.dataset_tools import modify_features
|
||||
|
||||
print(f"Loading dataset: {repo_id}")
|
||||
|
||||
dataset = LeRobotDataset(repo_id, root=local_dir, download_videos=True)
|
||||
print(f"Dataset: {dataset.num_episodes} episodes, {dataset.num_frames} frames")
|
||||
|
||||
print("Computing task embeddings...")
|
||||
unique_tasks = dataset.meta.tasks.index.tolist()
|
||||
task_embeddings = {}
|
||||
|
||||
for task in tqdm(unique_tasks, desc="Encoding tasks"):
|
||||
# Clean up task text
|
||||
task_clean = task.strip().capitalize().strip(" .,!?-_")
|
||||
embedding = language_encoder.encode([task_clean])[0]
|
||||
task_embeddings[task] = embedding
|
||||
|
||||
print(f"Computed {len(task_embeddings)} task embeddings")
|
||||
|
||||
print("Processing frames and computing embeddings...")
|
||||
all_task_embeddings = []
|
||||
all_image_embeddings_dict = {cam_key: [] for cam_key in dataset.meta.camera_keys}
|
||||
|
||||
for frame_idx in tqdm(range(dataset.num_frames), desc="Processing frames"):
|
||||
item = dataset.hf_dataset[frame_idx]
|
||||
ep_idx = item["episode_index"].item()
|
||||
|
||||
task = dataset.meta.tasks.iloc[item["task_index"].item()].name
|
||||
task_emb = task_embeddings[task]
|
||||
all_task_embeddings.append(task_emb)
|
||||
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
if cam_key in dataset.meta.video_keys:
|
||||
current_ts = item["timestamp"].item()
|
||||
video_frames = dataset._query_videos({cam_key: [current_ts]}, ep_idx)
|
||||
img = video_frames[cam_key]
|
||||
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 4:
|
||||
img = img[0] # (T, C, H, W) -> (C, H, W)
|
||||
elif img.ndim != 3:
|
||||
raise ValueError(f"Unexpected video frame shape {img.shape} for camera {cam_key}")
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
else:
|
||||
img = item[cam_key]
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 3:
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape {img.shape} for camera {cam_key}")
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
|
||||
all_image_embeddings_dict[cam_key].append(img_np)
|
||||
|
||||
print("Computing image embeddings...")
|
||||
image_embeddings_dict = {}
|
||||
for cam_key, images in all_image_embeddings_dict.items():
|
||||
print(f" {cam_key}: {len(images)} images")
|
||||
embeddings = image_encoder.encode(images)
|
||||
image_embeddings_dict[cam_key] = embeddings
|
||||
|
||||
all_task_embeddings = np.array(all_task_embeddings)
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
image_embeddings_dict[cam_key] = np.array(image_embeddings_dict[cam_key])
|
||||
|
||||
img_emb_dim = image_encoder.embedding_dim
|
||||
lang_emb_dim = language_encoder.embedding_dim
|
||||
|
||||
add_features_dict = {
|
||||
"task_embedding": (
|
||||
all_task_embeddings,
|
||||
{"dtype": "float32", "shape": [lang_emb_dim], "names": None},
|
||||
),
|
||||
}
|
||||
|
||||
for cam_key in dataset.meta.camera_keys:
|
||||
add_features_dict[f"{cam_key}_embedding"] = (
|
||||
image_embeddings_dict[cam_key],
|
||||
{"dtype": "float32", "shape": [img_emb_dim], "names": None},
|
||||
)
|
||||
|
||||
print("Adding embeddings to dataset...")
|
||||
remove_features_list = None
|
||||
if remove_videos:
|
||||
remove_features_list = dataset.meta.video_keys
|
||||
|
||||
output_dataset = modify_features(
|
||||
dataset=dataset,
|
||||
add_features=add_features_dict,
|
||||
remove_features=remove_features_list,
|
||||
output_dir=output_local_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
if remove_videos:
|
||||
print("Removing video files...")
|
||||
videos_dir = output_dataset.root / "videos"
|
||||
if videos_dir.exists():
|
||||
shutil.rmtree(videos_dir)
|
||||
|
||||
print(f"Saved to: {output_dataset.root}")
|
||||
|
||||
if push_to_hub:
|
||||
print(f"Pushing to hub: {output_repo_id}")
|
||||
output_dataset.push_to_hub(push_videos=not remove_videos)
|
||||
print("Done!")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Generate embeddings for LeRobot datasets",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Examples:
|
||||
# Basic usage with default encoders (DinoV2 ViT-B/14 + MiniLM-L12)
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_embeddings \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--push-to-hub
|
||||
|
||||
# Generate embeddings and remove videos
|
||||
python src/lerobot/datasets/generating_embeddings/generate_embeddings.py \\
|
||||
--repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--output-repo-id your-username/utokyo_xarm_bimanual_lightweight \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--remove-videos \\
|
||||
--push-to-hub
|
||||
|
||||
Available image encoders:
|
||||
- dinov2_vits14: DinoV2 ViT-S/14 (384-dim, faster)
|
||||
- dinov2_vitb14: DinoV2 ViT-B/14 (768-dim, recommended)
|
||||
- dinov2_vitl14: DinoV2 ViT-L/14 (1024-dim, best quality)
|
||||
|
||||
Available language encoders:
|
||||
- minilm-l6: MiniLM-L6-v2 (384-dim, faster)
|
||||
- minilm-l12: MiniLM-L12-v2 (384-dim, recommended)
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--repo-id", type=str, required=True, help="Source dataset repository ID")
|
||||
parser.add_argument("--output-repo-id", type=str, required=True, help="Output dataset repository ID")
|
||||
parser.add_argument(
|
||||
"--image-encoder",
|
||||
type=str,
|
||||
default="dinov2_vitb14",
|
||||
help="Image encoder to use (default: dinov2_vitb14)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language-encoder",
|
||||
type=str,
|
||||
default="minilm-l12",
|
||||
help="Language encoder to use (default: minilm-l12)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--remove-videos",
|
||||
action="store_true",
|
||||
help="Remove video files after generating embeddings",
|
||||
)
|
||||
parser.add_argument("--local-dir", type=str, default=None, help="Local directory for source dataset")
|
||||
parser.add_argument(
|
||||
"--output-local-dir", type=str, default=None, help="Local directory for output dataset"
|
||||
)
|
||||
parser.add_argument(
|
||||
"--push-to-hub",
|
||||
action="store_true",
|
||||
help="Push the converted dataset to the hub",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use for encoding (default: cuda)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load encoders
|
||||
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
|
||||
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
|
||||
|
||||
# Generate embeddings
|
||||
generate_embeddings_for_dataset(
|
||||
repo_id=args.repo_id,
|
||||
output_repo_id=args.output_repo_id,
|
||||
image_encoder=image_encoder,
|
||||
language_encoder=language_encoder,
|
||||
remove_videos=args.remove_videos,
|
||||
local_dir=Path(args.local_dir) if args.local_dir else None,
|
||||
output_local_dir=Path(args.output_local_dir) if args.output_local_dir else None,
|
||||
push_to_hub=args.push_to_hub,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,222 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Validate pre-computed embeddings against on-the-fly computed embeddings.
|
||||
|
||||
Usage:
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \
|
||||
--embeddings-repo-id <your_username>/utokyo_xarm_bimanual_embeddings \
|
||||
--image-encoder dinov2_vitb14 \
|
||||
--language-encoder minilm-l12 \
|
||||
--num-samples 10
|
||||
"""
|
||||
|
||||
import argparse
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.datasets.generating_embeddings.encoders import ImageEncoder, LanguageEncoder
|
||||
from lerobot.datasets.generating_embeddings.generate_embeddings import (
|
||||
get_image_encoder,
|
||||
get_language_encoder,
|
||||
)
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
|
||||
def cosine_similarity(a: np.ndarray, b: np.ndarray) -> float:
|
||||
"""Compute cosine similarity between two vectors."""
|
||||
return np.dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b))
|
||||
|
||||
|
||||
def validate_embeddings(
|
||||
original_repo_id: str,
|
||||
embeddings_repo_id: str,
|
||||
image_encoder: ImageEncoder,
|
||||
language_encoder: LanguageEncoder,
|
||||
num_samples: int = 10,
|
||||
device: str = "cuda",
|
||||
):
|
||||
"""Validate pre-computed embeddings against on-the-fly embeddings.
|
||||
|
||||
Args:
|
||||
original_repo_id: Original dataset repository ID
|
||||
embeddings_repo_id: Dataset with pre-computed embeddings repository ID
|
||||
image_encoder: Image encoder instance
|
||||
language_encoder: Language encoder instance
|
||||
num_samples: Number of samples to validate
|
||||
device: Device to use for encoding
|
||||
"""
|
||||
# Load both datasets
|
||||
print("Loading datasets...")
|
||||
original_dataset = LeRobotDataset(original_repo_id, download_videos=True)
|
||||
embeddings_dataset = LeRobotDataset(embeddings_repo_id, download_videos=False)
|
||||
|
||||
# Verify both datasets have the same number of frames
|
||||
assert original_dataset.num_frames == embeddings_dataset.num_frames, (
|
||||
f"Frame count mismatch: original={original_dataset.num_frames}, "
|
||||
f"embeddings={embeddings_dataset.num_frames}"
|
||||
)
|
||||
|
||||
camera_keys = original_dataset.meta.camera_keys
|
||||
|
||||
# Check embedding features exist
|
||||
expected_features = ["task_embedding"] + [f"{cam}_embedding" for cam in camera_keys]
|
||||
for feat in expected_features:
|
||||
if feat not in embeddings_dataset.features:
|
||||
raise ValueError(f"Embedding feature not found: {feat}")
|
||||
|
||||
# Select random sample indices
|
||||
sample_indices = np.random.choice(
|
||||
original_dataset.num_frames, size=min(num_samples, original_dataset.num_frames), replace=False
|
||||
)
|
||||
print(f"Validating {len(sample_indices)} samples...")
|
||||
|
||||
# Track statistics
|
||||
task_similarities = []
|
||||
image_similarities = {cam: [] for cam in camera_keys}
|
||||
|
||||
for idx in tqdm(sample_indices, desc="Validating"):
|
||||
idx = int(idx)
|
||||
|
||||
embeddings_item = embeddings_dataset[idx]
|
||||
precomputed_task_emb = embeddings_item["task_embedding"].numpy()
|
||||
precomputed_image_embs = {cam: embeddings_item[f"{cam}_embedding"].numpy() for cam in camera_keys}
|
||||
|
||||
original_item = original_dataset[idx]
|
||||
|
||||
# Get task and compute embedding
|
||||
task = original_item["task"]
|
||||
# Clean up task text (same as in generate_embeddings.py)
|
||||
task_clean = task.strip().capitalize().strip(" .,!?-_")
|
||||
onthefly_task_emb = language_encoder.encode([task_clean])[0]
|
||||
|
||||
# Get images and compute embeddings
|
||||
onthefly_image_embs = {}
|
||||
for cam in camera_keys:
|
||||
img = original_item[cam]
|
||||
# Convert to numpy if needed
|
||||
if isinstance(img, torch.Tensor):
|
||||
if img.ndim == 3: # (C, H, W)
|
||||
img_np = (img.permute(1, 2, 0).numpy() * 255).astype(np.uint8)
|
||||
else:
|
||||
raise ValueError(f"Unexpected image shape: {img.shape}")
|
||||
else:
|
||||
img_np = np.array(img)
|
||||
|
||||
onthefly_image_embs[cam] = image_encoder.encode([img_np])[0]
|
||||
|
||||
# Task embedding comparison
|
||||
task_sim = cosine_similarity(precomputed_task_emb, onthefly_task_emb)
|
||||
task_similarities.append(task_sim)
|
||||
|
||||
# Image embedding comparison
|
||||
for cam in camera_keys:
|
||||
img_sim = cosine_similarity(precomputed_image_embs[cam], onthefly_image_embs[cam])
|
||||
image_similarities[cam].append(img_sim)
|
||||
|
||||
# Results
|
||||
print("\nResults:")
|
||||
task_sim_threshold = 0.99
|
||||
img_sim_threshold = 0.99
|
||||
|
||||
task_mean_sim = np.mean(task_similarities)
|
||||
task_pass = task_mean_sim >= task_sim_threshold
|
||||
|
||||
print(f" Task: {task_mean_sim:.4f} {'✓' if task_pass else '✗'}")
|
||||
|
||||
for cam in camera_keys:
|
||||
cam_mean_sim = np.mean(image_similarities[cam])
|
||||
cam_pass = cam_mean_sim >= img_sim_threshold
|
||||
print(f" {cam}: {cam_mean_sim:.4f} {'✓' if cam_pass else '✗'}")
|
||||
|
||||
image_pass = all(np.mean(image_similarities[cam]) >= img_sim_threshold for cam in camera_keys)
|
||||
|
||||
print()
|
||||
if task_pass and image_pass:
|
||||
print("✓ PASSED")
|
||||
else:
|
||||
print("✗ FAILED")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(
|
||||
description="Validate and compare pre-computed embeddings with on-the-fly embeddings",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
epilog="""
|
||||
Example:
|
||||
python src/lerobot/datasets/generating_embeddings/validate_embeddings.py \\
|
||||
--original-repo-id lerobot/utokyo_xarm_bimanual \\
|
||||
--embeddings-repo-id lerobot/utokyo_xarm_bimanual_embeddings \\
|
||||
--image-encoder dinov2_vitb14 \\
|
||||
--language-encoder minilm-l12 \\
|
||||
--num-samples 20
|
||||
""",
|
||||
)
|
||||
parser.add_argument("--original-repo-id", type=str, required=True, help="Original dataset repository ID")
|
||||
parser.add_argument(
|
||||
"--embeddings-repo-id",
|
||||
type=str,
|
||||
required=True,
|
||||
help="Dataset with pre-computed embeddings repository ID",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--image-encoder",
|
||||
type=str,
|
||||
default="dinov2_vitb14",
|
||||
help="Image encoder to use (default: dinov2_vitb14)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--language-encoder",
|
||||
type=str,
|
||||
default="minilm-l12",
|
||||
help="Language encoder to use (default: minilm-l12)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--num-samples",
|
||||
type=int,
|
||||
default=10,
|
||||
help="Number of samples to validate (default: 10)",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--device",
|
||||
type=str,
|
||||
default="cuda",
|
||||
help="Device to use for encoding (default: cuda)",
|
||||
)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
# Load encoders
|
||||
image_encoder = get_image_encoder(args.image_encoder, device=args.device)
|
||||
language_encoder = get_language_encoder(args.language_encoder, device=args.device)
|
||||
|
||||
# Validate embeddings
|
||||
validate_embeddings(
|
||||
original_repo_id=args.original_repo_id,
|
||||
embeddings_repo_id=args.embeddings_repo_id,
|
||||
image_encoder=image_encoder,
|
||||
language_encoder=language_encoder,
|
||||
num_samples=args.num_samples,
|
||||
device=args.device,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -110,8 +110,8 @@ def worker_thread_loop(queue: queue.Queue):
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath, compress_level = item
|
||||
write_image(image_array, fpath, compress_level)
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
@@ -169,13 +169,11 @@ class AsyncImageWriter:
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
):
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath, compress_level))
|
||||
self.queue.put((image, fpath))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
@@ -13,7 +13,6 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
@@ -540,15 +539,6 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -1081,7 +1071,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
# TODO(Steven): consider move this to utils
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
@@ -1091,15 +1080,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
) -> None:
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath, compress_level=compress_level)
|
||||
write_image(image, fpath)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
@@ -1137,19 +1124,14 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self._save_image(frame[key], img_path)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
@@ -1161,8 +1143,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
|
||||
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
|
||||
"""
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
|
||||
@@ -1199,40 +1179,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
# num_cameras * num_threads = (total_cpu -1)
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
@@ -1397,18 +1345,9 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(
|
||||
self,
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
temp_path: Path | None = None,
|
||||
) -> dict:
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
if temp_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
else:
|
||||
ep_path = temp_path
|
||||
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
@@ -1526,7 +1465,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
|
||||
@@ -0,0 +1,151 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
SARM Temporal Sampler for reward model training.
|
||||
|
||||
Samples frames uniformly from episodes for SARM's 9-frame symmetric pattern:
|
||||
- 1 initial frame + 4 frames before + current + 3 frames after
|
||||
|
||||
Boundary handling: clamp to first/last frame when indices go out of bounds.
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from typing import Iterator, Optional
|
||||
import numpy as np
|
||||
import torch
|
||||
from torch.utils.data import Sampler
|
||||
import random
|
||||
|
||||
|
||||
class SARMTemporalSampler(Sampler):
|
||||
"""
|
||||
Temporal sampler for SARM reward model training with symmetric/bidirectional sampling.
|
||||
|
||||
SARM uses 9 frames per sample:
|
||||
- Frame 0: Initial frame of the episode (always frame 0)
|
||||
- Frames 1-8: Symmetric context around current frame
|
||||
Pattern: [t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Early frames: backward indices clamp to 0 (e.g., [0,0,0,5,35,65,95,125])
|
||||
- Late frames: forward indices clamp to last frame (e.g., [850,880,910,940,970,1000,1000,1000])
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Args:
|
||||
dataset_from_index: Start indices of episodes (global dataset indices)
|
||||
dataset_to_index: End indices of episodes (global dataset indices)
|
||||
frame_gap: Gap between consecutive frames (default: 30 = 1 second at 30fps)
|
||||
shuffle: Whether to shuffle sampling order
|
||||
seed: Random seed for reproducibility
|
||||
samples_per_epoch: Number of samples per epoch (default: 6400)
|
||||
min_episode_length: Minimum episode length to include (default: 1)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
dataset_from_index: np.ndarray,
|
||||
dataset_to_index: np.ndarray,
|
||||
frame_gap: int = 30,
|
||||
shuffle: bool = True,
|
||||
seed: Optional[int] = None,
|
||||
samples_per_epoch: int = 6400,
|
||||
min_episode_length: int = 1,
|
||||
):
|
||||
self.dataset_from_index = np.array(dataset_from_index)
|
||||
self.dataset_to_index = np.array(dataset_to_index)
|
||||
self.frame_gap = frame_gap
|
||||
self.shuffle = shuffle
|
||||
self.samples_per_epoch = samples_per_epoch
|
||||
self.min_episode_length = min_episode_length
|
||||
|
||||
if seed is not None:
|
||||
self.seed = seed
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
self.generator = torch.Generator().manual_seed(seed)
|
||||
else:
|
||||
self.generator = torch.Generator()
|
||||
|
||||
# Compute valid episodes and sampling positions (ALL frames for uniform sampling)
|
||||
self._compute_valid_positions()
|
||||
|
||||
logging.info(
|
||||
f"SARMTemporalSampler: {len(self.valid_episodes)} valid episodes, "
|
||||
f"{len(self.all_valid_positions)} positions (uniform sampling), "
|
||||
f"{self.samples_per_epoch} samples per epoch, "
|
||||
f"frame_gap={frame_gap}, symmetric bidirectional pattern"
|
||||
)
|
||||
|
||||
def _compute_valid_positions(self):
|
||||
"""Compute valid episodes and ALL sampling positions for uniform sampling.
|
||||
|
||||
With symmetric bidirectional sampling, we can sample from ANY frame:
|
||||
- Early frames: backward indices clamp to first frame
|
||||
- Late frames: forward indices clamp to last frame
|
||||
"""
|
||||
self.valid_episodes = []
|
||||
self.all_valid_positions = []
|
||||
|
||||
for ep_idx in range(len(self.dataset_from_index)):
|
||||
ep_start = self.dataset_from_index[ep_idx]
|
||||
ep_end = self.dataset_to_index[ep_idx]
|
||||
episode_length = ep_end - ep_start
|
||||
|
||||
# Include all episodes with at least min_episode_length frames
|
||||
if episode_length >= self.min_episode_length:
|
||||
self.valid_episodes.append((ep_idx, ep_start, ep_end))
|
||||
|
||||
# Include ALL positions in the episode (truly uniform sampling)
|
||||
for pos in range(ep_start, ep_end):
|
||||
self.all_valid_positions.append(pos)
|
||||
|
||||
self.valid_episodes = np.array(self.valid_episodes)
|
||||
self.all_valid_positions = np.array(self.all_valid_positions)
|
||||
|
||||
if len(self.all_valid_positions) == 0:
|
||||
raise ValueError(
|
||||
f"No valid sampling positions found! "
|
||||
f"Check that episodes have at least {self.min_episode_length} frames."
|
||||
)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return self.samples_per_epoch
|
||||
|
||||
def __iter__(self) -> Iterator[int]:
|
||||
"""
|
||||
Yields global dataset indices for uniform sampling across episodes.
|
||||
|
||||
Each yielded index represents the "current frame" position.
|
||||
The dataset's observation_delta_indices then handles loading:
|
||||
- Frame 0: Episode initial frame (via large negative delta clamping)
|
||||
- Frames 1-8: Symmetric context around current frame (with boundary clamping)
|
||||
|
||||
For early frames: backward indices clamp to first frame (progress ~0%)
|
||||
For late frames: forward indices clamp to last frame (progress ~100%)
|
||||
"""
|
||||
if self.shuffle:
|
||||
# Randomly sample from all valid positions
|
||||
for _ in range(self.samples_per_epoch):
|
||||
idx = np.random.randint(0, len(self.all_valid_positions))
|
||||
yield int(self.all_valid_positions[idx])
|
||||
else:
|
||||
# Sequential sampling with wrap-around
|
||||
for i in range(self.samples_per_epoch):
|
||||
idx = i % len(self.all_valid_positions)
|
||||
yield int(self.all_valid_positions[idx])
|
||||
@@ -49,7 +49,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
|
||||
@@ -311,7 +311,6 @@ def encode_video_frames(
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
@@ -360,9 +359,6 @@ def encode_video_frames(
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
||||
|
||||
@@ -35,6 +35,7 @@ from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.tdmpc.configuration_tdmpc import TDMPCConfig
|
||||
@@ -103,6 +104,10 @@ def get_policy_class(name: str) -> type[PreTrainedPolicy]:
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy
|
||||
|
||||
return SmolVLAPolicy
|
||||
elif name == "sarm":
|
||||
from lerobot.policies.sarm.modeling_sarm import SARMRewardModel
|
||||
|
||||
return SARMRewardModel
|
||||
elif name == "groot":
|
||||
from lerobot.policies.groot.modeling_groot import GrootPolicy
|
||||
|
||||
@@ -322,6 +327,14 @@ def make_pre_post_processors(
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
)
|
||||
|
||||
elif isinstance(policy_cfg, SARMConfig):
|
||||
from lerobot.policies.sarm.processor_sarm import make_sarm_pre_post_processors
|
||||
|
||||
processors = make_sarm_pre_post_processors(
|
||||
config=policy_cfg,
|
||||
dataset_stats=kwargs.get("dataset_stats"),
|
||||
dataset_meta=kwargs.get("dataset_meta"),
|
||||
)
|
||||
elif isinstance(policy_cfg, GrootConfig):
|
||||
from lerobot.policies.groot.processor_groot import make_groot_pre_post_processors
|
||||
|
||||
@@ -405,6 +418,13 @@ def make_policy(
|
||||
if not cfg.input_features:
|
||||
cfg.input_features = {key: ft for key, ft in features.items() if key not in cfg.output_features}
|
||||
kwargs["config"] = cfg
|
||||
|
||||
# Pass dataset_stats to the policy if available (needed for some policies like SARM)
|
||||
if ds_meta is not None and hasattr(ds_meta, 'stats'):
|
||||
kwargs["dataset_stats"] = ds_meta.stats
|
||||
|
||||
if ds_meta is not None:
|
||||
kwargs["dataset_meta"] = ds_meta
|
||||
|
||||
if cfg.pretrained_path:
|
||||
# Load a pretrained policy and override the config if needed (for example, if there are inference-time
|
||||
|
||||
@@ -538,8 +538,6 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
if config.compile_model:
|
||||
torch.set_float32_matmul_precision("high")
|
||||
self.sample_actions = torch.compile(self.sample_actions, mode=config.compile_mode)
|
||||
# Also compile the main forward pass used during training
|
||||
self.forward = torch.compile(self.forward, mode=config.compile_mode)
|
||||
|
||||
msg = """An incorrect transformer version is used, please create an issue on https://github.com/huggingface/lerobot/issues"""
|
||||
|
||||
|
||||
@@ -14,5 +14,21 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
from .unitree_g1 import UnitreeG1
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.modeling_sarm import (
|
||||
SARMRewardModel,
|
||||
SARMTransformer,
|
||||
)
|
||||
from lerobot.policies.sarm.processor_sarm import (
|
||||
SARMEncodingProcessorStep,
|
||||
make_sarm_pre_post_processors,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"SARMConfig",
|
||||
"SARMRewardModel",
|
||||
"SARMTransformer",
|
||||
"SARMEncodingProcessorStep",
|
||||
"make_sarm_pre_post_processors",
|
||||
]
|
||||
|
||||
@@ -0,0 +1,186 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType, NormalizationMode
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("sarm")
|
||||
@dataclass
|
||||
class SARMConfig(PreTrainedConfig):
|
||||
"""Configuration class for SARM (Stage-Aware Reward Modeling)"""
|
||||
|
||||
# CLIP params
|
||||
image_dim: int = 512
|
||||
text_dim: int = 512
|
||||
num_frames: int = 9 # 1 initial + 8 consecutive frames
|
||||
frame_gap: int = 30 # Frame gap between frames (at 30 fps = 1 second)
|
||||
|
||||
# Architecture params
|
||||
hidden_dim: int = 768
|
||||
num_heads: int = 12
|
||||
num_layers: int = 8
|
||||
max_state_dim: int = 32
|
||||
num_stages: int = 5 # Number of task stages (auto-updated from annotations if available)
|
||||
subtask_names: list | None = None # List of subtask names (auto-populated from annotations)
|
||||
temporal_proportions: list | None = None # Temporal proportions for each stage (auto-computed from annotations)
|
||||
max_length: int = num_frames # Maximum video sequence length (matches num_frames)
|
||||
use_temporal_sampler: bool = True # Always enable temporal sequence loading
|
||||
|
||||
# Training params
|
||||
batch_size: int = 64
|
||||
clip_batch_size: int = 64 # Batch size for CLIP encoding
|
||||
dropout: float = 0.1
|
||||
stage_loss_weight: float = 1.0 # Weight for stage classification loss when using subtask annotations
|
||||
|
||||
pretrained_model_path: str | None = None
|
||||
device: str | None = None
|
||||
|
||||
# Processor settings
|
||||
image_key: str = "observation.images.top" # Key for image used from the dataset
|
||||
|
||||
# State key in the dataset (for normalization)
|
||||
state_key: str = "observation.state"
|
||||
|
||||
# Populated by the processor (video_features, state_features, text_features)
|
||||
input_features: dict = field(default_factory=lambda: {})
|
||||
|
||||
# Output features
|
||||
output_features: dict = field(default_factory=lambda: {
|
||||
"stage": PolicyFeature(shape=(9, 5), type=FeatureType.REWARD),
|
||||
"progress": PolicyFeature(shape=(9, 1), type=FeatureType.REWARD),
|
||||
})
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.IDENTITY,
|
||||
"STATE": NormalizationMode.MEAN_STD,
|
||||
"LANGUAGE": NormalizationMode.IDENTITY,
|
||||
"REWARD": NormalizationMode.IDENTITY,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
# Add the image_key as VISUAL
|
||||
if self.image_key:
|
||||
self.input_features[self.image_key] = PolicyFeature(
|
||||
shape=(480, 640, 3),
|
||||
type=FeatureType.VISUAL
|
||||
)
|
||||
|
||||
# Add state_key as STATE
|
||||
self.input_features[self.state_key] = PolicyFeature(
|
||||
shape=(self.max_state_dim,), # Single frame state, temporal sampling handles sequence
|
||||
type=FeatureType.STATE
|
||||
)
|
||||
|
||||
# Update output features with actual dimensions
|
||||
self.output_features["stage"] = PolicyFeature(
|
||||
shape=(self.num_frames, self.num_stages),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
self.output_features["progress"] = PolicyFeature(
|
||||
shape=(self.num_frames, 1),
|
||||
type=FeatureType.REWARD
|
||||
)
|
||||
|
||||
# Validate configuration
|
||||
if self.hidden_dim % self.num_heads != 0:
|
||||
raise ValueError(
|
||||
f"hidden_dim ({self.hidden_dim}) must be divisible by num_heads ({self.num_heads})"
|
||||
)
|
||||
|
||||
if self.max_length != self.num_frames:
|
||||
raise ValueError(
|
||||
f"max_length ({self.max_length}) must equal num_frames ({self.num_frames})"
|
||||
)
|
||||
|
||||
if self.num_stages < 2:
|
||||
raise ValueError(f"num_stages must be at least 2, got {self.num_stages}")
|
||||
|
||||
def get_optimizer_preset(self) -> AdamWConfig:
|
||||
"""Get default optimizer configuration for SARM training."""
|
||||
return AdamWConfig(
|
||||
lr=5e-5,
|
||||
weight_decay=1e-3,
|
||||
betas=(0.9, 0.999),
|
||||
eps=1e-8,
|
||||
)
|
||||
|
||||
def get_scheduler_preset(self) -> CosineDecayWithWarmupSchedulerConfig:
|
||||
"""Get default learning rate scheduler configuration."""
|
||||
return CosineDecayWithWarmupSchedulerConfig(
|
||||
peak_lr=5e-5,
|
||||
decay_lr=5e-6,
|
||||
num_warmup_steps=500,
|
||||
num_decay_steps=50000,
|
||||
)
|
||||
|
||||
def validate_features(self) -> None:
|
||||
"""Validate input and output features."""
|
||||
pass
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list[int]:
|
||||
"""Load frames for SARM temporal sampling with SYMMETRIC/BIDIRECTIONAL pattern.
|
||||
|
||||
The model uses 9 frames with symmetric context around current frame:
|
||||
- Frame 0: Initial frame of the episode (clamped via large negative delta)
|
||||
- Frames 1-8: Symmetric context: 4 before + current + 3 after
|
||||
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling (done by dataset loader):
|
||||
- Early frames: backward indices clamp to 0 (first frame)
|
||||
- Late frames: forward indices clamp to episode end (last frame)
|
||||
|
||||
This enables truly uniform sampling across entire episodes.
|
||||
|
||||
Returns:
|
||||
9 delta indices: [-1_000_000, -4*gap, -3*gap, -2*gap, -gap, 0, gap, 2*gap, 3*gap]
|
||||
"""
|
||||
initial_frame_delta = -1_000_000
|
||||
|
||||
# Symmetric pattern: 4 frames before, current (0), 3 frames after = 8 context frames
|
||||
symmetric_deltas = [
|
||||
-4 * self.frame_gap,
|
||||
-3 * self.frame_gap,
|
||||
-2 * self.frame_gap,
|
||||
-1 * self.frame_gap,
|
||||
0, # current frame
|
||||
1 * self.frame_gap,
|
||||
2 * self.frame_gap,
|
||||
3 * self.frame_gap,
|
||||
]
|
||||
|
||||
return [initial_frame_delta] + symmetric_deltas
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> None:
|
||||
"""SARM is a reward model, not an action policy."""
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
"""SARM doesn't use delta rewards."""
|
||||
return None
|
||||
|
||||
@@ -0,0 +1,650 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
from typing import List, Union, Optional
|
||||
import random
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from PIL import Image
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
|
||||
class SARMTransformer(nn.Module):
|
||||
"""
|
||||
SARM Transformer model for stage-aware reward prediction.
|
||||
|
||||
This model has a dual-head architecture:
|
||||
1. Stage estimator: Predicts the high-level task stage (classification)
|
||||
2. Subtask estimator: Predicts fine-grained progress within the stage (regression)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_dim: int = 512,
|
||||
text_dim: int = 512,
|
||||
max_state_dim: int = 32,
|
||||
hidden_dim: int = 768,
|
||||
num_heads: int = 12,
|
||||
num_layers: int = 8,
|
||||
num_stages: int = 5,
|
||||
max_length: int = 9,
|
||||
dropout: float = 0.1,
|
||||
temporal_proportions: list[float] | None = None
|
||||
):
|
||||
super().__init__()
|
||||
self.hidden_dim = hidden_dim
|
||||
self.max_length = max_length
|
||||
self.num_stages = num_stages
|
||||
self.max_state_dim = max_state_dim
|
||||
|
||||
if temporal_proportions is None:
|
||||
raise ValueError(
|
||||
"temporal_proportions is required for SARM. "
|
||||
"Provide subtask annotations in your dataset or set temporal_proportions in config."
|
||||
)
|
||||
|
||||
# ᾱ_k: proportion for each stage
|
||||
alpha = torch.tensor(temporal_proportions, dtype=torch.float32)
|
||||
|
||||
# P_k: cumulative proportion up to stage k (P_0 = 0)
|
||||
cumulative = torch.zeros(num_stages + 1, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
self.register_buffer('alpha', alpha)
|
||||
self.register_buffer('cumulative_prior', cumulative)
|
||||
|
||||
self.video_proj = nn.Linear(video_dim, hidden_dim)
|
||||
self.text_proj = nn.Linear(text_dim, hidden_dim)
|
||||
self.state_proj = nn.Linear(max_state_dim, hidden_dim)
|
||||
|
||||
# Position embedding only for the first frame
|
||||
self.first_pos_embed = nn.Parameter(torch.randn(1, hidden_dim))
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=hidden_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=hidden_dim * 4,
|
||||
dropout=dropout,
|
||||
batch_first=True
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
# Stage estimator head (classification)
|
||||
self.stage_head = nn.Sequential(
|
||||
nn.Linear(hidden_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(512, num_stages)
|
||||
)
|
||||
|
||||
# Subtask estimator head (regression)
|
||||
self.stage_embedding = nn.Embedding(num_stages, hidden_dim // 4)
|
||||
subtask_input_dim = hidden_dim + hidden_dim // 4
|
||||
self.subtask_head = nn.Sequential(
|
||||
nn.Linear(subtask_input_dim, 512),
|
||||
nn.LayerNorm(512),
|
||||
nn.GELU(),
|
||||
nn.Dropout(dropout),
|
||||
nn.Linear(512, 1),
|
||||
nn.Sigmoid()
|
||||
)
|
||||
|
||||
# Attention mask
|
||||
self.register_buffer("attention_mask", None, persistent=False)
|
||||
|
||||
def _get_attention_mask(self, seq_length: int, device: torch.device) -> torch.Tensor:
|
||||
"""Generate or retrieve cached causal attention mask."""
|
||||
if self.attention_mask is None or self.attention_mask.shape[0] != seq_length:
|
||||
# Create causal mask
|
||||
mask = nn.Transformer.generate_square_subsequent_mask(seq_length, device=device)
|
||||
self.attention_mask = mask
|
||||
return self.attention_mask
|
||||
|
||||
def forward(
|
||||
self,
|
||||
video_frames: torch.Tensor,
|
||||
text_embed: torch.Tensor,
|
||||
state_features: Optional[torch.Tensor] = None
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""
|
||||
Forward pass through the SARM transformer.
|
||||
|
||||
Args:
|
||||
video_frames: Video frame embeddings (batch_size, seq_len, video_dim)
|
||||
text_embed: Text embeddings (batch_size, text_dim)
|
||||
state_features: Joint state features (batch_size, seq_len, state_dim)
|
||||
|
||||
Returns:
|
||||
Tuple of:
|
||||
- Stage logits for each frame (batch_size, seq_len, num_stages)
|
||||
- Stage probabilities (batch_size, seq_len, num_stages)
|
||||
- Progress predictions for each frame (batch_size, seq_len, 1)
|
||||
"""
|
||||
# Project inputs to common dimension
|
||||
video_embed = self.video_proj(video_frames) # [batch_size, seq_len, hidden_dim]
|
||||
text_embed = self.text_proj(text_embed).unsqueeze(1) # [batch_size, 1, hidden_dim]
|
||||
|
||||
# Pad state features to max_state_dim before projection
|
||||
state_features_padded = pad_state_to_max_dim(state_features, self.max_state_dim)
|
||||
|
||||
state_embed = self.state_proj(state_features_padded) # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Fuse video and state features
|
||||
video_embed = video_embed + state_embed
|
||||
|
||||
# Add positional embedding to first video frame
|
||||
video_embed[:, 0] += self.first_pos_embed
|
||||
|
||||
# Combine sequence: [text, video_frames]
|
||||
sequence = torch.cat([text_embed, video_embed], dim=1)
|
||||
|
||||
# Get causal attention mask
|
||||
seq_length = sequence.shape[1]
|
||||
attention_mask = self._get_attention_mask(seq_length, sequence.device)
|
||||
|
||||
# Pass through transformer with causal masking
|
||||
transformed = self.transformer(sequence, mask=attention_mask, is_causal=True)
|
||||
|
||||
# Get frame features
|
||||
frame_features = transformed[:, 1:] # [batch_size, seq_len, hidden_dim]
|
||||
|
||||
# Stage estimation
|
||||
stage_logits = self.stage_head(frame_features) # [batch_size, seq_len, num_stages]
|
||||
stage_probs = F.softmax(stage_logits, dim=-1) # [batch_size, seq_len, num_stages]
|
||||
|
||||
# Get predicted stage indices
|
||||
stage_indices = torch.argmax(stage_probs, dim=-1) # [batch_size, seq_len]
|
||||
|
||||
# Get stage embeddings for conditioning
|
||||
stage_embeds = self.stage_embedding(stage_indices)
|
||||
|
||||
# Concatenate frame features with stage embeddings
|
||||
conditioned_features = torch.cat([frame_features, stage_embeds], dim=-1)
|
||||
|
||||
# Subtask progress estimation (conditioned on stage)
|
||||
# τ̂ = within-subtask progress (0-1)
|
||||
tau_preds = self.subtask_head(conditioned_features) # [batch_size, seq_len, 1]
|
||||
|
||||
# Convert τ̂ to cumulative progress ŷ using Paper Formula (2):
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
progress_preds = compute_cumulative_progress_batch(
|
||||
tau_preds, stage_indices, self.alpha, self.cumulative_prior
|
||||
)
|
||||
|
||||
return stage_logits, stage_probs, progress_preds
|
||||
|
||||
|
||||
class SARMRewardModel(PreTrainedPolicy):
|
||||
"""
|
||||
SARM Reward Model for stage-aware task completion rewards.
|
||||
|
||||
Per SARM paper (Appendix A.4): "We employ a frozen clip-vit-base-patch32 encoder
|
||||
to process both RGB image sequences and task descriptions."
|
||||
|
||||
This model combines:
|
||||
- CLIP for encoding video frames AND text descriptions
|
||||
- SARMTransformer for predicting task stage and progress
|
||||
- Optional RA-BC (Reward-Aligned Behavior Cloning) for weighted training
|
||||
"""
|
||||
|
||||
name = "sarm"
|
||||
config_class = SARMConfig
|
||||
|
||||
def __init__(self, config: SARMConfig, dataset_stats: dict | None = None, dataset_meta=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
self.dataset_stats = dataset_stats
|
||||
self.device = torch.device(config.device if config.device else "cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Load temporal proportions from dataset
|
||||
if config.temporal_proportions is None and dataset_meta is not None:
|
||||
self._load_temporal_proportions(dataset_meta)
|
||||
|
||||
logging.info("Loading CLIP encoder")
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
self.sarm_transformer = SARMTransformer(
|
||||
video_dim=config.image_dim,
|
||||
text_dim=config.text_dim,
|
||||
max_state_dim=config.max_state_dim,
|
||||
hidden_dim=config.hidden_dim,
|
||||
num_heads=config.num_heads,
|
||||
num_layers=config.num_layers,
|
||||
num_stages=config.num_stages,
|
||||
max_length=config.max_length,
|
||||
dropout=config.dropout,
|
||||
temporal_proportions=config.temporal_proportions
|
||||
)
|
||||
self.sarm_transformer.to(self.device)
|
||||
logging.info(f"SARM initialized on {self.device}")
|
||||
|
||||
def _load_temporal_proportions(self, dataset_meta) -> None:
|
||||
"""
|
||||
Load pre-computed temporal proportions from dataset metadata JSON file.
|
||||
|
||||
The temporal proportions are computed during dataset annotation using SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
"""
|
||||
import json
|
||||
|
||||
proportions_path = dataset_meta.root / "meta" / "temporal_proportions.json"
|
||||
|
||||
if not proportions_path.exists():
|
||||
raise ValueError(
|
||||
f"Temporal proportions not found at {proportions_path}. "
|
||||
"Run the subtask annotation tool first to compute and save temporal proportions."
|
||||
)
|
||||
|
||||
with open(proportions_path, "r") as f:
|
||||
temporal_proportions_dict = json.load(f)
|
||||
|
||||
# Sort subtask names for consistent ordering
|
||||
subtask_names = sorted(temporal_proportions_dict.keys())
|
||||
|
||||
self.config.num_stages = len(subtask_names)
|
||||
self.config.subtask_names = subtask_names
|
||||
self.config.temporal_proportions = [temporal_proportions_dict[name] for name in subtask_names]
|
||||
|
||||
logging.info(f"Loaded {len(subtask_names)} subtasks: {subtask_names}")
|
||||
logging.info(f"Temporal proportions: {temporal_proportions_dict}")
|
||||
|
||||
def to(self, device):
|
||||
"""Override to method to ensure all components move together."""
|
||||
super().to(device)
|
||||
self.device = device if isinstance(device, torch.device) else torch.device(device)
|
||||
self.clip_model.to(device)
|
||||
self.sarm_transformer.to(device)
|
||||
return self
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_images(self, images: np.ndarray) -> np.ndarray:
|
||||
"""
|
||||
Encode video frames using CLIP.
|
||||
|
||||
Args:
|
||||
images: Video frames with shape (num_videos, num_frames, H, W, C) in uint8.
|
||||
Can also be (num_frames, H, W, C) for a single video.
|
||||
|
||||
Returns:
|
||||
Encoded image features (num_videos, num_frames, 512) or (num_frames, 512).
|
||||
"""
|
||||
# Handle single video case
|
||||
single_video = False
|
||||
if len(images.shape) == 4:
|
||||
images = images[np.newaxis, ...]
|
||||
single_video = True
|
||||
|
||||
assert len(images.shape) == 5, f"Expected 5D input (num_videos, num_frames, H, W, C), got {images.shape}"
|
||||
|
||||
all_embeddings = []
|
||||
|
||||
for video in images:
|
||||
video_embeddings = []
|
||||
|
||||
# Convert frames to PIL images for CLIP processor
|
||||
frames = []
|
||||
for frame in video:
|
||||
if frame.shape[0] == 3: # Channel first
|
||||
frame = frame.transpose(1, 2, 0)
|
||||
if frame.dtype != np.uint8:
|
||||
frame = (frame * 255).astype(np.uint8) if frame.max() <= 1.0 else frame.astype(np.uint8)
|
||||
frames.append(Image.fromarray(frame))
|
||||
|
||||
# Batch process frames with CLIP
|
||||
for i in range(0, len(frames), self.config.clip_batch_size):
|
||||
batch = frames[i:i + self.config.clip_batch_size]
|
||||
inputs = self.clip_processor(images=batch, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings from CLIP
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
video_embeddings.append(embeddings)
|
||||
|
||||
video_embeddings = torch.cat(video_embeddings)
|
||||
all_embeddings.append(video_embeddings)
|
||||
|
||||
result = torch.stack(all_embeddings).numpy()
|
||||
|
||||
if single_video:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def encode_text(self, text: Union[str, List[str]]) -> np.ndarray:
|
||||
"""
|
||||
Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Text string or list of text strings.
|
||||
|
||||
Returns:
|
||||
Encoded text features (batch_size, 512) or (512,) for single text.
|
||||
"""
|
||||
if isinstance(text, str):
|
||||
text = [text]
|
||||
single_text = True
|
||||
else:
|
||||
single_text = False
|
||||
|
||||
# Use CLIP's tokenizer directly (avoids image processor validation issues)
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
|
||||
# Process in batches
|
||||
all_embeddings = []
|
||||
for i in range(0, len(text), self.config.batch_size):
|
||||
batch_text = text[i:i + self.config.batch_size]
|
||||
|
||||
inputs = tokenizer(batch_text, return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
text_embeddings = self.clip_model.get_text_features(**inputs)
|
||||
all_embeddings.append(text_embeddings.cpu())
|
||||
|
||||
result = torch.cat(all_embeddings).numpy()
|
||||
|
||||
if single_text:
|
||||
result = result[0]
|
||||
|
||||
return result
|
||||
|
||||
@torch.no_grad()
|
||||
def calculate_rewards(
|
||||
self,
|
||||
text_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
video_embeddings: Union[np.ndarray, torch.Tensor],
|
||||
state_features: Optional[Union[np.ndarray, torch.Tensor]] = None,
|
||||
return_all_frames: bool = False,
|
||||
return_stages: bool = False
|
||||
) -> Union[np.ndarray, tuple]:
|
||||
"""
|
||||
Calculate rewards for given text, video, and state representations.
|
||||
|
||||
Args:
|
||||
text_embeddings: Encoded text representations (batch_size, 512)
|
||||
video_embeddings: Encoded video representations (batch_size, num_frames, 512)
|
||||
state_features: Joint state features (batch_size, num_frames, state_dim)
|
||||
return_all_frames: If True, return rewards for all frames
|
||||
return_stages: If True, also return stage predictions
|
||||
|
||||
Returns:
|
||||
If return_stages=False:
|
||||
Reward values (batch_size,) or (batch_size, num_frames)
|
||||
If return_stages=True:
|
||||
Tuple of (rewards, stage_probs)
|
||||
"""
|
||||
if isinstance(text_embeddings, np.ndarray):
|
||||
text_embeddings = torch.tensor(text_embeddings, dtype=torch.float32)
|
||||
if isinstance(video_embeddings, np.ndarray):
|
||||
video_embeddings = torch.tensor(video_embeddings, dtype=torch.float32)
|
||||
if state_features is not None and isinstance(state_features, np.ndarray):
|
||||
state_features = torch.tensor(state_features, dtype=torch.float32)
|
||||
|
||||
# Handle single sample case
|
||||
if text_embeddings.dim() == 1:
|
||||
text_embeddings = text_embeddings.unsqueeze(0)
|
||||
video_embeddings = video_embeddings.unsqueeze(0)
|
||||
if state_features is not None:
|
||||
state_features = state_features.unsqueeze(0)
|
||||
single_sample = True
|
||||
else:
|
||||
single_sample = False
|
||||
|
||||
# Process in batches
|
||||
all_rewards = []
|
||||
all_stage_probs = []
|
||||
|
||||
for i in range(0, len(video_embeddings), self.config.batch_size):
|
||||
batch_texts = text_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
batch_videos = video_embeddings[i:i + self.config.batch_size].to(self.device)
|
||||
batch_states = None
|
||||
if state_features is not None:
|
||||
batch_states = state_features[i:i + self.config.batch_size].to(self.device)
|
||||
|
||||
# Get predictions
|
||||
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
|
||||
batch_videos.float(), batch_texts.float(), batch_states.float() if batch_states is not None else None
|
||||
)
|
||||
|
||||
if return_all_frames:
|
||||
all_rewards.append(progress_preds.squeeze(-1).cpu())
|
||||
else:
|
||||
# Return only last frame reward
|
||||
all_rewards.append(progress_preds[:, -1, 0].cpu())
|
||||
|
||||
if return_stages:
|
||||
all_stage_probs.append(stage_probs.cpu())
|
||||
|
||||
rewards = torch.cat(all_rewards).numpy()
|
||||
|
||||
if single_sample:
|
||||
rewards = rewards[0] if not return_all_frames else rewards[0]
|
||||
|
||||
if return_stages:
|
||||
stage_probs = torch.cat(all_stage_probs).numpy()
|
||||
if single_sample:
|
||||
stage_probs = stage_probs[0]
|
||||
return rewards, stage_probs
|
||||
|
||||
return rewards
|
||||
|
||||
def train(self, mode: bool = True):
|
||||
"""Overwrite train method to ensure CLIP encoder stays frozen during training"""
|
||||
super().train(mode)
|
||||
self.clip_model.eval()
|
||||
self.sarm_transformer.train(mode)
|
||||
return self
|
||||
|
||||
def eval(self):
|
||||
"""Overwrite eval method to ensure CLIP encoder stays frozen during evaluation"""
|
||||
return self.train(False)
|
||||
|
||||
def parameters(self):
|
||||
"""Override to return trainable parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.sarm_transformer.parameters()
|
||||
|
||||
def get_optim_params(self):
|
||||
"""Override to return optimizer parameters (only SARM transformer, not CLIP encoder)."""
|
||||
return self.parameters()
|
||||
|
||||
def reset(self):
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
pass
|
||||
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for reward models."""
|
||||
raise NotImplementedError("SARM model does not predict action chunks")
|
||||
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Required by PreTrainedPolicy but not used for SARM."""
|
||||
raise NotImplementedError("SARM model does not select actions")
|
||||
|
||||
def _apply_temporal_augmentation(
|
||||
self,
|
||||
video: torch.Tensor,
|
||||
progress: torch.Tensor,
|
||||
state: torch.Tensor | None,
|
||||
max_length: int,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None]:
|
||||
"""Apply temporal augmentation by appending reversed frames (SARM paper A.4).
|
||||
|
||||
This helps the model learn to handle non-monotonic progress (failures, recoveries).
|
||||
Appends 1-4 reversed frames to simulate going backwards in task progress.
|
||||
"""
|
||||
num_reverse = random.randint(1, min(4, max_length - 1))
|
||||
|
||||
# Reverse and take frames (skip first which is last of original)
|
||||
reversed_video = video.flip(0)[1:num_reverse + 1]
|
||||
reversed_progress = progress.flip(0)[1:num_reverse + 1]
|
||||
|
||||
# Concatenate and trim
|
||||
video = torch.cat([video, reversed_video], dim=0)[:max_length]
|
||||
progress = torch.cat([progress, reversed_progress], dim=0)[:max_length]
|
||||
|
||||
if state is not None:
|
||||
reversed_state = state.flip(0)[1:num_reverse + 1]
|
||||
state = torch.cat([state, reversed_state], dim=0)[:max_length]
|
||||
|
||||
return video, progress, state
|
||||
|
||||
def _ensure_sequence_length(self, tensor: torch.Tensor, target_len: int) -> torch.Tensor:
|
||||
"""Pad or trim tensor to target length."""
|
||||
current_len = tensor.shape[0]
|
||||
if current_len == target_len:
|
||||
return tensor
|
||||
if current_len < target_len:
|
||||
padding = target_len - current_len
|
||||
return torch.cat([tensor, tensor[-1:].expand(padding, *tensor.shape[1:])])
|
||||
return tensor[:target_len]
|
||||
|
||||
def forward(self, batch):
|
||||
"""
|
||||
Forward pass for SARM reward model training.
|
||||
|
||||
Uses annotation-based progress targets following SARM paper Eq. 2:
|
||||
yt = Pk-1 + α̅k × τt
|
||||
where:
|
||||
- τt = (t - sk) / (ek - sk) is within-subtask normalized time
|
||||
- Pk-1 is cumulative prior (sum of previous subtask proportions)
|
||||
- α̅k is the temporal proportion for subtask k
|
||||
|
||||
Args:
|
||||
batch: Dictionary with 'observation' containing:
|
||||
- 'video_features': (B, T, 512) pre-encoded video features
|
||||
- 'text_features': (B, 512) pre-encoded text features (CLIP)
|
||||
- 'state_features': (B, T, state_dim) joint state features
|
||||
- 'stage_labels': (B, T) stage labels from annotations
|
||||
- 'progress_targets': (B, T, 1) progress targets from annotations
|
||||
|
||||
Returns:
|
||||
Tuple of (total_loss, output_dict with loss components)
|
||||
"""
|
||||
observation = batch.get('observation', batch)
|
||||
|
||||
# Extract required features
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
state_features = observation.get('state_features').to(self.device)
|
||||
|
||||
batch_size = video_features.shape[0]
|
||||
max_length = self.config.num_frames
|
||||
|
||||
# Ensure 3D video features (B, T, D)
|
||||
if video_features.dim() == 2:
|
||||
video_features = video_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
if state_features is not None and state_features.dim() == 2:
|
||||
state_features = state_features.unsqueeze(1).expand(-1, max_length, -1)
|
||||
|
||||
# Get annotation-based progress targets (required for SARM paper formula)
|
||||
progress_from_annotations = observation.get('progress_targets')
|
||||
if progress_from_annotations is None:
|
||||
raise ValueError("progress_targets from annotations is required for SARM training")
|
||||
|
||||
progress_from_annotations = progress_from_annotations.to(self.device)
|
||||
if progress_from_annotations.dim() == 2:
|
||||
progress_from_annotations = progress_from_annotations.unsqueeze(-1)
|
||||
if progress_from_annotations.dim() == 3 and progress_from_annotations.shape[0] == 1:
|
||||
progress_from_annotations = progress_from_annotations.expand(batch_size, -1, -1)
|
||||
|
||||
# Process each sample: apply temporal REWIND augmentation
|
||||
processed_videos = []
|
||||
processed_states = []
|
||||
progress_targets = []
|
||||
|
||||
for i in range(batch_size):
|
||||
video = video_features[i]
|
||||
state = state_features[i] if state_features is not None else None
|
||||
progress = progress_from_annotations[i].squeeze(-1) # (T,)
|
||||
|
||||
# Apply temporal REWIND augmentation with 50% probability: appends up to 4 reversed frames to simulate failures/recoveries
|
||||
if random.random() < 0.5:
|
||||
video, progress, state = self._apply_temporal_augmentation(video, progress, state, max_length)
|
||||
|
||||
# Ensure correct sequence length
|
||||
video = self._ensure_sequence_length(video, max_length)
|
||||
progress = self._ensure_sequence_length(progress.unsqueeze(-1), max_length).squeeze(-1)
|
||||
if state is not None:
|
||||
state = self._ensure_sequence_length(state, max_length)
|
||||
|
||||
processed_videos.append(video)
|
||||
progress_targets.append(progress)
|
||||
if state is not None:
|
||||
processed_states.append(state)
|
||||
|
||||
# Stack into batches
|
||||
processed_videos = torch.stack(processed_videos)
|
||||
progress_targets = torch.stack(progress_targets).unsqueeze(-1) # (B, T, 1)
|
||||
processed_states = torch.stack(processed_states) if processed_states else None
|
||||
|
||||
# Get model predictions
|
||||
stage_logits, stage_probs, progress_preds = self.sarm_transformer(
|
||||
processed_videos, text_features, processed_states
|
||||
)
|
||||
|
||||
# Compute progress loss (MSE)
|
||||
progress_loss = F.mse_loss(progress_preds, progress_targets)
|
||||
output_dict = {'progress_loss': progress_loss.item()}
|
||||
total_loss = progress_loss
|
||||
|
||||
# Compute stage loss (cross-entropy)
|
||||
stage_labels = observation.get('stage_labels')
|
||||
if stage_labels is None:
|
||||
raise ValueError("stage_labels from annotations is required for SARM training")
|
||||
|
||||
stage_labels = stage_labels.to(self.device)
|
||||
if stage_labels.dim() == 1:
|
||||
stage_labels = stage_labels.unsqueeze(0).expand(batch_size, -1)
|
||||
stage_loss = compute_stage_loss(stage_logits, stage_labels)
|
||||
total_loss = total_loss + self.config.stage_loss_weight * stage_loss
|
||||
output_dict['stage_loss'] = stage_loss.item()
|
||||
|
||||
# Misaligned loss: 20% probability
|
||||
if random.random() < 0.2:
|
||||
shuffle_idx = torch.randperm(batch_size, device=self.device)
|
||||
_, _, misaligned_preds = self.sarm_transformer(
|
||||
processed_videos, text_features[shuffle_idx], processed_states
|
||||
)
|
||||
misaligned_loss = F.mse_loss(misaligned_preds, torch.zeros_like(misaligned_preds))
|
||||
total_loss = total_loss + misaligned_loss
|
||||
output_dict['misaligned_loss'] = misaligned_loss.item()
|
||||
|
||||
output_dict['total_loss'] = total_loss.item()
|
||||
return total_loss, output_dict
|
||||
|
||||
def compute_stage_loss(stage_logits: torch.Tensor, target_stages: torch.Tensor) -> torch.Tensor:
|
||||
_, _, num_stages = stage_logits.shape
|
||||
stage_logits_flat = stage_logits.reshape(-1, num_stages)
|
||||
target_stages_flat = target_stages.reshape(-1)
|
||||
|
||||
loss = F.cross_entropy(stage_logits_flat, target_stages_flat)
|
||||
return loss
|
||||
@@ -0,0 +1,644 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Any
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
import pandas as pd
|
||||
from transformers import CLIPModel, CLIPProcessor
|
||||
|
||||
from lerobot.processor.core import TransitionKey
|
||||
from lerobot.policies.sarm.configuration_sarm import SARMConfig
|
||||
from lerobot.policies.sarm.sarm_utils import compute_tau, compute_cumulative_progress_batch, pad_state_to_max_dim
|
||||
from lerobot.processor import (
|
||||
ProcessorStep,
|
||||
PolicyProcessorPipeline,
|
||||
PolicyAction,
|
||||
DeviceProcessorStep,
|
||||
AddBatchDimensionProcessorStep,
|
||||
NormalizerProcessorStep,
|
||||
)
|
||||
from lerobot.processor.converters import (
|
||||
policy_action_to_transition,
|
||||
transition_to_policy_action,
|
||||
from_tensor_to_numpy,
|
||||
)
|
||||
from lerobot.processor.pipeline import PipelineFeatureType
|
||||
from lerobot.processor.core import EnvTransition, TransitionKey
|
||||
from lerobot.configs.types import PolicyFeature, FeatureType
|
||||
from lerobot.utils.constants import POLICY_POSTPROCESSOR_DEFAULT_NAME, POLICY_PREPROCESSOR_DEFAULT_NAME
|
||||
|
||||
|
||||
class SARMEncodingProcessorStep(ProcessorStep):
|
||||
"""ProcessorStep that encodes images and text with CLIP."""
|
||||
def __init__(
|
||||
self,
|
||||
config: SARMConfig,
|
||||
image_key: str | None = None,
|
||||
dataset_meta = None,
|
||||
dataset_stats: dict | None = None,
|
||||
):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.image_key = image_key or config.image_key
|
||||
self.dataset_meta = dataset_meta
|
||||
self.dataset_stats = dataset_stats
|
||||
self.temporal_proportions = {name: prop for name, prop in zip(self.config.subtask_names, self.config.temporal_proportions)}
|
||||
self.subtask_names = self.config.subtask_names
|
||||
|
||||
self.device = torch.device(
|
||||
self.config.device if self.config.device
|
||||
else "cuda" if torch.cuda.is_available() else "cpu"
|
||||
)
|
||||
|
||||
self.clip_model = CLIPModel.from_pretrained("openai/clip-vit-base-patch32")
|
||||
self.clip_processor = CLIPProcessor.from_pretrained("openai/clip-vit-base-patch32", use_fast=True)
|
||||
self.clip_model.to(self.device)
|
||||
self.clip_model.eval()
|
||||
|
||||
def _find_episode_for_frame(self, frame_idx: int) -> int:
|
||||
"""Find the episode index for a given frame index."""
|
||||
for ep_idx in range(len(self.dataset_meta.episodes)):
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
if ep_start <= frame_idx < ep_end:
|
||||
return ep_idx
|
||||
return 0
|
||||
|
||||
def _get_episode_indices(self, frame_indices: np.ndarray, episode_index) -> np.ndarray:
|
||||
"""Get episode indices for each frame index."""
|
||||
if episode_index is None:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
episode_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(episode_index)))
|
||||
|
||||
# If single episode but multiple frames, compute episode for each frame
|
||||
if len(episode_indices) == 1 and len(frame_indices) > 1:
|
||||
return np.array([self._find_episode_for_frame(int(f)) for f in frame_indices])
|
||||
|
||||
return episode_indices
|
||||
|
||||
def _compute_absolute_indices(self, frame_idx: int, ep_start: int, ep_end: int, num_frames: int) -> torch.Tensor:
|
||||
"""Compute absolute frame indices for symmetric bidirectional pattern.
|
||||
|
||||
Pattern: [ep_start, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Backward indices clamp to ep_start (first frame)
|
||||
- Forward indices clamp to ep_end - 1 (last frame)
|
||||
"""
|
||||
indices = []
|
||||
indices.append(ep_start) # Initial frame is always episode start
|
||||
|
||||
# Symmetric context: 4 before, current, 3 after
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
last_valid_frame = ep_end - 1
|
||||
|
||||
# Frames before current (clamp to first frame)
|
||||
for i in range(num_before, 0, -1):
|
||||
idx = max(ep_start, frame_idx - i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
# Current frame
|
||||
indices.append(frame_idx)
|
||||
|
||||
# Frames after current (clamp to last frame)
|
||||
for i in range(1, num_after + 1):
|
||||
idx = min(last_valid_frame, frame_idx + i * self.config.frame_gap)
|
||||
indices.append(idx)
|
||||
|
||||
return torch.tensor(indices)
|
||||
|
||||
def _compute_episode_metadata(
|
||||
self,
|
||||
frame_indices: np.ndarray,
|
||||
episode_indices: np.ndarray,
|
||||
num_frames: int,
|
||||
) -> tuple[list | torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
"""Compute episode metadata for all samples.
|
||||
|
||||
Returns:
|
||||
Tuple of (absolute_frame_indices, remaining_lengths, episode_lengths)
|
||||
"""
|
||||
absolute_indices_list = []
|
||||
remaining_lengths = []
|
||||
episode_lengths = []
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
ep_idx, frame_idx = int(ep_idx), int(frame_idx)
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
|
||||
episode_lengths.append(ep_end - ep_start)
|
||||
abs_indices = self._compute_absolute_indices(frame_idx, ep_start, ep_end, num_frames)
|
||||
absolute_indices_list.append(abs_indices)
|
||||
remaining_lengths.append(ep_end - abs_indices[0].item())
|
||||
|
||||
return absolute_indices_list, torch.tensor(remaining_lengths), torch.tensor(episode_lengths)
|
||||
|
||||
def _compute_stage_and_progress_for_frame(
|
||||
self,
|
||||
current_frame: int,
|
||||
subtask_names: list,
|
||||
subtask_start_frames: list,
|
||||
subtask_end_frames: list,
|
||||
transition_smoothing_frames: int = 15,
|
||||
) -> tuple[int, float, dict[int, float] | None]:
|
||||
"""Compute stage index, cumulative progress, and soft stage labels for a single frame.
|
||||
|
||||
Implements SARM Paper Formula (2):
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
|
||||
where:
|
||||
- τ_t = (t - s_k) / (e_k - s_k) is within-subtask progress
|
||||
- P_{k-1} is cumulative prior (sum of previous subtask proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k
|
||||
|
||||
Additionally computes soft stage labels near transitions to mitigate discrete jumps
|
||||
in the stage classifier. Near stage boundaries, labels are blended between adjacent
|
||||
stages to encourage smoother predictions.
|
||||
|
||||
Args:
|
||||
current_frame: Frame index relative to episode start
|
||||
subtask_names: List of subtask names for this episode
|
||||
subtask_start_frames: List of subtask start frames
|
||||
subtask_end_frames: List of subtask end frames
|
||||
transition_smoothing_frames: Number of frames over which to smooth labels near transitions
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_idx, cumulative_progress, soft_stage_labels)
|
||||
- stage_idx: Hard stage index (for compatibility)
|
||||
- cumulative_progress: Progress value in [0, 1]
|
||||
- soft_stage_labels: Dict mapping stage_idx -> probability, or None if not near transition
|
||||
"""
|
||||
# Get temporal proportions as list for compute_cumulative_progress
|
||||
temporal_proportions_list = [
|
||||
self.temporal_proportions.get(name, 0.0) for name in self.subtask_names
|
||||
]
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Find which subtask this frame belongs to
|
||||
for j, (name, start_frame, end_frame) in enumerate(zip(subtask_names, subtask_start_frames, subtask_end_frames)):
|
||||
if current_frame >= start_frame and current_frame <= end_frame:
|
||||
# Found the subtask, get its global index
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else 0
|
||||
|
||||
# Compute τ_t using utility function (Paper Formula 2)
|
||||
tau = compute_tau(current_frame, start_frame, end_frame)
|
||||
|
||||
# Compute cumulative progress using utility function (Paper Formula 2)
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
tau, stage_idx, temporal_proportions_list
|
||||
)
|
||||
|
||||
# Compute soft stage labels near transitions
|
||||
soft_stage_labels = None
|
||||
frames_from_start = current_frame - start_frame
|
||||
frames_to_end = end_frame - current_frame
|
||||
|
||||
if frames_from_start < transition_smoothing_frames and j > 0:
|
||||
# Near start of stage - blend with previous stage
|
||||
blend = frames_from_start / transition_smoothing_frames
|
||||
prev_name = subtask_names[j - 1]
|
||||
prev_stage_idx = self.subtask_names.index(prev_name) if prev_name in self.subtask_names else max(0, stage_idx - 1)
|
||||
soft_stage_labels = {prev_stage_idx: 1.0 - blend, stage_idx: blend}
|
||||
|
||||
elif frames_to_end < transition_smoothing_frames and j < len(subtask_names) - 1:
|
||||
# Near end of stage - blend with next stage
|
||||
blend = frames_to_end / transition_smoothing_frames
|
||||
next_name = subtask_names[j + 1]
|
||||
next_stage_idx = self.subtask_names.index(next_name) if next_name in self.subtask_names else min(num_stages - 1, stage_idx + 1)
|
||||
soft_stage_labels = {stage_idx: blend, next_stage_idx: 1.0 - blend}
|
||||
|
||||
return stage_idx, cumulative_progress, soft_stage_labels
|
||||
|
||||
# No matching subtask found
|
||||
if current_frame < subtask_start_frames[0]:
|
||||
return 0, 0.0, None
|
||||
elif current_frame > subtask_end_frames[-1]:
|
||||
return len(self.subtask_names) - 1, 1.0, None
|
||||
else:
|
||||
# Between subtasks - use previous subtask's end state (tau = 1.0)
|
||||
for j in range(len(subtask_names) - 1):
|
||||
if current_frame > subtask_end_frames[j] and current_frame < subtask_start_frames[j + 1]:
|
||||
name = subtask_names[j]
|
||||
stage_idx = self.subtask_names.index(name) if name in self.subtask_names else j
|
||||
|
||||
# Completed subtask, so tau = 1.0
|
||||
cumulative_progress = compute_cumulative_progress_batch(
|
||||
1.0, stage_idx, temporal_proportions_list
|
||||
)
|
||||
return stage_idx, cumulative_progress, None
|
||||
|
||||
return 0, 0.0, None
|
||||
|
||||
def _compute_labels_for_sample(
|
||||
self,
|
||||
frame_idx: int,
|
||||
ep_idx: int,
|
||||
seq_len: int,
|
||||
episodes_df: pd.DataFrame,
|
||||
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor | None] | tuple[None, None, None]:
|
||||
"""Compute stage labels, progress targets, and soft stage labels for symmetric bidirectional pattern.
|
||||
|
||||
Pattern: [initial, t-4*gap, t-3*gap, t-2*gap, t-gap, t, t+gap, t+2*gap, t+3*gap]
|
||||
|
||||
Boundary handling:
|
||||
- Before episode start: clamp to frame 0 (progress ~0%)
|
||||
- After episode end: clamp to last frame (progress ~100%)
|
||||
|
||||
Soft stage labels are computed near stage transitions to mitigate discrete jumps.
|
||||
|
||||
Args:
|
||||
frame_idx: The frame index for this sample
|
||||
ep_idx: The episode index
|
||||
seq_len: Number of frames in the sequence
|
||||
episodes_df: DataFrame with episode metadata
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels):
|
||||
- stage_labels: (T,) hard stage indices
|
||||
- progress_targets: (T, 1) progress values
|
||||
- soft_stage_labels: (T, num_stages) soft probability labels, or None if no transitions nearby
|
||||
"""
|
||||
# Check if episode has valid annotations
|
||||
if ep_idx >= len(episodes_df):
|
||||
return None, None, None
|
||||
|
||||
subtask_names = episodes_df.loc[ep_idx, 'subtask_names']
|
||||
if subtask_names is None or (isinstance(subtask_names, float) and pd.isna(subtask_names)):
|
||||
return None, None, None
|
||||
|
||||
subtask_start_frames = episodes_df.loc[ep_idx, 'subtask_start_frames']
|
||||
subtask_end_frames = episodes_df.loc[ep_idx, 'subtask_end_frames']
|
||||
ep_start = self.dataset_meta.episodes[ep_idx]["dataset_from_index"]
|
||||
ep_end = self.dataset_meta.episodes[ep_idx]["dataset_to_index"]
|
||||
ep_length = ep_end - ep_start
|
||||
last_valid_frame = ep_length - 1
|
||||
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
# Generate labels for each frame in the sequence
|
||||
stage_labels = []
|
||||
progress_targets = []
|
||||
soft_labels_list = [] # List of soft label dicts (or None)
|
||||
has_any_soft_labels = False
|
||||
|
||||
# Symmetric pattern: initial + 4 before + current + 3 after = 9 frames
|
||||
num_before = 4
|
||||
num_after = 3
|
||||
|
||||
for i in range(seq_len):
|
||||
if i == 0:
|
||||
# Position 0: Initial frame of the episode
|
||||
current_frame = 0 # Relative to episode start
|
||||
elif i <= num_before:
|
||||
# Positions 1-4: frames before current (with clamping to first frame)
|
||||
offset = -(num_before - i + 1) * self.config.frame_gap
|
||||
current_frame = max(0, frame_idx + offset - ep_start)
|
||||
elif i == num_before + 1:
|
||||
# Position 5: current frame
|
||||
current_frame = frame_idx - ep_start
|
||||
else:
|
||||
# Positions 6-8: frames after current (with clamping to last frame)
|
||||
offset = (i - num_before - 1) * self.config.frame_gap
|
||||
current_frame = min(last_valid_frame, frame_idx + offset - ep_start)
|
||||
|
||||
stage_idx, cumulative_progress, soft_stage_labels = self._compute_stage_and_progress_for_frame(
|
||||
current_frame, subtask_names, subtask_start_frames, subtask_end_frames
|
||||
)
|
||||
|
||||
stage_labels.append(stage_idx)
|
||||
progress_targets.append(cumulative_progress)
|
||||
soft_labels_list.append(soft_stage_labels)
|
||||
if soft_stage_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
stage_labels = torch.tensor(stage_labels, dtype=torch.long)
|
||||
progress_targets = torch.tensor(progress_targets, dtype=torch.float32).unsqueeze(-1)
|
||||
|
||||
# Convert soft labels to tensor if any exist
|
||||
soft_stage_labels_tensor = None
|
||||
if has_any_soft_labels:
|
||||
soft_stage_labels_tensor = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for i, soft_dict in enumerate(soft_labels_list):
|
||||
if soft_dict is not None:
|
||||
for stage_idx, prob in soft_dict.items():
|
||||
soft_stage_labels_tensor[i, stage_idx] = prob
|
||||
else:
|
||||
# Use hard one-hot label
|
||||
soft_stage_labels_tensor[i, stage_labels[i]] = 1.0
|
||||
|
||||
return stage_labels, progress_targets, soft_stage_labels_tensor
|
||||
|
||||
def _generate_stage_and_progress_labels(self, frame_index, episode_index, video_features):
|
||||
"""Generate stage labels, progress targets, and soft stage labels from subtask annotations.
|
||||
|
||||
Args:
|
||||
frame_index: Current frame index or tensor of indices
|
||||
episode_index: Episode index or tensor of indices
|
||||
video_features: Video features tensor to determine sequence length
|
||||
|
||||
Returns:
|
||||
Tuple of (stage_labels, progress_targets, soft_stage_labels) or (None, None, None) if no annotations.
|
||||
- stage_labels: (B, T) hard stage indices
|
||||
- progress_targets: (B, T, 1) progress values
|
||||
- soft_stage_labels: (B, T, num_stages) soft probability labels, or None
|
||||
"""
|
||||
if self.temporal_proportions is None or episode_index is None:
|
||||
return None, None, None
|
||||
|
||||
# Normalize inputs to numpy arrays
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine sequence length
|
||||
if video_features is not None and video_features.dim() >= 2:
|
||||
seq_len = video_features.shape[1]
|
||||
else:
|
||||
seq_len = 1
|
||||
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
num_stages = len(self.subtask_names)
|
||||
|
||||
all_stage_labels = []
|
||||
all_progress_targets = []
|
||||
all_soft_stage_labels = []
|
||||
has_any_soft_labels = False
|
||||
|
||||
for ep_idx, frame_idx in zip(episode_indices.tolist(), frame_indices.tolist()):
|
||||
stage_labels, progress_targets, soft_labels = self._compute_labels_for_sample(
|
||||
int(frame_idx), int(ep_idx), seq_len, episodes_df
|
||||
)
|
||||
|
||||
if stage_labels is None:
|
||||
all_stage_labels.append(torch.zeros(seq_len, dtype=torch.long))
|
||||
all_progress_targets.append(torch.zeros(seq_len, 1, dtype=torch.float32))
|
||||
all_soft_stage_labels.append(None)
|
||||
else:
|
||||
all_stage_labels.append(stage_labels)
|
||||
all_progress_targets.append(progress_targets)
|
||||
all_soft_stage_labels.append(soft_labels)
|
||||
if soft_labels is not None:
|
||||
has_any_soft_labels = True
|
||||
|
||||
stacked_stage_labels = torch.stack(all_stage_labels, dim=0)
|
||||
stacked_progress_targets = torch.stack(all_progress_targets, dim=0)
|
||||
|
||||
# Stack soft labels if any exist
|
||||
stacked_soft_labels = None
|
||||
if has_any_soft_labels:
|
||||
soft_labels_tensors = []
|
||||
for i, soft_labels in enumerate(all_soft_stage_labels):
|
||||
if soft_labels is not None:
|
||||
soft_labels_tensors.append(soft_labels)
|
||||
else:
|
||||
# Create one-hot from hard labels
|
||||
one_hot = torch.zeros(seq_len, num_stages, dtype=torch.float32)
|
||||
for t in range(seq_len):
|
||||
one_hot[t, all_stage_labels[i][t]] = 1.0
|
||||
soft_labels_tensors.append(one_hot)
|
||||
stacked_soft_labels = torch.stack(soft_labels_tensors, dim=0)
|
||||
|
||||
return stacked_stage_labels, stacked_progress_targets, stacked_soft_labels
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
"""Encode images, text, and normalize states in the transition."""
|
||||
|
||||
new_transition = transition.copy() if hasattr(transition, 'copy') else dict(transition)
|
||||
observation = new_transition.get(TransitionKey.OBSERVATION)
|
||||
|
||||
image = observation.get(self.image_key)
|
||||
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
video_features = self._encode_images_batch(image)
|
||||
observation['video_features'] = video_features
|
||||
|
||||
# Extract state and pad to max_state_dim (already normalized by NormalizerProcessorStep)
|
||||
state_key = self.config.state_key
|
||||
state_data = observation.get(state_key)
|
||||
|
||||
if isinstance(state_data, torch.Tensor):
|
||||
state_tensor = state_data.float()
|
||||
else:
|
||||
state_tensor = torch.tensor(state_data, dtype=torch.float32)
|
||||
|
||||
observation['state_features'] = pad_state_to_max_dim(state_tensor, self.config.max_state_dim)
|
||||
|
||||
comp_data = new_transition.get(TransitionKey.COMPLEMENTARY_DATA, {})
|
||||
|
||||
# Get task description from dataset (complementary_data["task"])
|
||||
task = comp_data.get('task')
|
||||
if isinstance(task, list):
|
||||
# If batch, take first task (assuming same task for all items in batch)
|
||||
task = task[0] if task else ""
|
||||
|
||||
# Encode text with CLIP
|
||||
batch_size = video_features.shape[0]
|
||||
observation['text_features'] = self._encode_text_clip(task, batch_size)
|
||||
|
||||
frame_index = comp_data.get('index')
|
||||
episode_index = comp_data.get('episode_index')
|
||||
|
||||
if frame_index is None:
|
||||
raise ValueError("Frame index ('index') not found in COMPLEMENTARY_DATA")
|
||||
if episode_index is None:
|
||||
raise ValueError("Episode index ('episode_index') not found in COMPLEMENTARY_DATA")
|
||||
|
||||
# Compute episode metadata if dataset_meta is available
|
||||
if self.dataset_meta is not None:
|
||||
frame_indices = np.atleast_1d(np.asarray(from_tensor_to_numpy(frame_index)))
|
||||
episode_indices = self._get_episode_indices(frame_indices, episode_index)
|
||||
|
||||
# Determine number of frames from video features
|
||||
if video_features.dim() >= 2:
|
||||
num_frames = video_features.shape[1]
|
||||
else:
|
||||
num_frames = 1
|
||||
|
||||
abs_indices, remaining, ep_lengths = self._compute_episode_metadata(
|
||||
frame_indices, episode_indices, num_frames
|
||||
)
|
||||
observation['absolute_frame_indices'] = abs_indices
|
||||
observation['remaining_length'] = remaining
|
||||
observation['episode_length'] = ep_lengths
|
||||
|
||||
# Generate stage labels, progress targets, and soft stage labels from subtask annotations
|
||||
if self.temporal_proportions is not None and self.dataset_meta is not None:
|
||||
stage_labels, progress_targets, soft_stage_labels = self._generate_stage_and_progress_labels(
|
||||
frame_index, episode_index, video_features
|
||||
)
|
||||
if stage_labels is not None:
|
||||
observation['stage_labels'] = stage_labels
|
||||
observation['progress_targets'] = progress_targets
|
||||
if soft_stage_labels is not None:
|
||||
observation['soft_stage_labels'] = soft_stage_labels
|
||||
|
||||
new_transition[TransitionKey.OBSERVATION] = observation
|
||||
return new_transition
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_images_batch(self, images: np.ndarray) -> torch.Tensor:
|
||||
"""Encode a batch of images using CLIP.
|
||||
|
||||
Args:
|
||||
images: Batched images with shape: (B, T, C, H, W)
|
||||
|
||||
Returns:
|
||||
Encoded feature vectors with shape (B, T, 512)
|
||||
"""
|
||||
|
||||
batch_size, seq_length = images.shape[0], images.shape[1]
|
||||
images = images.reshape(batch_size * seq_length, *images.shape[2:])
|
||||
|
||||
# Convert to list of PIL images
|
||||
num_frames = images.shape[0]
|
||||
images_list = []
|
||||
for i in range(num_frames):
|
||||
img = images[i]
|
||||
if img.shape[0] in [1, 3]: # Channel first (C, H, W)
|
||||
img = img.transpose(1, 2, 0)
|
||||
|
||||
# Handle single channel
|
||||
if img.shape[-1] == 1:
|
||||
img = np.repeat(img, 3, axis=-1)
|
||||
|
||||
# Convert to uint8
|
||||
if img.dtype != np.uint8:
|
||||
img = (img * 255).astype(np.uint8) if img.max() <= 1.0 else img.astype(np.uint8)
|
||||
|
||||
images_list.append(Image.fromarray(img))
|
||||
|
||||
# Encode each batch
|
||||
all_embeddings = []
|
||||
for i in range(0, num_frames, self.config.clip_batch_size):
|
||||
batch_imgs = images_list[i:i + self.config.clip_batch_size]
|
||||
|
||||
# Process with CLIP
|
||||
inputs = self.clip_processor(images=batch_imgs, return_tensors="pt")
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get image embeddings
|
||||
embeddings = self.clip_model.get_image_features(**inputs).detach().cpu()
|
||||
|
||||
# Handle single frame case
|
||||
if embeddings.dim() == 1:
|
||||
embeddings = embeddings.unsqueeze(0)
|
||||
|
||||
all_embeddings.append(embeddings)
|
||||
|
||||
# Concatenate all embeddings
|
||||
all_embeddings = torch.cat(all_embeddings) # (B*T, 512)
|
||||
|
||||
# Reshape back
|
||||
all_embeddings = all_embeddings.reshape(batch_size, seq_length, -1) # (B, T, 512)
|
||||
|
||||
return all_embeddings
|
||||
|
||||
@torch.no_grad()
|
||||
def _encode_text_clip(self, text: str, batch_size: int) -> torch.Tensor:
|
||||
"""Encode text using CLIP text encoder (per SARM paper A.4).
|
||||
|
||||
Args:
|
||||
text: Task description text to encode
|
||||
batch_size: Batch size to replicate for
|
||||
|
||||
Returns:
|
||||
Encoded text features with shape (B, 512)
|
||||
"""
|
||||
# Use CLIP's tokenizer directly for text
|
||||
tokenizer = self.clip_processor.tokenizer
|
||||
inputs = tokenizer([text], return_tensors="pt", padding=True, truncation=True)
|
||||
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
||||
|
||||
# Get text features from CLIP
|
||||
text_embedding = self.clip_model.get_text_features(**inputs).detach().cpu()
|
||||
|
||||
# Replicate for batch (B, 512)
|
||||
text_embedding = text_embedding.expand(batch_size, -1)
|
||||
|
||||
return text_embedding
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""Add encoded features to the observation features."""
|
||||
features[PipelineFeatureType.OBSERVATION]['video_features'] = PolicyFeature(
|
||||
type=FeatureType.VISUAL,
|
||||
shape=(self.config.num_frames, self.config.image_dim)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['text_features'] = PolicyFeature(
|
||||
type=FeatureType.LANGUAGE,
|
||||
shape=(self.config.text_dim,)
|
||||
)
|
||||
features[PipelineFeatureType.OBSERVATION]['state_features'] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(self.config.num_frames, self.config.max_state_dim)
|
||||
)
|
||||
return features
|
||||
|
||||
|
||||
def make_sarm_pre_post_processors(
|
||||
config: SARMConfig,
|
||||
dataset_stats: dict[str, dict[str, torch.Tensor]] | None = None,
|
||||
dataset_meta = None,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
]:
|
||||
"""
|
||||
Create pre-processor and post-processor pipelines for SARM.
|
||||
|
||||
The pre-processing pipeline:
|
||||
1. Adds batch dimension
|
||||
2. Normalizes observation.state using NormalizerProcessorStep (MEAN_STD)
|
||||
3. SARMEncodingProcessorStep:
|
||||
- Encodes images with CLIP
|
||||
- Pads states to max_state_dim
|
||||
- Encodes text with CLIP
|
||||
4. Moves data to device
|
||||
|
||||
The post-processing pipeline:
|
||||
1. Moves data to CPU
|
||||
"""
|
||||
return (
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]](
|
||||
steps=[
|
||||
AddBatchDimensionProcessorStep(),
|
||||
NormalizerProcessorStep(
|
||||
features={**config.input_features, **config.output_features},
|
||||
norm_map=config.normalization_mapping,
|
||||
stats=dataset_stats,
|
||||
),
|
||||
SARMEncodingProcessorStep(
|
||||
config=config,
|
||||
dataset_meta=dataset_meta,
|
||||
dataset_stats=dataset_stats
|
||||
),
|
||||
DeviceProcessorStep(device=config.device),
|
||||
],
|
||||
name=POLICY_PREPROCESSOR_DEFAULT_NAME,
|
||||
),
|
||||
PolicyProcessorPipeline[PolicyAction, PolicyAction](
|
||||
steps=[DeviceProcessorStep(device="cpu")],
|
||||
name=POLICY_POSTPROCESSOR_DEFAULT_NAME,
|
||||
to_transition=policy_action_to_transition,
|
||||
to_output=transition_to_policy_action,
|
||||
),
|
||||
)
|
||||
@@ -0,0 +1,257 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from typing import Sequence, Any
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# Pydantic Models for SARM-style Annotation
|
||||
class Timestamp(BaseModel):
|
||||
"""Timestamp in MM:SS or SS format"""
|
||||
start: str = Field(description="Start timestamp (MM:SS or just seconds)")
|
||||
end: str = Field(description="End timestamp (MM:SS or just seconds)")
|
||||
|
||||
|
||||
class Subtask(BaseModel):
|
||||
"""Individual subtask/stage - must use EXACT names from provided list"""
|
||||
name: str = Field(description="Subtask name - MUST match one from the predefined list exactly")
|
||||
timestamps: Timestamp
|
||||
|
||||
|
||||
class SubtaskAnnotation(BaseModel):
|
||||
"""Complete annotation for a robot manipulation episode"""
|
||||
subtasks: list[Subtask] = Field(description="List of all subtasks in temporal order")
|
||||
|
||||
|
||||
def compute_temporal_proportions(annotations: dict[int, Any], fps: int = 30) -> dict[str, float]:
|
||||
"""
|
||||
Compute dataset-level temporal proportions (priors) for each subtask.
|
||||
|
||||
Implements SARM Paper Formula (1):
|
||||
ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
where:
|
||||
- M is the number of trajectories (episodes)
|
||||
- L_{i,k} is the duration of subtask k in trajectory i
|
||||
- T_i is the total duration of trajectory i
|
||||
|
||||
This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of their absolute length.
|
||||
|
||||
Args:
|
||||
annotations: Dict mapping episode index to SubtaskAnnotation object.
|
||||
Each annotation has a .subtasks list where each subtask has:
|
||||
- .name: subtask name
|
||||
- .timestamps.start: start time as "MM:SS" string
|
||||
- .timestamps.end: end time as "MM:SS" string
|
||||
fps: Frames per second (unused, kept for API compatibility)
|
||||
|
||||
Returns:
|
||||
Dict mapping subtask name to its temporal proportion (ᾱ_k).
|
||||
Proportions are normalized to sum to 1.0.
|
||||
"""
|
||||
subtask_proportions: dict[str, list[float]] = {}
|
||||
|
||||
for annotation in annotations.values():
|
||||
total_duration = 0
|
||||
durations: dict[str, int] = {}
|
||||
|
||||
for subtask in annotation.subtasks:
|
||||
start_parts = subtask.timestamps.start.split(":")
|
||||
end_parts = subtask.timestamps.end.split(":")
|
||||
|
||||
start_seconds = int(start_parts[0]) * 60 + int(start_parts[1]) if len(start_parts) == 2 else int(start_parts[0])
|
||||
end_seconds = int(end_parts[0]) * 60 + int(end_parts[1]) if len(end_parts) == 2 else int(end_parts[0])
|
||||
|
||||
duration = end_seconds - start_seconds
|
||||
durations[subtask.name] = duration
|
||||
total_duration += duration
|
||||
|
||||
# Calculate L_{i,k} / T_i for each subtask in this trajectory
|
||||
if total_duration > 0:
|
||||
for name, duration in durations.items():
|
||||
if name not in subtask_proportions:
|
||||
subtask_proportions[name] = []
|
||||
subtask_proportions[name].append(duration / total_duration)
|
||||
|
||||
if not subtask_proportions:
|
||||
return {}
|
||||
|
||||
# Average across trajectories: (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
avg_proportions = {
|
||||
name: sum(props) / len(props)
|
||||
for name, props in subtask_proportions.items()
|
||||
}
|
||||
|
||||
# Normalize to ensure sum = 1
|
||||
total = sum(avg_proportions.values())
|
||||
if total > 0:
|
||||
avg_proportions = {name: prop / total for name, prop in avg_proportions.items()}
|
||||
|
||||
return avg_proportions
|
||||
|
||||
|
||||
def compute_tau(
|
||||
current_frame: int | float,
|
||||
subtask_start: int | float,
|
||||
subtask_end: int | float,
|
||||
) -> float:
|
||||
"""
|
||||
Compute within-subtask normalized time τ_t.
|
||||
|
||||
Implements part of SARM Paper Formula (2):
|
||||
τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
|
||||
where:
|
||||
- t is the current frame
|
||||
- s_k is the start frame of subtask k
|
||||
- e_k is the end frame of subtask k
|
||||
|
||||
Args:
|
||||
current_frame: Current frame index (t)
|
||||
subtask_start: Start frame of the subtask (s_k)
|
||||
subtask_end: End frame of the subtask (e_k)
|
||||
|
||||
Returns:
|
||||
Within-subtask progress τ_t ∈ [0, 1]
|
||||
"""
|
||||
subtask_duration = subtask_end - subtask_start
|
||||
|
||||
if subtask_duration <= 0:
|
||||
return 1.0
|
||||
|
||||
tau = (current_frame - subtask_start) / subtask_duration
|
||||
|
||||
return float(np.clip(tau, 0.0, 1.0))
|
||||
|
||||
|
||||
def compute_cumulative_progress_batch(
|
||||
tau: torch.Tensor | float,
|
||||
stage_indices: torch.Tensor | int,
|
||||
alpha: torch.Tensor | Sequence[float],
|
||||
cumulative_prior: torch.Tensor | None = None,
|
||||
) -> torch.Tensor | float:
|
||||
"""
|
||||
Compute cumulative normalized progress from within-subtask progress.
|
||||
|
||||
This function implements the core formula used in SARM for both:
|
||||
|
||||
**Formula 2 (Training labels):**
|
||||
y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
|
||||
|
||||
Used to compute ground-truth progress labels from subtask annotations.
|
||||
- τ_t comes from annotated frame position: τ_t = (t - s_k) / (e_k - s_k)
|
||||
- k is the known subtask from annotations
|
||||
|
||||
**Formula 4 (Inference predictions):**
|
||||
ŷ_{1:N} = P̂_{k-1, 1:N} + ᾱ_{k, 1:N} × τ̂_{1:N} ∈ [0, 1]
|
||||
|
||||
Used to convert model outputs to cumulative progress during inference.
|
||||
- τ̂ comes from the subtask MLP head (conditioned on predicted stage)
|
||||
- k = Ŝ is the predicted stage from Formula 3: Ŝ = argmax(softmax(Ψ))
|
||||
|
||||
The formulas are mathematically identical; only the source of inputs differs:
|
||||
- Training: τ and k from annotations → ground-truth labels
|
||||
- Inference: τ̂ and Ŝ from model → predicted progress
|
||||
|
||||
where:
|
||||
- P_{k-1} = Σ_{j=1}^{k-1} ᾱ_j is the cumulative prior (sum of previous proportions)
|
||||
- ᾱ_k is the temporal proportion for subtask k (from Formula 1)
|
||||
- τ is within-subtask progress ∈ [0, 1]
|
||||
|
||||
This ensures:
|
||||
- y at start of subtask k = P_{k-1}
|
||||
- y at end of subtask k = P_k
|
||||
|
||||
Supports both scalar and batched tensor inputs:
|
||||
- Scalar: tau (float), stage_indices (int), alpha (list/sequence)
|
||||
- Batch: tau (Tensor), stage_indices (Tensor), alpha (Tensor), cumulative_prior (Tensor)
|
||||
|
||||
Args:
|
||||
tau: Within-subtask progress τ ∈ [0, 1].
|
||||
For training: computed from frame position in annotated subtask.
|
||||
For inference: predicted by subtask MLP head.
|
||||
Scalar float or Tensor with shape (..., 1)
|
||||
stage_indices: Index of current subtask k (0-indexed).
|
||||
For training: known from annotations.
|
||||
For inference: predicted via argmax(stage_probs) (Formula 3).
|
||||
Scalar int or Tensor with shape (...)
|
||||
alpha: Temporal proportions ᾱ with shape (num_stages,) or Sequence[float].
|
||||
Computed from dataset annotations using Formula 1.
|
||||
cumulative_prior: Optional. Cumulative priors P with shape (num_stages + 1,)
|
||||
where cumulative_prior[k] = P_k = Σ_{j=1}^{k} ᾱ_j.
|
||||
If None, will be computed from alpha.
|
||||
|
||||
Returns:
|
||||
Cumulative progress y ∈ [0, 1].
|
||||
Scalar float if inputs are scalar, otherwise Tensor with shape (..., 1)
|
||||
"""
|
||||
if not isinstance(tau, torch.Tensor):
|
||||
if not alpha:
|
||||
raise ValueError("alpha (temporal_proportions) cannot be empty")
|
||||
|
||||
if isinstance(alpha, torch.Tensor):
|
||||
alpha_list = alpha.tolist()
|
||||
else:
|
||||
alpha_list = list(alpha)
|
||||
|
||||
if stage_indices < 0 or stage_indices >= len(alpha_list):
|
||||
raise ValueError(
|
||||
f"stage_indices {stage_indices} out of range "
|
||||
f"for {len(alpha_list)} subtasks"
|
||||
)
|
||||
|
||||
# P_{k-1} = sum of proportions for subtasks 0 to k-1
|
||||
P_k_minus_1 = sum(alpha_list[:stage_indices])
|
||||
|
||||
# ᾱ_k = proportion for current subtask
|
||||
alpha_k = alpha_list[stage_indices]
|
||||
|
||||
# y_t = P_{k-1} + ᾱ_k × τ_t
|
||||
y_t = P_k_minus_1 + alpha_k * tau
|
||||
|
||||
return float(np.clip(y_t, 0.0, 1.0))
|
||||
|
||||
if not isinstance(alpha, torch.Tensor):
|
||||
alpha = torch.tensor(alpha, dtype=torch.float32)
|
||||
|
||||
# Compute cumulative_prior if not provided
|
||||
if cumulative_prior is None:
|
||||
cumulative_prior = torch.zeros(len(alpha) + 1, dtype=alpha.dtype, device=alpha.device)
|
||||
cumulative_prior[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
# P_{k-1} for each predicted stage
|
||||
P_k_minus_1 = cumulative_prior[stage_indices]
|
||||
|
||||
# ᾱ_k for each predicted stage
|
||||
alpha_k = alpha[stage_indices]
|
||||
|
||||
# ŷ = P_{k-1} + ᾱ_k × τ̂
|
||||
progress = P_k_minus_1.unsqueeze(-1) + alpha_k.unsqueeze(-1) * tau
|
||||
|
||||
return progress
|
||||
|
||||
def pad_state_to_max_dim(state: torch.Tensor, max_state_dim: int) -> torch.Tensor:
|
||||
"""Pad the state tensor's last dimension to max_state_dim with zeros."""
|
||||
current_dim = state.shape[-1]
|
||||
if current_dim >= max_state_dim:
|
||||
return state[..., :max_state_dim] # Truncate if larger
|
||||
|
||||
# Pad with zeros on the right
|
||||
padding = (0, max_state_dim - current_dim) # (left, right) for last dim
|
||||
return F.pad(state, padding, mode='constant', value=0)
|
||||
@@ -230,6 +230,10 @@ def validate_visual_features_consistency(
|
||||
) -> None:
|
||||
"""
|
||||
Validates visual feature consistency between a policy config and provided dataset/environment features.
|
||||
|
||||
Validation passes if EITHER:
|
||||
- Policy's expected visuals are a subset of dataset (policy uses some cameras, dataset has more)
|
||||
- Dataset's provided visuals are a subset of policy (policy declares extras for flexibility)
|
||||
|
||||
Args:
|
||||
cfg (PreTrainedConfig): The model or policy configuration containing input_features and type.
|
||||
@@ -237,5 +241,11 @@ def validate_visual_features_consistency(
|
||||
"""
|
||||
expected_visuals = {k for k, v in cfg.input_features.items() if v.type == FeatureType.VISUAL}
|
||||
provided_visuals = {k for k, v in features.items() if v.type == FeatureType.VISUAL}
|
||||
if not provided_visuals.issubset(expected_visuals):
|
||||
|
||||
# Accept if either direction is a subset
|
||||
policy_subset_of_dataset = expected_visuals.issubset(provided_visuals)
|
||||
dataset_subset_of_policy = provided_visuals.issubset(expected_visuals)
|
||||
|
||||
if not (policy_subset_of_dataset or dataset_subset_of_policy):
|
||||
raise_feature_mismatch_error(provided_visuals, expected_visuals)
|
||||
|
||||
|
||||
@@ -170,8 +170,9 @@ def _extract_complementary_data(batch: dict[str, Any]) -> dict[str, Any]:
|
||||
task_key = {"task": batch["task"]} if "task" in batch else {}
|
||||
index_key = {"index": batch["index"]} if "index" in batch else {}
|
||||
task_index_key = {"task_index": batch["task_index"]} if "task_index" in batch else {}
|
||||
episode_index_key = {"episode_index": batch["episode_index"]} if "episode_index" in batch else {}
|
||||
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key}
|
||||
return {**pad_keys, **task_key, **index_key, **task_index_key, **episode_index_key}
|
||||
|
||||
|
||||
def create_transition(
|
||||
|
||||
@@ -78,7 +78,7 @@ from lerobot.transport.utils import (
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
@@ -398,7 +398,7 @@ def act_with_policy(
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
precise_sleep(1 / cfg.env.fps - dt_time)
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
@@ -74,7 +74,7 @@ from lerobot.teleoperators import (
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -114,7 +114,7 @@ def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> No
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
precise_sleep(0.015)
|
||||
busy_wait(0.015)
|
||||
|
||||
|
||||
class RobotEnv(gym.Env):
|
||||
@@ -238,7 +238,7 @@ class RobotEnv(gym.Env):
|
||||
reset_follower_position(self.robot, np.array(self.reset_pose))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
|
||||
precise_sleep(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
@@ -713,7 +713,7 @@ def control_loop(
|
||||
transition = env_processor(transition)
|
||||
|
||||
# Maintain fps timing
|
||||
precise_sleep(dt - (time.perf_counter() - step_start_time))
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Pushing dataset to hub")
|
||||
@@ -745,7 +745,7 @@ def replay_trajectory(
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -1,55 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from ..config import RobotConfig
|
||||
|
||||
_GAINS: dict[str, dict[str, list[float]]] = {
|
||||
"left_leg": {
|
||||
"kp": [150, 150, 150, 300, 40, 40],
|
||||
"kd": [2, 2, 2, 4, 2, 2],
|
||||
}, # pitch, roll, yaw, knee, ankle_pitch, ankle_roll
|
||||
"right_leg": {"kp": [150, 150, 150, 300, 40, 40], "kd": [2, 2, 2, 4, 2, 2]},
|
||||
"waist": {"kp": [250, 250, 250], "kd": [5, 5, 5]}, # yaw, roll, pitch
|
||||
"left_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]}, # shoulder_pitch/roll/yaw, elbow
|
||||
"left_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]}, # roll, pitch, yaw
|
||||
"right_arm": {"kp": [80, 80, 80, 80], "kd": [3, 3, 3, 3]},
|
||||
"right_wrist": {"kp": [40, 40, 40], "kd": [1.5, 1.5, 1.5]},
|
||||
"other": {"kp": [80, 80, 80, 80, 80, 80], "kd": [3, 3, 3, 3, 3, 3]},
|
||||
}
|
||||
|
||||
|
||||
def _build_gains() -> tuple[list[float], list[float]]:
|
||||
"""Build kp and kd lists from body-part groupings."""
|
||||
kp = [v for g in _GAINS.values() for v in g["kp"]]
|
||||
kd = [v for g in _GAINS.values() for v in g["kd"]]
|
||||
return kp, kd
|
||||
|
||||
|
||||
_DEFAULT_KP, _DEFAULT_KD = _build_gains()
|
||||
|
||||
|
||||
@RobotConfig.register_subclass("unitree_g1")
|
||||
@dataclass
|
||||
class UnitreeG1Config(RobotConfig):
|
||||
kp: list[float] = field(default_factory=lambda: _DEFAULT_KP.copy())
|
||||
kd: list[float] = field(default_factory=lambda: _DEFAULT_KD.copy())
|
||||
|
||||
control_dt: float = 1.0 / 250.0 # 250Hz
|
||||
|
||||
# socket config for ZMQ bridge
|
||||
robot_ip: str = "172.18.129.215"
|
||||
@@ -1,89 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
# ruff: noqa: N801, N815
|
||||
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
class G1_29_JointArmIndex(IntEnum):
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
|
||||
class G1_29_JointIndex(IntEnum):
|
||||
# Left leg
|
||||
kLeftHipPitch = 0
|
||||
kLeftHipRoll = 1
|
||||
kLeftHipYaw = 2
|
||||
kLeftKnee = 3
|
||||
kLeftAnklePitch = 4
|
||||
kLeftAnkleRoll = 5
|
||||
|
||||
# Right leg
|
||||
kRightHipPitch = 6
|
||||
kRightHipRoll = 7
|
||||
kRightHipYaw = 8
|
||||
kRightKnee = 9
|
||||
kRightAnklePitch = 10
|
||||
kRightAnkleRoll = 11
|
||||
|
||||
kWaistYaw = 12
|
||||
kWaistRoll = 13
|
||||
kWaistPitch = 14
|
||||
|
||||
# Left arm
|
||||
kLeftShoulderPitch = 15
|
||||
kLeftShoulderRoll = 16
|
||||
kLeftShoulderYaw = 17
|
||||
kLeftElbow = 18
|
||||
kLeftWristRoll = 19
|
||||
kLeftWristPitch = 20
|
||||
kLeftWristyaw = 21
|
||||
|
||||
# Right arm
|
||||
kRightShoulderPitch = 22
|
||||
kRightShoulderRoll = 23
|
||||
kRightShoulderYaw = 24
|
||||
kRightElbow = 25
|
||||
kRightWristRoll = 26
|
||||
kRightWristPitch = 27
|
||||
kRightWristYaw = 28
|
||||
|
||||
# not used
|
||||
kNotUsedJoint0 = 29
|
||||
kNotUsedJoint1 = 30
|
||||
kNotUsedJoint2 = 31
|
||||
kNotUsedJoint3 = 32
|
||||
kNotUsedJoint4 = 33
|
||||
kNotUsedJoint5 = 34
|
||||
@@ -1,212 +0,0 @@
|
||||
#!/usr/bin/env python3
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
DDS-to-ZMQ bridge server for Unitree G1 robot.
|
||||
|
||||
This server runs on the robot and forwards:
|
||||
- Robot state (LowState) from DDS to ZMQ (for remote clients)
|
||||
- Robot commands (LowCmd) from ZMQ to DDS (from remote clients)
|
||||
|
||||
Uses JSON for secure serialization instead of pickle.
|
||||
"""
|
||||
|
||||
import base64
|
||||
import contextlib
|
||||
import json
|
||||
import threading
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
from unitree_sdk2py.comm.motion_switcher.motion_switcher_client import MotionSwitcherClient
|
||||
from unitree_sdk2py.core.channel import ChannelFactoryInitialize, ChannelPublisher, ChannelSubscriber
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import LowCmd_ as hg_LowCmd, LowState_ as hg_LowState
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd" # action to robot
|
||||
kTopicLowState = "rt/lowstate" # observation from robot
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
NUM_MOTORS = 35
|
||||
|
||||
|
||||
def lowstate_to_dict(msg: hg_LowState) -> dict[str, Any]:
|
||||
"""Convert LowState SDK message to a JSON-serializable dictionary."""
|
||||
motor_states = []
|
||||
for i in range(NUM_MOTORS):
|
||||
temp = msg.motor_state[i].temperature
|
||||
avg_temp = float(sum(temp) / len(temp)) if isinstance(temp, list) else float(temp)
|
||||
motor_states.append(
|
||||
{
|
||||
"q": float(msg.motor_state[i].q),
|
||||
"dq": float(msg.motor_state[i].dq),
|
||||
"tau_est": float(msg.motor_state[i].tau_est),
|
||||
"temperature": avg_temp,
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"motor_state": motor_states,
|
||||
"imu_state": {
|
||||
"quaternion": [float(x) for x in msg.imu_state.quaternion],
|
||||
"gyroscope": [float(x) for x in msg.imu_state.gyroscope],
|
||||
"accelerometer": [float(x) for x in msg.imu_state.accelerometer],
|
||||
"rpy": [float(x) for x in msg.imu_state.rpy],
|
||||
"temperature": float(msg.imu_state.temperature),
|
||||
},
|
||||
# Encode bytes as base64 for JSON compatibility
|
||||
"wireless_remote": base64.b64encode(bytes(msg.wireless_remote)).decode("ascii"),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
}
|
||||
|
||||
|
||||
def dict_to_lowcmd(data: dict[str, Any]) -> hg_LowCmd:
|
||||
"""Convert dictionary back to LowCmd SDK message."""
|
||||
cmd = unitree_hg_msg_dds__LowCmd_()
|
||||
cmd.mode_pr = data.get("mode_pr", 0)
|
||||
cmd.mode_machine = data.get("mode_machine", 0)
|
||||
|
||||
for i, motor_data in enumerate(data.get("motor_cmd", [])):
|
||||
cmd.motor_cmd[i].mode = motor_data.get("mode", 0)
|
||||
cmd.motor_cmd[i].q = motor_data.get("q", 0.0)
|
||||
cmd.motor_cmd[i].dq = motor_data.get("dq", 0.0)
|
||||
cmd.motor_cmd[i].kp = motor_data.get("kp", 0.0)
|
||||
cmd.motor_cmd[i].kd = motor_data.get("kd", 0.0)
|
||||
cmd.motor_cmd[i].tau = motor_data.get("tau", 0.0)
|
||||
|
||||
return cmd
|
||||
|
||||
|
||||
def state_forward_loop(
|
||||
lowstate_sub: ChannelSubscriber,
|
||||
lowstate_sock: zmq.Socket,
|
||||
state_period: float,
|
||||
shutdown_event: threading.Event,
|
||||
) -> None:
|
||||
"""Read observation from DDS and forward to ZMQ clients."""
|
||||
last_state_time = 0.0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
# read from DDS
|
||||
msg = lowstate_sub.Read()
|
||||
if msg is None:
|
||||
continue
|
||||
|
||||
now = time.time()
|
||||
# optional downsampling (if robot dds rate > state_period)
|
||||
if now - last_state_time >= state_period:
|
||||
# Convert to dict and serialize with JSON
|
||||
state_dict = lowstate_to_dict(msg)
|
||||
payload = json.dumps({"topic": kTopicLowState, "data": state_dict}).encode("utf-8")
|
||||
# if no subscribers / tx buffer full, just drop
|
||||
with contextlib.suppress(zmq.Again):
|
||||
lowstate_sock.send(payload, zmq.NOBLOCK)
|
||||
last_state_time = now
|
||||
|
||||
|
||||
def cmd_forward_loop(
|
||||
lowcmd_sock: zmq.Socket,
|
||||
lowcmd_pub_debug: ChannelPublisher,
|
||||
crc: CRC,
|
||||
) -> None:
|
||||
"""Receive commands from ZMQ and forward to DDS."""
|
||||
while True:
|
||||
try:
|
||||
payload = lowcmd_sock.recv()
|
||||
except zmq.ContextTerminated:
|
||||
break
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
|
||||
topic = msg_dict.get("topic", "")
|
||||
cmd_data = msg_dict.get("data", {})
|
||||
|
||||
# Reconstruct LowCmd object from dict
|
||||
cmd = dict_to_lowcmd(cmd_data)
|
||||
|
||||
# recompute crc
|
||||
cmd.crc = crc.Crc(cmd)
|
||||
|
||||
if topic == kTopicLowCommand_Debug:
|
||||
lowcmd_pub_debug.Write(cmd)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
"""Main entry point for the robot server bridge."""
|
||||
# initialize DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
# stop all active publishers on the robot
|
||||
msc = MotionSwitcherClient()
|
||||
msc.SetTimeout(5.0)
|
||||
msc.Init()
|
||||
|
||||
status, result = msc.CheckMode()
|
||||
while result is not None and "name" in result and result["name"]:
|
||||
msc.ReleaseMode()
|
||||
status, result = msc.CheckMode()
|
||||
time.sleep(1.0)
|
||||
|
||||
crc = CRC()
|
||||
|
||||
# initialize DDS publisher
|
||||
lowcmd_pub_debug = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
lowcmd_pub_debug.Init()
|
||||
|
||||
# initialize DDS subscriber
|
||||
lowstate_sub = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
lowstate_sub.Init()
|
||||
|
||||
# initialize ZMQ
|
||||
ctx = zmq.Context.instance()
|
||||
|
||||
# receive commands from remote client
|
||||
lowcmd_sock = ctx.socket(zmq.PULL)
|
||||
lowcmd_sock.bind(f"tcp://0.0.0.0:{LOWCMD_PORT}")
|
||||
|
||||
# publish state to remote clients
|
||||
lowstate_sock = ctx.socket(zmq.PUB)
|
||||
lowstate_sock.bind(f"tcp://0.0.0.0:{LOWSTATE_PORT}")
|
||||
|
||||
state_period = 0.002 # ~500 hz
|
||||
shutdown_event = threading.Event()
|
||||
|
||||
# start observation forwarding in background thread
|
||||
t_state = threading.Thread(
|
||||
target=state_forward_loop,
|
||||
args=(lowstate_sub, lowstate_sock, state_period, shutdown_event),
|
||||
)
|
||||
t_state.start()
|
||||
|
||||
print("bridge running (lowstate -> zmq, lowcmd -> dds)")
|
||||
|
||||
# run command forwarding in main thread
|
||||
try:
|
||||
cmd_forward_loop(lowcmd_sock, lowcmd_pub_debug, crc)
|
||||
except KeyboardInterrupt:
|
||||
print("shutting down bridge...")
|
||||
finally:
|
||||
shutdown_event.set()
|
||||
ctx.term() # terminates blocking zmq.recv() calls
|
||||
t_state.join(timeout=2.0)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -1,268 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import struct
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
from functools import cached_property
|
||||
from typing import Any
|
||||
|
||||
import numpy as np
|
||||
from unitree_sdk2py.idl.default import unitree_hg_msg_dds__LowCmd_
|
||||
from unitree_sdk2py.idl.unitree_hg.msg.dds_ import (
|
||||
LowCmd_ as hg_LowCmd,
|
||||
LowState_ as hg_LowState,
|
||||
)
|
||||
from unitree_sdk2py.utils.crc import CRC
|
||||
|
||||
from lerobot.robots.unitree_g1.g1_utils import G1_29_JointIndex
|
||||
from lerobot.robots.unitree_g1.unitree_sdk2_socket import (
|
||||
ChannelFactoryInitialize,
|
||||
ChannelPublisher,
|
||||
ChannelSubscriber,
|
||||
)
|
||||
|
||||
from ..robot import Robot
|
||||
from .config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
kTopicLowState = "rt/lowstate"
|
||||
|
||||
G1_29_Num_Motors = 35
|
||||
G1_23_Num_Motors = 35
|
||||
H1_2_Num_Motors = 35
|
||||
H1_Num_Motors = 20
|
||||
|
||||
|
||||
@dataclass
|
||||
class MotorState:
|
||||
q: float | None = None # position
|
||||
dq: float | None = None # velocity
|
||||
tau_est: float | None = None # estimated torque
|
||||
temperature: float | None = None # motor temperature
|
||||
|
||||
|
||||
@dataclass
|
||||
class IMUState:
|
||||
quaternion: np.ndarray | None = None # [w, x, y, z]
|
||||
gyroscope: np.ndarray | None = None # [x, y, z] angular velocity (rad/s)
|
||||
accelerometer: np.ndarray | None = None # [x, y, z] linear acceleration (m/s²)
|
||||
rpy: np.ndarray | None = None # [roll, pitch, yaw] (rad)
|
||||
temperature: float | None = None # IMU temperature
|
||||
|
||||
|
||||
# g1 observation class
|
||||
@dataclass
|
||||
class G1_29_LowState: # noqa: N801
|
||||
motor_state: list[MotorState] = field(
|
||||
default_factory=lambda: [MotorState() for _ in range(G1_29_Num_Motors)]
|
||||
)
|
||||
imu_state: IMUState = field(default_factory=IMUState)
|
||||
wireless_remote: Any = None # Raw wireless remote data
|
||||
mode_machine: int = 0 # Robot mode
|
||||
|
||||
|
||||
class DataBuffer:
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.lock = threading.Lock()
|
||||
|
||||
def get_data(self):
|
||||
with self.lock:
|
||||
return self.data
|
||||
|
||||
def set_data(self, data):
|
||||
with self.lock:
|
||||
self.data = data
|
||||
|
||||
|
||||
class UnitreeG1(Robot):
|
||||
config_class = UnitreeG1Config
|
||||
name = "unitree_g1"
|
||||
|
||||
# unitree remote controller
|
||||
class RemoteController:
|
||||
def __init__(self):
|
||||
self.lx = 0
|
||||
self.ly = 0
|
||||
self.rx = 0
|
||||
self.ry = 0
|
||||
self.button = [0] * 16
|
||||
|
||||
def set(self, data):
|
||||
# wireless_remote
|
||||
keys = struct.unpack("H", data[2:4])[0]
|
||||
for i in range(16):
|
||||
self.button[i] = (keys & (1 << i)) >> i
|
||||
self.lx = struct.unpack("f", data[4:8])[0]
|
||||
self.rx = struct.unpack("f", data[8:12])[0]
|
||||
self.ry = struct.unpack("f", data[12:16])[0]
|
||||
self.ly = struct.unpack("f", data[20:24])[0]
|
||||
|
||||
def __init__(self, config: UnitreeG1Config):
|
||||
super().__init__(config)
|
||||
|
||||
logger.info("Initialize UnitreeG1...")
|
||||
|
||||
self.config = config
|
||||
|
||||
self.control_dt = config.control_dt
|
||||
|
||||
# connect robot
|
||||
self.connect()
|
||||
|
||||
# initialize direct motor control interface
|
||||
self.lowcmd_publisher = ChannelPublisher(kTopicLowCommand_Debug, hg_LowCmd)
|
||||
self.lowcmd_publisher.Init()
|
||||
self.lowstate_subscriber = ChannelSubscriber(kTopicLowState, hg_LowState)
|
||||
self.lowstate_subscriber.Init()
|
||||
self.lowstate_buffer = DataBuffer()
|
||||
|
||||
# initialize subscribe thread to read robot state
|
||||
self._shutdown_event = threading.Event()
|
||||
self.subscribe_thread = threading.Thread(target=self._subscribe_motor_state)
|
||||
self.subscribe_thread.start()
|
||||
|
||||
while not self.is_connected:
|
||||
time.sleep(0.1)
|
||||
|
||||
# initialize hg's lowcmd msg
|
||||
self.crc = CRC()
|
||||
self.msg = unitree_hg_msg_dds__LowCmd_()
|
||||
self.msg.mode_pr = 0
|
||||
|
||||
# Wait for first state message to arrive
|
||||
lowstate = None
|
||||
while lowstate is None:
|
||||
lowstate = self.lowstate_buffer.get_data()
|
||||
if lowstate is None:
|
||||
time.sleep(0.01)
|
||||
logger.warning("[UnitreeG1] Waiting for robot state...")
|
||||
logger.warning("[UnitreeG1] Connected to robot.")
|
||||
self.msg.mode_machine = lowstate.mode_machine
|
||||
|
||||
# initialize all motors with unified kp/kd from config
|
||||
self.kp = np.array(config.kp, dtype=np.float32)
|
||||
self.kd = np.array(config.kd, dtype=np.float32)
|
||||
|
||||
for id in G1_29_JointIndex:
|
||||
self.msg.motor_cmd[id].mode = 1
|
||||
self.msg.motor_cmd[id].kp = self.kp[id.value]
|
||||
self.msg.motor_cmd[id].kd = self.kd[id.value]
|
||||
self.msg.motor_cmd[id].q = lowstate.motor_state[id.value].q
|
||||
|
||||
# Initialize remote controller
|
||||
self.remote_controller = self.RemoteController()
|
||||
|
||||
def _subscribe_motor_state(self): # polls robot state @ 250Hz
|
||||
while not self._shutdown_event.is_set():
|
||||
start_time = time.time()
|
||||
msg = self.lowstate_subscriber.Read()
|
||||
if msg is not None:
|
||||
lowstate = G1_29_LowState()
|
||||
|
||||
# Capture motor states
|
||||
for id in range(G1_29_Num_Motors):
|
||||
lowstate.motor_state[id].q = msg.motor_state[id].q
|
||||
lowstate.motor_state[id].dq = msg.motor_state[id].dq
|
||||
lowstate.motor_state[id].tau_est = msg.motor_state[id].tau_est
|
||||
lowstate.motor_state[id].temperature = msg.motor_state[id].temperature
|
||||
|
||||
# Capture IMU state
|
||||
lowstate.imu_state.quaternion = list(msg.imu_state.quaternion)
|
||||
lowstate.imu_state.gyroscope = list(msg.imu_state.gyroscope)
|
||||
lowstate.imu_state.accelerometer = list(msg.imu_state.accelerometer)
|
||||
lowstate.imu_state.rpy = list(msg.imu_state.rpy)
|
||||
lowstate.imu_state.temperature = msg.imu_state.temperature
|
||||
|
||||
# Capture wireless remote data
|
||||
lowstate.wireless_remote = msg.wireless_remote
|
||||
|
||||
# Capture mode_machine
|
||||
lowstate.mode_machine = msg.mode_machine
|
||||
|
||||
self.lowstate_buffer.set_data(lowstate)
|
||||
|
||||
current_time = time.time()
|
||||
all_t_elapsed = current_time - start_time
|
||||
sleep_time = max(0, (self.control_dt - all_t_elapsed)) # maintain constant control dt
|
||||
time.sleep(sleep_time)
|
||||
|
||||
@cached_property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
def calibrate(self) -> None: # robot is already calibrated
|
||||
pass
|
||||
|
||||
def configure(self) -> None:
|
||||
pass
|
||||
|
||||
def connect(self, calibrate: bool = True) -> None: # connect to DDS
|
||||
ChannelFactoryInitialize(0)
|
||||
|
||||
def disconnect(self):
|
||||
self._shutdown_event.set()
|
||||
self.subscribe_thread.join(timeout=2.0)
|
||||
|
||||
def get_observation(self) -> dict[str, Any]:
|
||||
return self.lowstate_buffer.get_data()
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return True
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.lowstate_buffer.get_data() is not None
|
||||
|
||||
@property
|
||||
def _motors_ft(self) -> dict[str, type]:
|
||||
return {f"{G1_29_JointIndex(motor).name}.pos": float for motor in G1_29_JointIndex}
|
||||
|
||||
@property
|
||||
def _cameras_ft(self) -> dict[str, tuple]:
|
||||
return {
|
||||
cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras
|
||||
}
|
||||
|
||||
@cached_property
|
||||
def observation_features(self) -> dict[str, type | tuple]:
|
||||
return {**self._motors_ft, **self._cameras_ft}
|
||||
|
||||
def send_action(self, action: dict[str, Any]) -> dict[str, Any]:
|
||||
self.msg.crc = self.crc.Crc(action)
|
||||
self.lowcmd_publisher.Write(action)
|
||||
return action
|
||||
|
||||
def get_gravity_orientation(self, quaternion): # get gravity orientation from quaternion
|
||||
"""Get gravity orientation from quaternion."""
|
||||
qw = quaternion[0]
|
||||
qx = quaternion[1]
|
||||
qy = quaternion[2]
|
||||
qz = quaternion[3]
|
||||
|
||||
gravity_orientation = np.zeros(3)
|
||||
gravity_orientation[0] = 2 * (-qz * qx + qw * qy)
|
||||
gravity_orientation[1] = -2 * (qz * qy + qw * qx)
|
||||
gravity_orientation[2] = 1 - 2 * (qw * qw + qz * qz)
|
||||
return gravity_orientation
|
||||
@@ -1,168 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import base64
|
||||
import json
|
||||
from typing import Any
|
||||
|
||||
import zmq
|
||||
|
||||
from lerobot.robots.unitree_g1.config_unitree_g1 import UnitreeG1Config
|
||||
|
||||
_ctx: zmq.Context | None = None
|
||||
_lowcmd_sock: zmq.Socket | None = None
|
||||
_lowstate_sock: zmq.Socket | None = None
|
||||
|
||||
LOWCMD_PORT = 6000
|
||||
LOWSTATE_PORT = 6001
|
||||
|
||||
# DDS topic names follow Unitree SDK naming conventions
|
||||
# ruff: noqa: N816
|
||||
kTopicLowCommand_Debug = "rt/lowcmd"
|
||||
|
||||
|
||||
class LowStateMsg:
|
||||
"""
|
||||
Wrapper class that mimics the Unitree SDK LowState_ message structure.
|
||||
|
||||
Reconstructs the message from deserialized JSON data to maintain
|
||||
compatibility with existing code that expects SDK message objects.
|
||||
"""
|
||||
|
||||
class MotorState:
|
||||
"""Motor state data for a single joint."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.q: float = data.get("q", 0.0)
|
||||
self.dq: float = data.get("dq", 0.0)
|
||||
self.tau_est: float = data.get("tau_est", 0.0)
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
class IMUState:
|
||||
"""IMU sensor data."""
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
self.quaternion: list[float] = data.get("quaternion", [1.0, 0.0, 0.0, 0.0])
|
||||
self.gyroscope: list[float] = data.get("gyroscope", [0.0, 0.0, 0.0])
|
||||
self.accelerometer: list[float] = data.get("accelerometer", [0.0, 0.0, 0.0])
|
||||
self.rpy: list[float] = data.get("rpy", [0.0, 0.0, 0.0])
|
||||
self.temperature: float = data.get("temperature", 0.0)
|
||||
|
||||
def __init__(self, data: dict[str, Any]) -> None:
|
||||
"""Initialize from deserialized JSON data."""
|
||||
self.motor_state = [self.MotorState(m) for m in data.get("motor_state", [])]
|
||||
self.imu_state = self.IMUState(data.get("imu_state", {}))
|
||||
# Decode base64-encoded wireless_remote bytes
|
||||
wireless_b64 = data.get("wireless_remote", "")
|
||||
self.wireless_remote: bytes = base64.b64decode(wireless_b64) if wireless_b64 else b""
|
||||
self.mode_machine: int = data.get("mode_machine", 0)
|
||||
|
||||
|
||||
def lowcmd_to_dict(topic: str, msg: Any) -> dict[str, Any]:
|
||||
"""Convert LowCmd message to a JSON-serializable dictionary."""
|
||||
motor_cmds = []
|
||||
# Iterate over all motor commands in the message
|
||||
for i in range(len(msg.motor_cmd)):
|
||||
motor_cmds.append(
|
||||
{
|
||||
"mode": int(msg.motor_cmd[i].mode),
|
||||
"q": float(msg.motor_cmd[i].q),
|
||||
"dq": float(msg.motor_cmd[i].dq),
|
||||
"kp": float(msg.motor_cmd[i].kp),
|
||||
"kd": float(msg.motor_cmd[i].kd),
|
||||
"tau": float(msg.motor_cmd[i].tau),
|
||||
}
|
||||
)
|
||||
|
||||
return {
|
||||
"topic": topic,
|
||||
"data": {
|
||||
"mode_pr": int(msg.mode_pr),
|
||||
"mode_machine": int(msg.mode_machine),
|
||||
"motor_cmd": motor_cmds,
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
def ChannelFactoryInitialize(*args: Any, **kwargs: Any) -> None: # noqa: N802
|
||||
"""
|
||||
Initialize ZMQ sockets for robot communication.
|
||||
|
||||
This function mimics the Unitree SDK's ChannelFactoryInitialize but uses
|
||||
ZMQ sockets to connect to the robot server bridge instead of DDS.
|
||||
"""
|
||||
global _ctx, _lowcmd_sock, _lowstate_sock
|
||||
|
||||
# read socket config
|
||||
config = UnitreeG1Config()
|
||||
robot_ip = config.robot_ip
|
||||
|
||||
ctx = zmq.Context.instance()
|
||||
_ctx = ctx
|
||||
|
||||
# lowcmd: send robot commands
|
||||
lowcmd_sock = ctx.socket(zmq.PUSH)
|
||||
lowcmd_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowcmd_sock.connect(f"tcp://{robot_ip}:{LOWCMD_PORT}")
|
||||
_lowcmd_sock = lowcmd_sock
|
||||
|
||||
# lowstate: receive robot observations
|
||||
lowstate_sock = ctx.socket(zmq.SUB)
|
||||
lowstate_sock.setsockopt(zmq.CONFLATE, 1) # keep only last message
|
||||
lowstate_sock.connect(f"tcp://{robot_ip}:{LOWSTATE_PORT}")
|
||||
lowstate_sock.setsockopt_string(zmq.SUBSCRIBE, "")
|
||||
_lowstate_sock = lowstate_sock
|
||||
|
||||
|
||||
class ChannelPublisher:
|
||||
"""ZMQ-based publisher that sends commands to the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the publisher (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Write(self, msg: Any) -> None: # noqa: N802
|
||||
"""Serialize and send a command message to the robot."""
|
||||
if _lowcmd_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = json.dumps(lowcmd_to_dict(self.topic, msg)).encode("utf-8")
|
||||
_lowcmd_sock.send(payload)
|
||||
|
||||
|
||||
class ChannelSubscriber:
|
||||
"""ZMQ-based subscriber that receives state from the robot server."""
|
||||
|
||||
def __init__(self, topic: str, msg_type: type) -> None:
|
||||
self.topic = topic
|
||||
self.msg_type = msg_type
|
||||
|
||||
def Init(self) -> None: # noqa: N802
|
||||
"""Initialize the subscriber (no-op for ZMQ)."""
|
||||
pass
|
||||
|
||||
def Read(self) -> LowStateMsg: # noqa: N802
|
||||
"""Receive and deserialize a state message from the robot."""
|
||||
if _lowstate_sock is None:
|
||||
raise RuntimeError("ChannelFactoryInitialize must be called first")
|
||||
|
||||
payload = _lowstate_sock.recv()
|
||||
msg_dict = json.loads(payload.decode("utf-8"))
|
||||
return LowStateMsg(msg_dict.get("data", {}))
|
||||
@@ -65,6 +65,7 @@ import argparse
|
||||
import gc
|
||||
import logging
|
||||
import time
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
|
||||
import numpy as np
|
||||
@@ -77,6 +78,19 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
|
||||
|
||||
|
||||
class EpisodeSampler(torch.utils.data.Sampler):
|
||||
def __init__(self, dataset: LeRobotDataset, episode_index: int):
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
self.frame_ids = range(from_idx, to_idx)
|
||||
|
||||
def __iter__(self) -> Iterator:
|
||||
return iter(self.frame_ids)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.frame_ids)
|
||||
|
||||
|
||||
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
|
||||
assert chw_float32_torch.dtype == torch.float32
|
||||
assert chw_float32_torch.ndim == 3
|
||||
@@ -105,10 +119,12 @@ def visualize_dataset(
|
||||
repo_id = dataset.repo_id
|
||||
|
||||
logging.info("Loading dataloader")
|
||||
episode_sampler = EpisodeSampler(dataset, episode_index)
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=num_workers,
|
||||
batch_size=batch_size,
|
||||
sampler=episode_sampler,
|
||||
)
|
||||
|
||||
logging.info("Starting Rerun")
|
||||
|
||||
@@ -50,7 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -114,7 +114,7 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
|
||||
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
|
||||
break
|
||||
|
||||
precise_sleep(0.01)
|
||||
busy_wait(0.01)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -119,7 +119,7 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
@@ -364,7 +364,7 @@ def record_loop(
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -121,7 +121,7 @@ def replay(cfg: ReplayConfig):
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -170,13 +170,12 @@ def teleop_loop(
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 3)
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
busy_wait(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
move_cursor_up(1)
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
|
||||
@@ -61,6 +61,7 @@ def update_policy(
|
||||
accelerator: Accelerator,
|
||||
lr_scheduler=None,
|
||||
lock=None,
|
||||
rabc_weight_computer=None,
|
||||
) -> tuple[MetricsTracker, dict]:
|
||||
"""
|
||||
Performs a single training step to update the policy's weights.
|
||||
@@ -85,10 +86,22 @@ def update_policy(
|
||||
"""
|
||||
start_time = time.perf_counter()
|
||||
policy.train()
|
||||
|
||||
# Compute RA-BC weights if enabled
|
||||
rabc_weights = None
|
||||
if rabc_weight_computer is not None:
|
||||
rabc_weights = rabc_weight_computer.compute_batch_weights(batch)
|
||||
|
||||
# Let accelerator handle mixed precision
|
||||
with accelerator.autocast():
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Apply RA-BC weights if enabled
|
||||
if rabc_weights is not None:
|
||||
# Weight the loss
|
||||
loss = loss * rabc_weights.mean()
|
||||
output_dict['rabc_mean_weight'] = rabc_weights.mean().item()
|
||||
|
||||
# TODO(rcadene): policy.unnormalize_outputs(out_dict)
|
||||
|
||||
# Use accelerator's backward method
|
||||
@@ -140,8 +153,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg: A `TrainPipelineConfig` object containing all training configurations.
|
||||
accelerator: Optional Accelerator instance. If None, one will be created automatically.
|
||||
"""
|
||||
cfg.validate()
|
||||
|
||||
# Create Accelerator if not provided
|
||||
# It will automatically detect if running in distributed mode or single-process mode
|
||||
# We set step_scheduler_with_optimizer=False to prevent accelerate from adjusting the lr_scheduler steps based on the num_processes
|
||||
@@ -158,6 +169,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# When using accelerate, only the main process should log to avoid duplicate outputs
|
||||
is_main_process = accelerator.is_main_process
|
||||
|
||||
cfg.validate()
|
||||
|
||||
# Only log on main process
|
||||
if is_main_process:
|
||||
logging.info(pformat(cfg.to_dict()))
|
||||
@@ -215,6 +228,10 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if (cfg.policy.pretrained_path and not cfg.resume) or not cfg.policy.pretrained_path:
|
||||
# Only provide dataset_stats when not resuming from saved processor state
|
||||
processor_kwargs["dataset_stats"] = dataset.meta.stats
|
||||
|
||||
# For SARM, always provide dataset_meta for progress normalization
|
||||
if cfg.policy.type == "sarm":
|
||||
processor_kwargs["dataset_meta"] = dataset.meta
|
||||
|
||||
if cfg.policy.pretrained_path is not None:
|
||||
processor_kwargs["preprocessor_overrides"] = {
|
||||
@@ -246,6 +263,28 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
if is_main_process:
|
||||
logging.info("Creating optimizer and scheduler")
|
||||
optimizer, lr_scheduler = make_optimizer_and_scheduler(cfg, policy)
|
||||
|
||||
# Load reward model for RA-BC if enabled
|
||||
rabc_weight_computer = None
|
||||
if cfg.use_rabc:
|
||||
logging.info(f"Loading reward model for RA-BC from {cfg.reward_model_path}")
|
||||
from lerobot.policies.factory import get_policy_class
|
||||
from lerobot.utils.rabc import RABCWeightComputer
|
||||
|
||||
# Detect reward model type from path
|
||||
# For now, assume SARM if not specified
|
||||
reward_model_class = get_policy_class("sarm")
|
||||
reward_model = reward_model_class.from_pretrained(cfg.reward_model_path)
|
||||
reward_model.to(device)
|
||||
reward_model.eval()
|
||||
|
||||
rabc_weight_computer = RABCWeightComputer(
|
||||
reward_model=reward_model,
|
||||
kappa=cfg.rabc_kappa,
|
||||
epsilon=cfg.rabc_epsilon,
|
||||
device=device,
|
||||
)
|
||||
logging.info("RA-BC weight computer initialized")
|
||||
|
||||
step = 0 # number of policy updates (forward + backward + optim)
|
||||
|
||||
@@ -280,6 +319,18 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
elif cfg.policy.type == "sarm" and getattr(cfg.policy, "use_temporal_sampler", False):
|
||||
# Use SARM temporal sampler for reward model training
|
||||
from lerobot.datasets.temporal_sampler import SARMTemporalSampler
|
||||
|
||||
shuffle = False
|
||||
sampler = SARMTemporalSampler(
|
||||
dataset_from_index=dataset.meta.episodes["dataset_from_index"],
|
||||
dataset_to_index=dataset.meta.episodes["dataset_to_index"],
|
||||
frame_gap=getattr(cfg.policy, "frame_gap", 30),
|
||||
shuffle=True,
|
||||
seed=cfg.seed,
|
||||
)
|
||||
else:
|
||||
shuffle = True
|
||||
sampler = None
|
||||
@@ -324,7 +375,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
logging.info("Start offline training on a fixed dataset")
|
||||
logging.info(f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}")
|
||||
|
||||
for _ in range(step, cfg.steps):
|
||||
start_time = time.perf_counter()
|
||||
@@ -340,6 +391,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
cfg.optimizer.grad_clip_norm,
|
||||
accelerator=accelerator,
|
||||
lr_scheduler=lr_scheduler,
|
||||
rabc_weight_computer=rabc_weight_computer,
|
||||
)
|
||||
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
@@ -356,6 +408,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
wandb_log_dict = train_tracker.to_dict()
|
||||
if output_dict:
|
||||
wandb_log_dict.update(output_dict)
|
||||
# Log RA-BC statistics if enabled
|
||||
if rabc_weight_computer is not None:
|
||||
rabc_stats = rabc_weight_computer.get_stats()
|
||||
wandb_log_dict.update({
|
||||
'rabc_progress_mean': rabc_stats['mean'],
|
||||
'rabc_progress_std': rabc_stats['std'],
|
||||
'rabc_samples_seen': rabc_stats['count'],
|
||||
})
|
||||
wandb_logger.log_dict(wandb_log_dict, step)
|
||||
train_tracker.reset_averages()
|
||||
|
||||
|
||||
@@ -0,0 +1,183 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Reward-Aligned Behavior Cloning (RA-BC) utilities.
|
||||
|
||||
RA-BC uses a pre-trained reward model (e.g., SARM) to compute progress-based weights
|
||||
for training samples, emphasizing high-quality demonstrations and down-weighting
|
||||
suboptimal ones.
|
||||
"""
|
||||
|
||||
import logging
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
|
||||
|
||||
class RABCWeightComputer:
|
||||
"""
|
||||
Computes RA-BC weights for training batches using a pre-trained reward model.
|
||||
|
||||
Uses Welford's online algorithm for numerically stable running statistics
|
||||
and applies soft weighting based on progress deltas.
|
||||
|
||||
Args:
|
||||
reward_model: Pre-trained reward model (e.g., SARM)
|
||||
kappa: Hard threshold for high-quality samples (default: 0.01)
|
||||
epsilon: Small constant for numerical stability (default: 1e-6)
|
||||
device: Device to run reward model on
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
reward_model: nn.Module,
|
||||
kappa: float = 0.01,
|
||||
epsilon: float = 1e-6,
|
||||
device: torch.device = None,
|
||||
):
|
||||
self.reward_model = reward_model
|
||||
self.reward_model.eval() # Always in eval mode
|
||||
self.kappa = kappa
|
||||
self.epsilon = epsilon
|
||||
self.device = device or torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
||||
|
||||
# Running statistics (Welford's algorithm)
|
||||
self.mean = 0.0
|
||||
self.m2 = 0.0
|
||||
self.count = 0
|
||||
|
||||
logging.info(f"RA-BC WeightComputer initialized with kappa={kappa}, epsilon={epsilon}")
|
||||
|
||||
def _update_stats(self, deltas: torch.Tensor):
|
||||
"""Update running statistics using Welford's online algorithm."""
|
||||
for delta in deltas:
|
||||
self.count += 1
|
||||
delta_val = delta.item()
|
||||
delta_mean = delta_val - self.mean
|
||||
self.mean += delta_mean / self.count
|
||||
delta_m2 = delta_val - self.mean
|
||||
self.m2 += delta_mean * delta_m2
|
||||
|
||||
def _compute_weights(self, deltas: torch.Tensor) -> torch.Tensor:
|
||||
"""Compute RA-BC weights from progress deltas."""
|
||||
if self.count < 2:
|
||||
# Not enough data, use uniform weights
|
||||
return torch.ones_like(deltas)
|
||||
|
||||
# Get running statistics
|
||||
mean = max(self.mean, 0.0) # Clamp mean to non-negative
|
||||
variance = self.m2 / (self.count - 1)
|
||||
std = torch.tensor(variance).sqrt().item()
|
||||
|
||||
# Compute soft weights
|
||||
lower_bound = mean - 2 * std
|
||||
upper_bound = mean + 2 * std
|
||||
weights = (deltas - lower_bound) / (4 * std + self.epsilon)
|
||||
weights = torch.clamp(weights, 0.0, 1.0)
|
||||
|
||||
# Apply hard threshold
|
||||
high_quality_mask = deltas > self.kappa
|
||||
weights = torch.where(high_quality_mask, torch.ones_like(weights), weights)
|
||||
|
||||
return weights
|
||||
|
||||
@torch.no_grad()
|
||||
def compute_batch_weights(self, batch: dict, chunk_size: int = 1) -> torch.Tensor:
|
||||
"""
|
||||
Compute RA-BC weights for a training batch.
|
||||
|
||||
This function:
|
||||
1. Extracts current and next observations from the batch
|
||||
2. Computes rewards using the reward model
|
||||
3. Calculates progress deltas
|
||||
4. Updates running statistics
|
||||
5. Returns normalized weights
|
||||
|
||||
Args:
|
||||
batch: Training batch containing observations
|
||||
chunk_size: Size of action chunks for computing deltas (default: 1)
|
||||
|
||||
Returns:
|
||||
Weights tensor (batch_size,) normalized to sum to batch_size
|
||||
"""
|
||||
observation = batch.get('observation', batch)
|
||||
batch_size = next(iter(observation.values())).shape[0]
|
||||
|
||||
# Extract features needed for reward computation
|
||||
# These should already be encoded by the preprocessor
|
||||
if 'video_features' not in observation or 'text_features' not in observation:
|
||||
logging.warning("RA-BC: Missing video/text features, using uniform weights")
|
||||
return torch.ones(batch_size, device=self.device)
|
||||
|
||||
video_features = observation['video_features'].to(self.device)
|
||||
text_features = observation['text_features'].to(self.device)
|
||||
state_features = observation.get('state_features', None)
|
||||
if state_features is not None:
|
||||
state_features = state_features.to(self.device)
|
||||
|
||||
# Compute rewards for current observations
|
||||
# Handle both single-frame and multi-frame features
|
||||
if video_features.dim() == 3: # (B, T, D)
|
||||
# Multi-frame: use last frame reward
|
||||
if hasattr(self.reward_model, 'calculate_rewards'):
|
||||
current_rewards = self.reward_model.calculate_rewards(
|
||||
text_features, video_features, state_features,
|
||||
return_all_frames=False
|
||||
)
|
||||
else:
|
||||
# Fallback for models without calculate_rewards
|
||||
current_rewards = torch.zeros(batch_size, device=self.device)
|
||||
else: # (B, D)
|
||||
# Single frame
|
||||
if hasattr(self.reward_model, 'calculate_rewards'):
|
||||
current_rewards = self.reward_model.calculate_rewards(
|
||||
text_features, video_features.unsqueeze(1), state_features,
|
||||
return_all_frames=False
|
||||
)
|
||||
else:
|
||||
current_rewards = torch.zeros(batch_size, device=self.device)
|
||||
|
||||
if isinstance(current_rewards, tuple):
|
||||
current_rewards = current_rewards[0]
|
||||
|
||||
current_rewards = torch.tensor(current_rewards, device=self.device) if isinstance(current_rewards, (list, tuple)) else current_rewards
|
||||
|
||||
# For simplicity, assume progress delta is proportional to reward
|
||||
# In practice, you'd want to compute next_frame rewards and take differences
|
||||
# For now, use current reward as a proxy for progress delta
|
||||
progress_deltas = current_rewards
|
||||
|
||||
# Update running statistics
|
||||
self._update_stats(progress_deltas)
|
||||
|
||||
# Compute weights
|
||||
weights = self._compute_weights(progress_deltas)
|
||||
|
||||
# Normalize weights to sum to batch_size (maintains effective batch size)
|
||||
weight_sum = weights.sum() + self.epsilon
|
||||
weights = weights * batch_size / weight_sum
|
||||
|
||||
return weights
|
||||
|
||||
def get_stats(self) -> dict:
|
||||
"""Get current running statistics."""
|
||||
std = torch.tensor(self.m2 / (self.count - 1)).sqrt().item() if self.count > 1 else 0.0
|
||||
return {
|
||||
'mean': self.mean,
|
||||
'std': std,
|
||||
'count': self.count,
|
||||
}
|
||||
|
||||
@@ -16,40 +16,14 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
"""
|
||||
if seconds <= 0:
|
||||
return
|
||||
|
||||
system = platform.system()
|
||||
# On macOS and Windows the scheduler / sleep granularity can make
|
||||
# short sleeps inaccurate. Instead of burning CPU for the whole
|
||||
# duration, sleep for most of the time and spin for the final few
|
||||
# milliseconds to achieve good accuracy with much lower CPU usage.
|
||||
if system in ("Darwin", "Windows"):
|
||||
def busy_wait(seconds):
|
||||
if platform.system() == "Darwin" or platform.system() == "Windows":
|
||||
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
end_time = time.perf_counter() + seconds
|
||||
while True:
|
||||
remaining = end_time - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
# If there's more than a couple milliseconds left, sleep most
|
||||
# of the remaining time and leave a small margin for the final spin.
|
||||
if remaining > spin_threshold:
|
||||
# Sleep but avoid sleeping past the end by leaving a small margin.
|
||||
time.sleep(max(remaining - sleep_margin, 0))
|
||||
else:
|
||||
# Final short spin to hit precise timing without long sleeps.
|
||||
pass
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
else:
|
||||
# On Linux time.sleep is accurate enough for most uses
|
||||
time.sleep(seconds)
|
||||
# On Linux time.sleep is accurate
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
|
||||
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Tests for SARM utility functions.
|
||||
|
||||
Tests the implementation of SARM paper formulas:
|
||||
- Formula (1): compute_temporal_proportions - dataset-level temporal proportions
|
||||
- Formula (2): compute_tau, compute_cumulative_progress - progress labels
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sarm.sarm_utils import SubtaskAnnotation, Subtask, Timestamp
|
||||
from lerobot.policies.sarm.sarm_utils import (
|
||||
compute_temporal_proportions,
|
||||
compute_tau,
|
||||
compute_cumulative_progress_batch,
|
||||
)
|
||||
|
||||
def make_annotation(subtasks: list[tuple[str, int, int]]) -> SubtaskAnnotation:
|
||||
"""Helper to create SubtaskAnnotation from list of (name, start_sec, end_sec)."""
|
||||
return SubtaskAnnotation(
|
||||
subtasks=[
|
||||
Subtask(
|
||||
name=name,
|
||||
timestamps=Timestamp(
|
||||
start=f"{start // 60:02d}:{start % 60:02d}",
|
||||
end=f"{end // 60:02d}:{end % 60:02d}"
|
||||
)
|
||||
)
|
||||
for name, start, end in subtasks
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
class TestComputeTemporalProportions:
|
||||
"""Tests for compute_temporal_proportions (SARM Paper Formula 1).
|
||||
|
||||
Formula: ᾱ_k = (1/M) × Σ_i (L_{i,k} / T_i)
|
||||
|
||||
Key insight: This averages the PROPORTION of each subtask within each trajectory,
|
||||
giving equal weight to all trajectories regardless of absolute length.
|
||||
"""
|
||||
|
||||
def test_basic_two_trajectories_equal_proportions(self):
|
||||
"""Test with two trajectories that have equal proportions."""
|
||||
# Both trajectories: subtask1 = 50%, subtask2 = 50%
|
||||
# Traj 1: T=100s, subtask1=50s, subtask2=50s
|
||||
# Traj 2: T=200s, subtask1=100s, subtask2=100s
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 50), ('subtask2', 50, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 100), ('subtask2', 100, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Both should be 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
def test_paper_example_different_from_avg_durations(self):
|
||||
"""Test that compute_temporal_proportions differs from naive average duration approach.
|
||||
|
||||
This is the key test showing the difference between:
|
||||
- Paper formula: average of (L_i,k / T_i)
|
||||
- Naive approach: mean(L_i,k) / sum(mean(L_i,j))
|
||||
"""
|
||||
# Episode 1: T=100s, subtask1=80s, subtask2=20s (proportions: 0.8, 0.2)
|
||||
# Episode 2: T=200s, subtask1=40s, subtask2=160s (proportions: 0.2, 0.8)
|
||||
annotations = {
|
||||
0: make_annotation([('subtask1', 0, 80), ('subtask2', 80, 100)]),
|
||||
1: make_annotation([('subtask1', 0, 40), ('subtask2', 40, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
# Paper formula:
|
||||
# ᾱ_1 = (1/2) × (80/100 + 40/200) = (1/2) × (0.8 + 0.2) = 0.5
|
||||
# ᾱ_2 = (1/2) × (20/100 + 160/200) = (1/2) × (0.2 + 0.8) = 0.5
|
||||
assert abs(result['subtask1'] - 0.5) < 1e-6
|
||||
assert abs(result['subtask2'] - 0.5) < 1e-6
|
||||
|
||||
def test_single_trajectory(self):
|
||||
"""Test with a single trajectory."""
|
||||
# T=100s, reach=30s, grasp=20s, lift=50s
|
||||
annotations = {
|
||||
0: make_annotation([('reach', 0, 30), ('grasp', 30, 50), ('lift', 50, 100)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
assert abs(result['reach'] - 0.3) < 1e-6
|
||||
assert abs(result['grasp'] - 0.2) < 1e-6
|
||||
assert abs(result['lift'] - 0.5) < 1e-6
|
||||
|
||||
def test_sum_to_one(self):
|
||||
"""Test that proportions always sum to 1."""
|
||||
# Three episodes with varying proportions
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 10), ('b', 10, 50), ('c', 50, 100)]), # 0.1, 0.4, 0.5
|
||||
1: make_annotation([('a', 0, 20), ('b', 20, 70), ('c', 70, 100)]), # 0.2, 0.5, 0.3
|
||||
2: make_annotation([('a', 0, 30), ('b', 30, 90), ('c', 90, 100)]), # 0.3, 0.6, 0.1
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
total = sum(result.values())
|
||||
assert abs(total - 1.0) < 1e-6
|
||||
|
||||
def test_empty_annotations_returns_empty(self):
|
||||
"""Test that empty annotations returns empty dict."""
|
||||
result = compute_temporal_proportions({})
|
||||
assert result == {}
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions across subtasks."""
|
||||
# Each subtask takes 25% of each episode
|
||||
annotations = {
|
||||
0: make_annotation([('a', 0, 25), ('b', 25, 50), ('c', 50, 75), ('d', 75, 100)]),
|
||||
1: make_annotation([('a', 0, 50), ('b', 50, 100), ('c', 100, 150), ('d', 150, 200)]),
|
||||
}
|
||||
|
||||
result = compute_temporal_proportions(annotations)
|
||||
|
||||
for name in ['a', 'b', 'c', 'd']:
|
||||
assert abs(result[name] - 0.25) < 1e-6
|
||||
|
||||
|
||||
class TestComputeTau:
|
||||
"""Tests for compute_tau (within-subtask progress).
|
||||
|
||||
Formula: τ_t = (t - s_k) / (e_k - s_k) ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_at_start(self):
|
||||
"""τ should be 0 at subtask start."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_at_end(self):
|
||||
"""τ should be 1 at subtask end."""
|
||||
tau = compute_tau(current_frame=50, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_at_middle(self):
|
||||
"""τ should be 0.5 at subtask midpoint."""
|
||||
tau = compute_tau(current_frame=30, subtask_start=10, subtask_end=50)
|
||||
assert abs(tau - 0.5) < 1e-6
|
||||
|
||||
def test_quarter_progress(self):
|
||||
"""Test τ at 25% through subtask."""
|
||||
tau = compute_tau(current_frame=20, subtask_start=0, subtask_end=80)
|
||||
assert abs(tau - 0.25) < 1e-6
|
||||
|
||||
def test_zero_duration_subtask(self):
|
||||
"""τ should be 1.0 for zero-duration subtask."""
|
||||
tau = compute_tau(current_frame=10, subtask_start=10, subtask_end=10)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_clamps_below_zero(self):
|
||||
"""τ should be clamped to 0 if frame is before subtask."""
|
||||
tau = compute_tau(current_frame=5, subtask_start=10, subtask_end=50)
|
||||
assert tau == 0.0
|
||||
|
||||
def test_clamps_above_one(self):
|
||||
"""τ should be clamped to 1 if frame is after subtask."""
|
||||
tau = compute_tau(current_frame=60, subtask_start=10, subtask_end=50)
|
||||
assert tau == 1.0
|
||||
|
||||
def test_float_inputs(self):
|
||||
"""Test with float frame indices (from interpolation)."""
|
||||
tau = compute_tau(current_frame=25.5, subtask_start=10.0, subtask_end=50.0)
|
||||
expected = (25.5 - 10.0) / (50.0 - 10.0)
|
||||
assert abs(tau - expected) < 1e-6
|
||||
|
||||
|
||||
class TestComputeCumulativeProgressBatchScalar:
|
||||
"""Tests for compute_cumulative_progress_batch with scalar inputs (normalized progress y_t).
|
||||
|
||||
Formula: y_t = P_{k-1} + ᾱ_k × τ_t ∈ [0, 1]
|
||||
"""
|
||||
|
||||
def test_first_subtask_start(self):
|
||||
"""y should be 0 at start of first subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=0, alpha=proportions)
|
||||
assert y == 0.0
|
||||
|
||||
def test_first_subtask_end(self):
|
||||
"""y should equal ᾱ_1 at end of first subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=0, alpha=proportions)
|
||||
assert abs(y - 0.3) < 1e-6
|
||||
|
||||
def test_second_subtask_start(self):
|
||||
"""y should equal P_1 at start of second subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=0.0, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.3) < 1e-6
|
||||
|
||||
def test_second_subtask_end(self):
|
||||
"""y should equal P_2 at end of second subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.8) < 1e-6 # 0.3 + 0.5
|
||||
|
||||
def test_third_subtask_end(self):
|
||||
"""y should be 1.0 at end of last subtask."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=2, alpha=proportions)
|
||||
assert abs(y - 1.0) < 1e-6
|
||||
|
||||
def test_midpoint_of_subtask(self):
|
||||
"""Test progress at midpoint of a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
# At τ=0.5 in subtask 1: y = P_0 + ᾱ_1 × 0.5 = 0 + 0.4 × 0.5 = 0.2
|
||||
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=0, alpha=proportions)
|
||||
assert abs(y - 0.2) < 1e-6
|
||||
|
||||
# At τ=0.5 in subtask 2: y = P_1 + ᾱ_2 × 0.5 = 0.4 + 0.6 × 0.5 = 0.7
|
||||
y = compute_cumulative_progress_batch(tau=0.5, stage_indices=1, alpha=proportions)
|
||||
assert abs(y - 0.7) < 1e-6
|
||||
|
||||
def test_uniform_proportions(self):
|
||||
"""Test with uniform proportions."""
|
||||
proportions = [0.25, 0.25, 0.25, 0.25]
|
||||
|
||||
# At end of each subtask, progress should be 0.25, 0.5, 0.75, 1.0
|
||||
for i in range(4):
|
||||
y = compute_cumulative_progress_batch(tau=1.0, stage_indices=i, alpha=proportions)
|
||||
expected = (i + 1) * 0.25
|
||||
assert abs(y - expected) < 1e-6
|
||||
|
||||
|
||||
class TestComputeCumulativeProgressBatchTensor:
|
||||
"""Tests for compute_cumulative_progress_batch with tensor inputs (GPU batch version)."""
|
||||
|
||||
def test_tensor_matches_scalar_version(self):
|
||||
"""Test that tensor version matches scalar version."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
cumulative = torch.zeros(len(proportions) + 1, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
test_cases = [
|
||||
(0.0, 0), # start of subtask 0
|
||||
(1.0, 0), # end of subtask 0
|
||||
(0.0, 1), # start of subtask 1
|
||||
(0.5, 1), # middle of subtask 1
|
||||
(1.0, 2), # end of subtask 2
|
||||
]
|
||||
|
||||
for tau_val, stage_idx in test_cases:
|
||||
# Scalar version
|
||||
expected = compute_cumulative_progress_batch(tau_val, stage_idx, proportions)
|
||||
|
||||
# Tensor version (single element)
|
||||
tau = torch.tensor([[[tau_val]]]) # (1, 1, 1)
|
||||
stages = torch.tensor([[stage_idx]]) # (1, 1)
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
|
||||
|
||||
assert abs(result[0, 0, 0].item() - expected) < 1e-6
|
||||
|
||||
def test_batch_processing(self):
|
||||
"""Test batch processing with multiple samples."""
|
||||
proportions = [0.4, 0.6]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
cumulative = torch.zeros(3, dtype=torch.float32)
|
||||
cumulative[1:] = torch.cumsum(alpha, dim=0)
|
||||
|
||||
# Batch of 2 samples, sequence length 3
|
||||
tau = torch.tensor([
|
||||
[[0.0], [0.5], [1.0]], # sample 1
|
||||
[[0.0], [0.5], [1.0]], # sample 2
|
||||
])
|
||||
stages = torch.tensor([
|
||||
[0, 0, 0], # sample 1: all in subtask 0
|
||||
[1, 1, 1], # sample 2: all in subtask 1
|
||||
])
|
||||
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha, cumulative)
|
||||
|
||||
# Sample 1: subtask 0 with tau 0, 0.5, 1.0 -> y = 0, 0.2, 0.4
|
||||
assert abs(result[0, 0, 0].item() - 0.0) < 1e-6
|
||||
assert abs(result[0, 1, 0].item() - 0.2) < 1e-6
|
||||
assert abs(result[0, 2, 0].item() - 0.4) < 1e-6
|
||||
|
||||
# Sample 2: subtask 1 with tau 0, 0.5, 1.0 -> y = 0.4, 0.7, 1.0
|
||||
assert abs(result[1, 0, 0].item() - 0.4) < 1e-6
|
||||
assert abs(result[1, 1, 0].item() - 0.7) < 1e-6
|
||||
assert abs(result[1, 2, 0].item() - 1.0) < 1e-6
|
||||
|
||||
def test_auto_compute_cumulative_prior(self):
|
||||
"""Test that cumulative_prior is auto-computed when not provided."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
alpha = torch.tensor(proportions, dtype=torch.float32)
|
||||
|
||||
tau = torch.tensor([[[0.5]]])
|
||||
stages = torch.tensor([[1]])
|
||||
|
||||
# Without cumulative_prior (should auto-compute)
|
||||
result = compute_cumulative_progress_batch(tau, stages, alpha)
|
||||
|
||||
# Expected: P_0 + alpha_1 * 0.5 = 0.3 + 0.5 * 0.5 = 0.55
|
||||
assert abs(result[0, 0, 0].item() - 0.55) < 1e-6
|
||||
|
||||
|
||||
class TestEndToEndProgressLabeling:
|
||||
"""End-to-end tests for progress label computation."""
|
||||
|
||||
def test_consistent_semantic_meaning(self):
|
||||
"""Test that same subtask completion maps to same progress across trajectories.
|
||||
|
||||
This is the key semantic property: "end of subtask 1" should always
|
||||
mean the same progress value regardless of trajectory speed.
|
||||
"""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# Fast trajectory: subtask 1 ends at frame 30 (of 100)
|
||||
tau_fast = compute_tau(30, 0, 30) # = 1.0
|
||||
y_fast = compute_cumulative_progress_batch(tau_fast, 0, proportions)
|
||||
|
||||
# Slow trajectory: subtask 1 ends at frame 90 (of 300)
|
||||
tau_slow = compute_tau(90, 0, 90) # = 1.0
|
||||
y_slow = compute_cumulative_progress_batch(tau_slow, 0, proportions)
|
||||
|
||||
# Both should map to same progress (0.3 = end of subtask 1)
|
||||
assert abs(y_fast - y_slow) < 1e-6
|
||||
assert abs(y_fast - 0.3) < 1e-6
|
||||
|
||||
def test_monotonic_within_subtask(self):
|
||||
"""Test that progress is monotonically increasing within a subtask."""
|
||||
proportions = [0.4, 0.6]
|
||||
|
||||
prev_y = -1
|
||||
for tau in np.linspace(0, 1, 11):
|
||||
y = compute_cumulative_progress_batch(tau, 0, proportions)
|
||||
assert y > prev_y or (tau == 0 and y == 0)
|
||||
prev_y = y
|
||||
|
||||
def test_continuous_across_subtasks(self):
|
||||
"""Test that progress is continuous at subtask boundaries."""
|
||||
proportions = [0.3, 0.5, 0.2]
|
||||
|
||||
# End of subtask 0 (tau=1.0)
|
||||
y_end_0 = compute_cumulative_progress_batch(1.0, 0, proportions)
|
||||
|
||||
# Start of subtask 1 (tau=0.0)
|
||||
y_start_1 = compute_cumulative_progress_batch(0.0, 1, proportions)
|
||||
|
||||
# Should be equal (P_1 = 0.3)
|
||||
assert abs(y_end_0 - y_start_1) < 1e-6
|
||||
|
||||
# End of subtask 1
|
||||
y_end_1 = compute_cumulative_progress_batch(1.0, 1, proportions)
|
||||
|
||||
# Start of subtask 2
|
||||
y_start_2 = compute_cumulative_progress_batch(0.0, 2, proportions)
|
||||
|
||||
# Should be equal (P_2 = 0.8)
|
||||
assert abs(y_end_1 - y_start_2) < 1e-6
|
||||
|
||||
Reference in New Issue
Block a user