mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
16 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 58f70b6bd3 | |||
| b07160eb1b | |||
| 648ea8f485 | |||
| 581dd45eae | |||
| 17581a9449 | |||
| 87bee86640 | |||
| 18b32dced9 | |||
| 36e8feefe3 | |||
| 0f551df8f4 | |||
| 6e86a69dcd | |||
| 8a915c6b6f | |||
| b464d9f8bc | |||
| 784cdae55a | |||
| d9e74a9d37 | |||
| a5b29d4301 | |||
| a4aa316470 |
@@ -83,11 +83,11 @@ jobs:
|
||||
fi
|
||||
|
||||
- name: Remove Tags with Git dependencies
|
||||
# TODO(Steven): Temporary patch to remove libero and pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
# TODO(Steven): Temporary patch to remove pi from PyPi 0.4.0 release due to its reliance on git dependencies.
|
||||
run: |
|
||||
echo "::info:: Checking for Git dependencies to remove from pyproject.toml..."
|
||||
grep -E '@ git\+https|lerobot\[pi\]|lerobot\[libero\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]|lerobot\[libero\]/d' pyproject.toml
|
||||
grep -E '@ git\+https|lerobot\[pi\]' pyproject.toml | sed 's/^/::warning:: Removing line: /' || true
|
||||
sed -E -i '/@ git\+https|lerobot\[pi\]/d' pyproject.toml
|
||||
echo "::info:: Git dependencies removed. Proceeding with build."
|
||||
|
||||
- name: Install build dependencies
|
||||
|
||||
@@ -70,7 +70,7 @@ jobs:
|
||||
echo "Dependencies unbound:" && cat pyproject.toml
|
||||
|
||||
- name: Install lerobot with all extras
|
||||
run: uv sync --all-extras
|
||||
run: uv sync --all-extras --no-extra groot # TODO(Steven): Make flash-attn optional
|
||||
|
||||
- name: Run pytest (all extras)
|
||||
run: uv run pytest tests -vv
|
||||
|
||||
@@ -186,7 +186,7 @@ For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi tags, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
> For lerobot 0.4.0, if you want to install pi tags, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
|
||||
@@ -1,94 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import threading
|
||||
import time
|
||||
from contextlib import ContextDecorator
|
||||
|
||||
|
||||
class TimeBenchmark(ContextDecorator):
|
||||
"""
|
||||
Measures execution time using a context manager or decorator.
|
||||
|
||||
This class supports both context manager and decorator usage, and is thread-safe for multithreaded
|
||||
environments.
|
||||
|
||||
Args:
|
||||
print: If True, prints the elapsed time upon exiting the context or completing the function. Defaults
|
||||
to False.
|
||||
|
||||
Examples:
|
||||
|
||||
Using as a context manager:
|
||||
|
||||
>>> benchmark = TimeBenchmark()
|
||||
>>> with benchmark:
|
||||
... time.sleep(1)
|
||||
>>> print(f"Block took {benchmark.result:.4f} seconds")
|
||||
Block took approximately 1.0000 seconds
|
||||
|
||||
Using with multithreading:
|
||||
|
||||
```python
|
||||
import threading
|
||||
|
||||
benchmark = TimeBenchmark()
|
||||
|
||||
|
||||
def context_manager_example():
|
||||
with benchmark:
|
||||
time.sleep(0.01)
|
||||
print(f"Block took {benchmark.result_ms:.2f} milliseconds")
|
||||
|
||||
|
||||
threads = []
|
||||
for _ in range(3):
|
||||
t1 = threading.Thread(target=context_manager_example)
|
||||
threads.append(t1)
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
```
|
||||
Expected output:
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
Block took approximately 10.00 milliseconds
|
||||
"""
|
||||
|
||||
def __init__(self, print=False):
|
||||
self.local = threading.local()
|
||||
self.print_time = print
|
||||
|
||||
def __enter__(self):
|
||||
self.local.start_time = time.perf_counter()
|
||||
return self
|
||||
|
||||
def __exit__(self, *exc):
|
||||
self.local.end_time = time.perf_counter()
|
||||
self.local.elapsed_time = self.local.end_time - self.local.start_time
|
||||
if self.print_time:
|
||||
print(f"Elapsed time: {self.local.elapsed_time:.4f} seconds")
|
||||
return False
|
||||
|
||||
@property
|
||||
def result(self):
|
||||
return getattr(self.local, "elapsed_time", None)
|
||||
|
||||
@property
|
||||
def result_ms(self):
|
||||
return self.result * 1e3
|
||||
@@ -1,102 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Capture video feed from a camera as raw images."""
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import os
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
import cv2
|
||||
import rerun as rr
|
||||
|
||||
# see https://rerun.io/docs/howto/visualization/limit-ram
|
||||
RERUN_MEMORY_LIMIT = os.getenv("LEROBOT_RERUN_MEMORY_LIMIT", "5%")
|
||||
|
||||
|
||||
def display_and_save_video_stream(output_dir: Path, fps: int, width: int, height: int, duration: int):
|
||||
rr.init("lerobot_capture_camera_feed")
|
||||
rr.spawn(memory_limit=RERUN_MEMORY_LIMIT)
|
||||
|
||||
now = dt.datetime.now()
|
||||
capture_dir = output_dir / f"{now:%Y-%m-%d}" / f"{now:%H-%M-%S}"
|
||||
if not capture_dir.exists():
|
||||
capture_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Opens the default webcam
|
||||
cap = cv2.VideoCapture(0)
|
||||
if not cap.isOpened():
|
||||
print("Error: Could not open video stream.")
|
||||
return
|
||||
|
||||
cap.set(cv2.CAP_PROP_FPS, fps)
|
||||
cap.set(cv2.CAP_PROP_FRAME_WIDTH, width)
|
||||
cap.set(cv2.CAP_PROP_FRAME_HEIGHT, height)
|
||||
|
||||
frame_index = 0
|
||||
start_time = time.time()
|
||||
while time.time() - start_time < duration:
|
||||
ret, frame = cap.read()
|
||||
|
||||
if not ret:
|
||||
print("Error: Could not read frame.")
|
||||
break
|
||||
rr.log("video/stream", rr.Image(frame), static=True)
|
||||
cv2.imwrite(str(capture_dir / f"frame_{frame_index:06d}.png"), frame)
|
||||
frame_index += 1
|
||||
|
||||
# Release the capture
|
||||
cap.release()
|
||||
|
||||
# TODO(Steven): Add a graceful shutdown via a close() method for the Viewer context, though not currently supported in the Rerun API.
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
parser = argparse.ArgumentParser()
|
||||
|
||||
parser.add_argument(
|
||||
"--output-dir",
|
||||
type=Path,
|
||||
default=Path("outputs/cam_capture/"),
|
||||
help="Directory where the capture images are written. A subfolder named with the current date & time will be created inside it for each capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--fps",
|
||||
type=int,
|
||||
default=30,
|
||||
help="Frames Per Second of the capture.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--width",
|
||||
type=int,
|
||||
default=1280,
|
||||
help="Width of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--height",
|
||||
type=int,
|
||||
default=720,
|
||||
help="Height of the captured images.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--duration",
|
||||
type=int,
|
||||
default=20,
|
||||
help="Duration in seconds for which the video stream should be captured.",
|
||||
)
|
||||
args = parser.parse_args()
|
||||
display_and_save_video_stream(**vars(args))
|
||||
@@ -21,11 +21,13 @@ See the provided README.md or run `python benchmark/video/run_video_benchmark.py
|
||||
|
||||
import argparse
|
||||
import datetime as dt
|
||||
import itertools
|
||||
import random
|
||||
import shutil
|
||||
from collections import OrderedDict
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
@@ -35,13 +37,13 @@ import torch
|
||||
from skimage.metrics import mean_squared_error, peak_signal_noise_ratio, structural_similarity
|
||||
from tqdm import tqdm
|
||||
|
||||
from benchmarks.video.benchmark import TimeBenchmark
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.video_utils import (
|
||||
decode_video_frames_torchvision,
|
||||
decode_video_frames,
|
||||
encode_video_frames,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGE
|
||||
from lerobot.utils.utils import TimerManager
|
||||
|
||||
BASE_ENCODING = OrderedDict(
|
||||
[
|
||||
@@ -86,7 +88,7 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
frames = []
|
||||
for ts in timestamps:
|
||||
idx = int(ts * fps)
|
||||
frame = PIL.Image.open(imgs_dir / f"frame_{idx:06d}.png")
|
||||
frame = PIL.Image.open(imgs_dir / f"frame-{idx:06d}.png")
|
||||
frame = torch.from_numpy(np.array(frame))
|
||||
frame = frame.type(torch.float32) / 255
|
||||
frame = einops.rearrange(frame, "h w c -> c h w")
|
||||
@@ -97,21 +99,21 @@ def load_original_frames(imgs_dir: Path, timestamps: list[float], fps: int) -> t
|
||||
def save_decoded_frames(
|
||||
imgs_dir: Path, save_dir: Path, frames: torch.Tensor, timestamps: list[float], fps: int
|
||||
) -> None:
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame_*.png"))) == len(timestamps):
|
||||
if save_dir.exists() and len(list(save_dir.glob("frame-*.png"))) == len(timestamps):
|
||||
return
|
||||
|
||||
save_dir.mkdir(parents=True, exist_ok=True)
|
||||
for i, ts in enumerate(timestamps):
|
||||
idx = int(ts * fps)
|
||||
frame_hwc = (frames[i].permute((1, 2, 0)) * 255).type(torch.uint8).cpu().numpy()
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame_{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame_{idx:06d}.png", save_dir / f"frame_{idx:06d}_original.png")
|
||||
PIL.Image.fromarray(frame_hwc).save(save_dir / f"frame-{idx:06d}_decoded.png")
|
||||
shutil.copyfile(imgs_dir / f"frame-{idx:06d}.png", save_dir / f"frame-{idx:06d}_original.png")
|
||||
|
||||
|
||||
def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
episode_index = 0
|
||||
ep_num_images = dataset.meta.episodes["length"][episode_index]
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame_*.png"))) == ep_num_images:
|
||||
if imgs_dir.exists() and len(list(imgs_dir.glob("frame-*.png"))) == ep_num_images:
|
||||
return
|
||||
|
||||
imgs_dir.mkdir(parents=True, exist_ok=True)
|
||||
@@ -125,7 +127,7 @@ def save_first_episode(imgs_dir: Path, dataset: LeRobotDataset) -> None:
|
||||
tqdm(imgs_dataset, desc=f"saving {dataset.repo_id} first episode images", leave=False)
|
||||
):
|
||||
img = item[img_keys[0]]
|
||||
img.save(str(imgs_dir / f"frame_{i:06d}.png"), quality=100)
|
||||
img.save(str(imgs_dir / f"frame-{i:06d}.png"), quality=100)
|
||||
|
||||
if i >= ep_num_images - 1:
|
||||
break
|
||||
@@ -149,18 +151,6 @@ def sample_timestamps(timestamps_mode: str, ep_num_images: int, fps: int) -> lis
|
||||
return [idx / fps for idx in frame_indexes]
|
||||
|
||||
|
||||
def decode_video_frames(
|
||||
video_path: str,
|
||||
timestamps: list[float],
|
||||
tolerance_s: float,
|
||||
backend: str,
|
||||
) -> torch.Tensor:
|
||||
if backend in ["pyav", "video_reader"]:
|
||||
return decode_video_frames_torchvision(video_path, timestamps, tolerance_s, backend)
|
||||
else:
|
||||
raise NotImplementedError(backend)
|
||||
|
||||
|
||||
def benchmark_decoding(
|
||||
imgs_dir: Path,
|
||||
video_path: Path,
|
||||
@@ -172,8 +162,8 @@ def benchmark_decoding(
|
||||
num_workers: int = 4,
|
||||
save_frames: bool = False,
|
||||
) -> dict:
|
||||
def process_sample(sample: int):
|
||||
time_benchmark = TimeBenchmark()
|
||||
def process_sample(sample: int, lock: Lock):
|
||||
time_benchmark = TimerManager(log=False)
|
||||
timestamps = sample_timestamps(timestamps_mode, ep_num_images, fps)
|
||||
num_frames = len(timestamps)
|
||||
result = {
|
||||
@@ -182,13 +172,13 @@ def benchmark_decoding(
|
||||
"mse_values": [],
|
||||
}
|
||||
|
||||
with time_benchmark:
|
||||
with time_benchmark, lock:
|
||||
frames = decode_video_frames(video_path, timestamps=timestamps, tolerance_s=5e-1, backend=backend)
|
||||
result["load_time_video_ms"] = time_benchmark.result_ms / num_frames
|
||||
result["load_time_video_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
with time_benchmark:
|
||||
original_frames = load_original_frames(imgs_dir, timestamps, fps)
|
||||
result["load_time_images_ms"] = time_benchmark.result_ms / num_frames
|
||||
result["load_time_images_ms"] = (time_benchmark.last * 1000) / num_frames
|
||||
|
||||
frames_np, original_frames_np = frames.numpy(), original_frames.numpy()
|
||||
for i in range(num_frames):
|
||||
@@ -215,8 +205,10 @@ def benchmark_decoding(
|
||||
# A sample is a single set of decoded frames specified by timestamps_mode (e.g. a single frame, 2 frames, etc.).
|
||||
# For each sample, we record metrics (loading time and quality metrics) which are then averaged over all samples.
|
||||
# As these samples are independent, we run them in parallel threads to speed up the benchmark.
|
||||
# Use a single shared lock for all worker threads
|
||||
shared_lock = Lock()
|
||||
with ThreadPoolExecutor(max_workers=num_workers) as executor:
|
||||
futures = [executor.submit(process_sample, i) for i in range(num_samples)]
|
||||
futures = [executor.submit(process_sample, i, shared_lock) for i in range(num_samples)]
|
||||
for future in tqdm(as_completed(futures), total=num_samples, desc="samples", leave=False):
|
||||
result = future.result()
|
||||
load_times_video_ms.append(result["load_time_video_ms"])
|
||||
@@ -358,24 +350,27 @@ def main(
|
||||
imgs_dir = output_dir / "images" / dataset.repo_id.replace("/", "_")
|
||||
# We only use the first episode
|
||||
save_first_episode(imgs_dir, dataset)
|
||||
for key, values in tqdm(encoding_benchmarks.items(), desc="encodings (g, crf)", leave=False):
|
||||
for value in tqdm(values, desc=f"encodings ({key})", leave=False):
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for duet in [
|
||||
dict(zip(encoding_benchmarks.keys(), unique_combination, strict=False))
|
||||
for unique_combination in itertools.product(*encoding_benchmarks.values())
|
||||
]:
|
||||
encoding_cfg = BASE_ENCODING.copy()
|
||||
encoding_cfg["vcodec"] = video_codec
|
||||
encoding_cfg["pix_fmt"] = pixel_format
|
||||
for key, value in duet.items():
|
||||
encoding_cfg[key] = value
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
args_path = Path("_".join(str(value) for value in encoding_cfg.values()))
|
||||
video_path = output_dir / "videos" / args_path / f"{repo_id.replace('/', '_')}.mp4"
|
||||
benchmark_table += benchmark_encoding_decoding(
|
||||
dataset,
|
||||
video_path,
|
||||
imgs_dir,
|
||||
encoding_cfg,
|
||||
decoding_benchmarks,
|
||||
num_samples,
|
||||
num_workers,
|
||||
save_frames,
|
||||
)
|
||||
|
||||
# Save intermediate results
|
||||
benchmark_df = pd.DataFrame(benchmark_table, columns=headers)
|
||||
@@ -409,9 +404,9 @@ if __name__ == "__main__":
|
||||
nargs="*",
|
||||
default=[
|
||||
"lerobot/pusht_image",
|
||||
"aliberts/aloha_mobile_shrimp_image",
|
||||
"aliberts/paris_street",
|
||||
"aliberts/kitchen",
|
||||
"lerobot/aloha_mobile_shrimp_image",
|
||||
"lerobot/paris_street",
|
||||
"lerobot/kitchen",
|
||||
],
|
||||
help="Datasets repo-ids to test against. First episodes only are used. Must be images.",
|
||||
)
|
||||
@@ -419,7 +414,7 @@ if __name__ == "__main__":
|
||||
"--vcodec",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["libx264", "hevc", "libsvtav1"],
|
||||
default=["h264", "hevc", "libsvtav1"],
|
||||
help="Video codecs to be tested",
|
||||
)
|
||||
parser.add_argument(
|
||||
@@ -468,7 +463,7 @@ if __name__ == "__main__":
|
||||
"--backends",
|
||||
type=str,
|
||||
nargs="*",
|
||||
default=["pyav", "video_reader"],
|
||||
default=["torchcodec", "pyav"],
|
||||
help="Torchvision decoding backend to be tested.",
|
||||
)
|
||||
parser.add_argument(
|
||||
|
||||
@@ -15,8 +15,6 @@
|
||||
title: Train a Robot with RL
|
||||
- local: hilserl_sim
|
||||
title: Train RL in Simulation
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: multi_gpu_training
|
||||
title: Multi GPU training
|
||||
title: "Tutorials"
|
||||
@@ -40,11 +38,17 @@
|
||||
- local: groot
|
||||
title: NVIDIA GR00T N1.5
|
||||
title: "Policies"
|
||||
- sections:
|
||||
- local: async
|
||||
title: Use Async Inference
|
||||
- local: rtc
|
||||
title: Real-Time Chunking (RTC)
|
||||
title: "Inference"
|
||||
- sections:
|
||||
- local: envhub
|
||||
title: Environments from the Hub
|
||||
- local: il_sim
|
||||
title: Imitation Learning in Sim
|
||||
- local: envhub_leisaac
|
||||
title: Control & Train Robots in Sim (LeIsaac)
|
||||
- local: libero
|
||||
title: Using Libero
|
||||
- local: metaworld
|
||||
@@ -59,6 +63,8 @@
|
||||
title: Implement your own processor
|
||||
- local: processors_robots_teleop
|
||||
title: Processors for Robots and Teleoperators
|
||||
- local: env_processor
|
||||
title: Environment Processors
|
||||
title: "Robot Processors"
|
||||
- sections:
|
||||
- local: so101
|
||||
|
||||
@@ -196,7 +196,7 @@ client_cfg = RobotClientConfig(
|
||||
server_address="localhost:8080",
|
||||
policy_device="mps",
|
||||
policy_type="smolvla",
|
||||
pretrained_name_or_path="fracapuano/smolvla_async",
|
||||
pretrained_name_or_path="<user>/smolvla_async",
|
||||
chunk_size_threshold=0.5,
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
@@ -0,0 +1,418 @@
|
||||
# Environment Processors
|
||||
|
||||
Environment processors are a critical layer in LeRobot's data processing architecture that handle **environment-specific** transformations, separate from policy-specific processing. This separation of concerns enables cleaner code, better modularity, and easier experimentation with different environments and policies.
|
||||
|
||||
## Why Environment Processors?
|
||||
|
||||
When working with different robot environments (LIBERO, MetaWorld, Aloha, etc.), each environment often has unique data formats, coordinate systems, and conventions that need standardization **before** policy processing. Without environment processors, these transformations would be:
|
||||
|
||||
1. **Hardcoded in environment code** - Making it difficult to experiment with different state representations
|
||||
2. **Duplicated across policies** - Each policy would need to handle environment-specific quirks
|
||||
3. **Mixed with policy logic** - Violating separation of concerns and making debugging harder
|
||||
|
||||
Environment processors solve this by providing a **dedicated processing layer** between raw environment observations and policy inputs.
|
||||
|
||||
## The Processing Pipeline
|
||||
|
||||
Here's how data flows through the complete processing pipeline during evaluation:
|
||||
|
||||
```python
|
||||
# In lerobot_eval.py rollout() function:
|
||||
|
||||
# 1. Raw environment observation (numpy arrays, various formats)
|
||||
raw_observation = env.step(action)
|
||||
|
||||
# 2. Convert numpy to torch, normalize images [0,1]
|
||||
observation = preprocess_observation(raw_observation)
|
||||
|
||||
# 3. Add task metadata (for multi-task environments)
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
# 4. ENVIRONMENT-SPECIFIC preprocessing (NEW!)
|
||||
# - Flatten robot states
|
||||
# - Rotate images to match dataset conventions
|
||||
# - Handle environment-specific coordinate systems
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
# 5. POLICY-SPECIFIC preprocessing
|
||||
# - Normalize with dataset statistics
|
||||
# - Add batch dimensions
|
||||
# - Move to GPU
|
||||
# - Tokenize language instructions
|
||||
observation = preprocessor(observation)
|
||||
|
||||
# 6. Policy inference
|
||||
action = policy.select_action(observation)
|
||||
|
||||
# 7. POLICY-SPECIFIC postprocessing
|
||||
# - Unnormalize actions
|
||||
# - Remove batch dimensions
|
||||
action = postprocessor(action)
|
||||
|
||||
# 8. ENVIRONMENT-SPECIFIC postprocessing (NEW!)
|
||||
# - Convert action formats if needed
|
||||
# - Apply environment-specific constraints
|
||||
action_transition = {"action": action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
|
||||
# 9. Execute in environment
|
||||
env.step(action)
|
||||
```
|
||||
|
||||
## The Benefits
|
||||
|
||||
### 1. **Separation of Concerns**
|
||||
|
||||
Environment processors handle transformations specific to the **environment's data format**, while policy processors handle transformations specific to the **model's requirements**.
|
||||
|
||||
```python
|
||||
# ❌ Before: Mixed concerns
|
||||
class LiberoVLAPolicy:
|
||||
def preprocess(self, obs):
|
||||
# Environment-specific: Flatten robot state (shouldn't be in policy!)
|
||||
state = self._flatten_robot_state(obs["robot_state"])
|
||||
# Policy-specific: Normalize with dataset stats
|
||||
state = self.normalizer(state)
|
||||
return state
|
||||
|
||||
# ✅ After: Clear separation
|
||||
# Environment processor: Handles LIBERO's nested robot state
|
||||
env_preprocessor = LiberoProcessorStep() # Flattens robot_state
|
||||
|
||||
# Policy processor: Handles model requirements
|
||||
policy_preprocessor = NormalizerProcessorStep(stats=dataset_stats)
|
||||
```
|
||||
|
||||
### 2. **Flexibility and Reusability**
|
||||
|
||||
The same policy can work with different environment processors, and the same environment processor can work with different policies:
|
||||
|
||||
```python
|
||||
# Use SmolVLA policy with LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
smolvla_preprocessor, smolvla_postprocessor = make_pre_post_processors(smolvla_cfg)
|
||||
|
||||
# Or use ACT policy with the same LIBERO environment
|
||||
libero_preprocessor, libero_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
act_preprocessor, act_postprocessor = make_pre_post_processors(act_cfg)
|
||||
```
|
||||
|
||||
### 3. **Easier Experimentation**
|
||||
|
||||
Want to try different state representations for LIBERO? Just create a new processor:
|
||||
|
||||
```python
|
||||
# Original: 8D state (pos + quat→axisangle + gripper)
|
||||
@ProcessorStepRegistry.register("libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
def _process_observation(self, obs):
|
||||
eef_pos = robot_state["eef"]["pos"] # 3D
|
||||
eef_axisangle = quat2axisangle(quat) # 3D
|
||||
gripper = robot_state["gripper"]["qpos"] # 2D
|
||||
state = torch.cat([eef_pos, eef_axisangle, gripper], dim=-1) # 8D
|
||||
return state
|
||||
|
||||
# Experiment: Add velocity for better control
|
||||
@ProcessorStepRegistry.register("libero_velocity_processor")
|
||||
class LiberoVelocityProcessorStep(ObservationProcessorStep):
|
||||
def _process_observation(self, obs):
|
||||
# Include velocities for 14D state
|
||||
eef_pos = robot_state["eef"]["pos"] # 3D
|
||||
eef_axisangle = quat2axisangle(quat) # 3D
|
||||
eef_vel = robot_state["eef"]["vel"] # 3D (NEW)
|
||||
gripper_pos = robot_state["gripper"]["qpos"] # 2D
|
||||
gripper_vel = robot_state["gripper"]["qvel"] # 3D (NEW)
|
||||
state = torch.cat([eef_pos, eef_axisangle, eef_vel,
|
||||
gripper_pos, gripper_vel], dim=-1) # 14D
|
||||
return state
|
||||
```
|
||||
|
||||
### 4. **Cleaner Environment Code**
|
||||
|
||||
Environments expose **all available data** without needing to know what downstream models will use:
|
||||
|
||||
```python
|
||||
# LIBERO environment exposes full robot state
|
||||
observation = {
|
||||
"pixels": {"image": img, "image2": img2},
|
||||
"robot_state": {
|
||||
"eef": {"pos": ..., "quat": ..., "vel": ..., "mat": ..., "axisangle": ...},
|
||||
"gripper": {"qpos": ..., "qvel": ...},
|
||||
"joints": {"pos": ..., "vel": ...}
|
||||
}
|
||||
}
|
||||
|
||||
# Environment processor decides what to use
|
||||
# Policy processor handles model-specific transformations
|
||||
```
|
||||
|
||||
## Using Environment Processors
|
||||
|
||||
### Factory Function
|
||||
|
||||
The `make_env_pre_post_processors` function follows the same pattern as `make_pre_post_processors` for policies:
|
||||
|
||||
```python
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
from lerobot.envs.configs import LiberoEnv, PushtEnv
|
||||
|
||||
# For LIBERO: Returns LiberoProcessorStep in preprocessor
|
||||
libero_cfg = LiberoEnv(task="libero_spatial", camera_name=["agentview"])
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(libero_cfg)
|
||||
|
||||
# For other environments: Returns identity processors (no-op)
|
||||
pusht_cfg = PushtEnv()
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(pusht_cfg)
|
||||
```
|
||||
|
||||
### Implementation in `envs/factory.py`
|
||||
|
||||
```python
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs
|
||||
"""
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
else:
|
||||
# For all other environments, return an identity preprocessor
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
# Postprocessor is currently identity for all environments
|
||||
# Future: Could add environment-specific action transformations
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
### Integration in Evaluation
|
||||
|
||||
In `lerobot_eval.py`, the environment processors are created once and used throughout:
|
||||
|
||||
```python
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment
|
||||
envs = make_env(cfg.env, n_envs=cfg.eval.batch_size)
|
||||
|
||||
# Create policy
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env)
|
||||
|
||||
# Create policy processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
)
|
||||
|
||||
# Create environment processors (NEW!)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
|
||||
# Run evaluation with both processor types
|
||||
eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor, # Environment-specific
|
||||
env_postprocessor=env_postprocessor, # Environment-specific
|
||||
preprocessor=preprocessor, # Policy-specific
|
||||
postprocessor=postprocessor, # Policy-specific
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
)
|
||||
```
|
||||
|
||||
## Example: LIBERO Environment Processor
|
||||
|
||||
The `LiberoProcessorStep` demonstrates a real-world environment processor:
|
||||
|
||||
```python
|
||||
from lerobot.processor.pipeline import ObservationProcessorStep
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
**State Processing:**
|
||||
- Extracts end-effector position (3D)
|
||||
- Converts quaternion to axis-angle representation (3D)
|
||||
- Extracts gripper joint positions (2D)
|
||||
- Concatenates into 8D state vector
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images 180° to match HuggingFaceVLA/libero convention
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
processed_obs = observation.copy()
|
||||
|
||||
# Process images: Flip 180° for camera convention
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith("observation.images."):
|
||||
img = processed_obs[key]
|
||||
img = torch.flip(img, dims=[2, 3]) # Flip H and W
|
||||
processed_obs[key] = img
|
||||
|
||||
# Process robot_state: Flatten to 8D vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3)
|
||||
eef_quat = robot_state["eef"]["quat"] # (B, 4)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
|
||||
# Concatenate into single state vector
|
||||
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
|
||||
state = state.float()
|
||||
|
||||
processed_obs["observation.state"] = state
|
||||
|
||||
return processed_obs
|
||||
```
|
||||
|
||||
### Why These Transformations?
|
||||
|
||||
1. **Image Rotation**: The HuggingFaceVLA/libero dataset has images rotated 180° from the raw LIBERO simulator. The processor handles this convention mismatch so policies trained on the dataset work seamlessly.
|
||||
|
||||
2. **State Flattening**: The raw LIBERO environment exposes nested dictionaries with all available state information (position, quaternion, velocity, matrix representation, etc.). The processor:
|
||||
- Selects the relevant components (pos, quat, gripper)
|
||||
- Converts quaternion to axis-angle (more suitable for learning)
|
||||
- Flattens to a single 8D vector that policies expect
|
||||
|
||||
3. **Flexibility**: The environment still exposes **all** raw data. If you want to try different state representations (e.g., including velocities, using matrix representation instead of axis-angle), you can create a new processor without modifying the environment code.
|
||||
|
||||
## Adding Environment Processors for New Environments
|
||||
|
||||
To add environment processors for a new environment:
|
||||
|
||||
### 1. Create the Processor Step
|
||||
|
||||
```python
|
||||
# In src/lerobot/processor/env_processor.py
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="myenv_processor")
|
||||
class MyEnvProcessorStep(ObservationProcessorStep):
|
||||
"""Process observations from MyEnv."""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
processed = observation.copy()
|
||||
|
||||
# Your environment-specific transformations
|
||||
if "myenv.specific.state" in processed:
|
||||
state = processed.pop("myenv.specific.state")
|
||||
# Transform to standard format
|
||||
processed["observation.state"] = self._transform_state(state)
|
||||
|
||||
return processed
|
||||
```
|
||||
|
||||
### 2. Update the Factory
|
||||
|
||||
```python
|
||||
# In src/lerobot/envs/factory.py
|
||||
|
||||
def make_env_pre_post_processors(env_cfg: EnvConfig):
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[LiberoProcessorStep()])
|
||||
elif isinstance(env_cfg, MyEnvConfig) or "myenv" in env_cfg.type:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[MyEnvProcessorStep()])
|
||||
else:
|
||||
preprocessor = PolicyProcessorPipeline(steps=[])
|
||||
|
||||
postprocessor = PolicyProcessorPipeline(steps=[])
|
||||
return preprocessor, postprocessor
|
||||
```
|
||||
|
||||
### 3. Use in Evaluation
|
||||
|
||||
No changes needed! The evaluation script automatically uses the appropriate processor:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=lerobot/my_policy \
|
||||
--env.type=myenv \ # Automatically uses MyEnvProcessorStep
|
||||
--eval.n_episodes=10
|
||||
```
|
||||
|
||||
## Future: Environment Postprocessors
|
||||
|
||||
Currently, postprocessors are identity (no-op) for all environments. Future use cases include:
|
||||
|
||||
### Action Space Transformations
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class MyEnvActionPostprocessor(ProcessorStep):
|
||||
"""Convert policy actions to environment-specific format."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition["action"]
|
||||
|
||||
# Example: Convert from Cartesian to joint space
|
||||
if self.action_space == "joint":
|
||||
action = self.ik_solver(action)
|
||||
|
||||
# Example: Apply environment-specific safety limits
|
||||
action = torch.clamp(action, self.min_action, self.max_action)
|
||||
|
||||
transition["action"] = action
|
||||
return transition
|
||||
```
|
||||
|
||||
### Coordinate System Conversions
|
||||
|
||||
```python
|
||||
@dataclass
|
||||
class CoordinateTransformPostprocessor(ProcessorStep):
|
||||
"""Transform actions between coordinate systems."""
|
||||
|
||||
def __call__(self, transition: EnvTransition) -> EnvTransition:
|
||||
action = transition["action"]
|
||||
|
||||
# Example: Policy outputs in world frame, env expects base frame
|
||||
action = self.world_to_base_transform(action)
|
||||
|
||||
transition["action"] = action
|
||||
return transition
|
||||
```
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Keep environment processors simple**: They should only handle environment-specific data format issues, not complex learning-related transformations.
|
||||
|
||||
2. **Use policy processors for model requirements**: Normalization, batching, device placement, and tokenization belong in policy processors.
|
||||
|
||||
3. **Expose all data from environments**: Let processors decide what to use rather than hardcoding choices in the environment.
|
||||
|
||||
4. **Document conventions**: Clearly document any coordinate system conventions, camera orientations, or data formats that your processor handles.
|
||||
|
||||
5. **Test independently**: Environment processors should be testable without loading full policies or environments.
|
||||
|
||||
## Summary
|
||||
|
||||
Environment processors provide a **clean separation** between environment-specific data transformations and policy-specific model requirements. This architecture:
|
||||
|
||||
- ✅ Enables easy experimentation with different state representations
|
||||
- ✅ Allows policies to work seamlessly across different environments
|
||||
- ✅ Keeps environment code focused on simulation/hardware interface
|
||||
- ✅ Makes processor pipelines more maintainable and debuggable
|
||||
- ✅ Follows the single responsibility principle
|
||||
|
||||
The key insight: **Environments define data formats, processors standardize them, policies consume standardized data.** Each layer has a clear, focused responsibility.
|
||||
@@ -0,0 +1,301 @@
|
||||
# LeIsaac × LeRobot EnvHub
|
||||
|
||||
LeRobot EnvHub now supports **imitation learning in simulation** with LeIsaac.
|
||||
Spin up everyday manipulation tasks, teleoperate the robot, collect demos, push them to the Hub, and train policies in LeRobot — all in one loop.
|
||||
|
||||
[LeIsaac](https://github.com/LightwheelAI/leisaac) integrates with IsaacLab and the SO101 Leader/Follower setup to provide:
|
||||
|
||||
- 🕹️ **Teleoperation-first workflows** for data collection
|
||||
- 📦 **Built-in data conversion** ready for LeRobot training
|
||||
- 🤖 **Everyday skills** like picking oranges, lifting cubes, cleaning tables, and folding cloth
|
||||
- ☁️ **Ongoing upgrades** from [LightWheel](https://lightwheel.ai/): cloud simulation, EnvHub support, Sim2Real tooling, and more
|
||||
|
||||
Below you’ll find the currently supported LeIsaac tasks exposed through LeRobot EnvHub.
|
||||
|
||||
# Available Environments
|
||||
|
||||
The following table lists all available tasks and environments in LeIsaac x LeRobot Envhub. You can also get the latest list of environments by running the following command:
|
||||
|
||||
```bash
|
||||
python scripts/environments/list_envs.py
|
||||
```
|
||||
|
||||
| Task | Environment ID | Task Description | Related Robot |
|
||||
| :-------------------------------------------------------------------------------------------------------------------------------------------------------------- | :-------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | :------------------------------------------------------------------------------------------------------------------------- | :--------------------------------------------------------- |
|
||||
| <video src="https://github.com/user-attachments/assets/466eddff-f720-4f99-94d5-5e123e4c302c" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-PickOrange-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/pick_orange_env_cfg.py)<br /><br />[LeIsaac-SO101-PickOrange-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/pick_orange/direct/pick_orange_env.py) | Pick three oranges and put them into the plate, then reset the arm to rest state. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/1e4eb83a-0b38-40fb-a0b2-ddb0fe201e6d" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-LiftCube-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/lift_cube_env_cfg.py)<br /><br />[LeIsaac-SO101-LiftCube-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/lift_cube/direct/lift_cube_env.py) | Lift the red cube up. | Single-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e49d8f1c-dcc9-412b-a88f-100680d8a45b" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-CleanToyTable-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/clean_toy_table_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-CleanToyTable-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/clean_toy_table/direct/clean_toy_table_bi_arm_env.py) | Pick two letter e objects into the box, and reset the arm to rest state. | Single-Arm SO101 Follower<br /><br />Bi-Arm SO101 Follower |
|
||||
| <video src="https://github.com/user-attachments/assets/e29a0f8a-9286-4ce6-b45d-342c3d3ba754" autoplay loop muted playsinline style="max-width: 300px;"></video> | [LeIsaac-SO101-FoldCloth-BiArm-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/fold_cloth_bi_arm_env_cfg.py)<br /><br />[LeIsaac-SO101-FoldCloth-BiArm-Direct-v0](https://github.com/LightwheelAI/leisaac/blob/main/source/leisaac/leisaac/tasks/fold_cloth/direct/fold_cloth_bi_arm_env.py) | Fold the cloth, and reset the arm to rest state.<br /><br />_Note: Only the DirectEnv support check_success in this task._ | Bi-Arm SO101 Follower |
|
||||
|
||||
# Load LeIsaac directly in LeRobot with one line of code
|
||||
|
||||
> EnvHub: Share LeIsaac environments through HuggingFace
|
||||
|
||||
[EnvHub](https://huggingface.co/docs/lerobot/envhub) is our reproducible environment hub, spin up a packaged simulation with one line, experiment immediately, and publish your own tasks for the community.
|
||||
|
||||
LeIsaac offers EnvHub support so you can consume or share tasks with only a few commands.
|
||||
|
||||
<video
|
||||
controls
|
||||
src="https://github.com/user-attachments/assets/687666f5-ebe0-421d-84a0-eb86116ac5f8"
|
||||
style={{ width: "100%", maxWidth: "960px", borderRadius: "8px" }}
|
||||
/>
|
||||
|
||||
## How to get started, environment Setup
|
||||
|
||||
Run the following commands to setup your code environments:
|
||||
|
||||
```bash
|
||||
# Refer to Getting Started/Installation to install leisaac firstly
|
||||
conda create -n leisaac_envhub python=3.11
|
||||
conda activate leisaac_envhub
|
||||
|
||||
conda install -c "nvidia/label/cuda-12.8.1" cuda-toolkit
|
||||
pip install -U torch==2.7.0 torchvision==0.22.0 --index-url https://download.pytorch.org/whl/cu128
|
||||
pip install 'leisaac[isaaclab] @ git+https://github.com/LightwheelAI/leisaac.git#subdirectory=source/leisaac' --extra-index-url https://pypi.nvidia.com
|
||||
|
||||
# Install lerobot
|
||||
pip install lerobot==0.4.1
|
||||
|
||||
# Fix numpy version
|
||||
pip install numpy==1.26.0
|
||||
```
|
||||
|
||||
## Usage Example
|
||||
|
||||
EnvHub exposes every LeIsaac-supported task in a uniform interface. The examples below load `so101_pick_orange` and demonstrate a random-action rollout and an interactive teleoperation.
|
||||
|
||||
### Random Action
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_random_action.py
|
||||
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# Use it like any gym environment
|
||||
obs, info = env.reset()
|
||||
|
||||
while True:
|
||||
action = torch.tensor(env.action_space.sample())
|
||||
obs, reward, terminated, truncated, info = env.step(action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
env.close()
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_random_action.py
|
||||
```
|
||||
|
||||
You should see the SO101 arm swinging under purely random commands.
|
||||
|
||||
### Teleoperation
|
||||
|
||||
LeRobot’s teleoperation stack can drive the simulated arm.
|
||||
|
||||
Connect the SO101 Leader controller, run the calibration command below.
|
||||
|
||||
```bash
|
||||
lerobot-calibrate \
|
||||
--teleop.type=so101_leader \
|
||||
--teleop.port=/dev/ttyACM0 \
|
||||
--teleop.id=leader
|
||||
```
|
||||
|
||||
And then launch the teleop script.
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
# envhub_teleop_example.py
|
||||
|
||||
import logging
|
||||
import time
|
||||
import gymnasium as gym
|
||||
|
||||
from dataclasses import asdict, dataclass
|
||||
from pprint import pformat
|
||||
|
||||
from lerobot.teleoperators import ( # noqa: F401
|
||||
Teleoperator,
|
||||
TeleoperatorConfig,
|
||||
make_teleoperator_from_config,
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
|
||||
@dataclass
|
||||
class TeleoperateConfig:
|
||||
teleop: TeleoperatorConfig
|
||||
env_name: str = "so101_pick_orange"
|
||||
fps: int = 60
|
||||
|
||||
|
||||
@dataclass
|
||||
class EnvWrap:
|
||||
env: gym.Env
|
||||
|
||||
|
||||
def make_env_from_leisaac(env_name: str = "so101_pick_orange"):
|
||||
envs_dict = make_env(
|
||||
f'LightwheelAI/leisaac_env:envs/{env_name}.py',
|
||||
n_envs=1,
|
||||
trust_remote_code=True
|
||||
)
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
return env
|
||||
|
||||
|
||||
def teleop_loop(teleop: Teleoperator, env: gym.Env, fps: int):
|
||||
from leisaac.devices.action_process import preprocess_device_action
|
||||
from leisaac.assets.robots.lerobot import SO101_FOLLOWER_MOTOR_LIMITS
|
||||
from leisaac.utils.env_utils import dynamic_reset_gripper_effort_limit_sim
|
||||
|
||||
env_wrap = EnvWrap(env=env)
|
||||
|
||||
obs, info = env.reset()
|
||||
while True:
|
||||
loop_start = time.perf_counter()
|
||||
if env.cfg.dynamic_reset_gripper_effort_limit:
|
||||
dynamic_reset_gripper_effort_limit_sim(env, 'so101leader')
|
||||
|
||||
raw_action = teleop.get_action()
|
||||
processed_action = preprocess_device_action(
|
||||
dict(
|
||||
so101_leader=True,
|
||||
joint_state={
|
||||
k.removesuffix(".pos"): v for k, v in raw_action.items()},
|
||||
motor_limits=SO101_FOLLOWER_MOTOR_LIMITS),
|
||||
env_wrap
|
||||
)
|
||||
obs, reward, terminated, truncated, info = env.step(processed_action)
|
||||
if terminated or truncated:
|
||||
obs, info = env.reset()
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
|
||||
|
||||
def teleoperate(cfg: TeleoperateConfig):
|
||||
init_logging()
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
teleop = make_teleoperator_from_config(cfg.teleop)
|
||||
env = make_env_from_leisaac(cfg.env_name)
|
||||
|
||||
teleop.connect()
|
||||
if hasattr(env, 'initialize'):
|
||||
env.initialize()
|
||||
try:
|
||||
teleop_loop(teleop=teleop, env=env, fps=cfg.fps)
|
||||
except KeyboardInterrupt:
|
||||
pass
|
||||
finally:
|
||||
teleop.disconnect()
|
||||
env.close()
|
||||
|
||||
|
||||
def main():
|
||||
teleoperate(TeleoperateConfig(
|
||||
teleop=so101_leader.SO101LeaderConfig(
|
||||
port="/dev/ttyACM0",
|
||||
id='leader',
|
||||
use_degrees=False,
|
||||
),
|
||||
env_name="so101_pick_orange",
|
||||
fps=60,
|
||||
))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
```
|
||||
|
||||
</details>
|
||||
|
||||
```bash
|
||||
python envhub_teleop_example.py
|
||||
```
|
||||
|
||||
Running the script lets you operate the simulated arm using the physical Leader device.
|
||||
|
||||
## ☁️ Cloud Simulation (No GPU Required)
|
||||
|
||||
Don’t have a local GPU or the right drivers? No problem! You can run LeIsaac entirely in the cloud with zero setup.
|
||||
LeIsaac works out-of-the-box on **NVIDIA Brev**, giving you a fully configured environment directly in your browser.
|
||||
|
||||
👉 **Start here:** [https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev](https://lightwheelai.github.io/leisaac/docs/cloud_simulation/nvidia_brev)
|
||||
|
||||
Once your instance is deployed, simply open the link for **port 80 (HTTP)** to launch **Visual Studio Code Server** (default password: `password`). From there, you can run simulations, edit code, and visualize IsaacLab environments — all from your web browser.
|
||||
|
||||
**No GPU, no drivers, no local installation. Just click and run.**
|
||||
|
||||
## Additional Notes
|
||||
|
||||
We keep EnvHub coverage aligned with the LeIsaac task. Currently supported:
|
||||
|
||||
- `so101_pick_orange`
|
||||
- `so101_lift_cube`
|
||||
- `so101_clean_toytable`
|
||||
- `bi_so101_fold_cloth`
|
||||
|
||||
Switch tasks by targeting a different script when calling `make_env`, for example:
|
||||
|
||||
```python
|
||||
envs_dict_pick_orange = make_env("LightwheelAI/leisaac_env:envs/so101_pick_orange.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_lift_cube = make_env("LightwheelAI/leisaac_env:envs/so101_lift_cube.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_clean_toytable = make_env("LightwheelAI/leisaac_env:envs/so101_clean_toytable.py", n_envs=1, trust_remote_code=True)
|
||||
envs_dict_fold_cloth = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
```
|
||||
|
||||
Note: when working with `bi_so101_fold_cloth`, call `initialize()` immediately after retrieving the env before performing any other operations:
|
||||
|
||||
<details>
|
||||
<summary>Click to expand code example</summary>
|
||||
|
||||
```python
|
||||
import torch
|
||||
from lerobot.envs.factory import make_env
|
||||
|
||||
# Load from the hub
|
||||
envs_dict = make_env("LightwheelAI/leisaac_env:envs/bi_so101_fold_cloth.py", n_envs=1, trust_remote_code=True)
|
||||
|
||||
# Access the environment
|
||||
suite_name = next(iter(envs_dict))
|
||||
sync_vector_env = envs_dict[suite_name][0]
|
||||
# retrieve the isaac environment from the sync vector env
|
||||
env = sync_vector_env.envs[0].unwrapped
|
||||
|
||||
# NOTE: initialize() first
|
||||
env.initialize()
|
||||
|
||||
# other operation with env...
|
||||
```
|
||||
|
||||
</details>
|
||||
@@ -393,7 +393,7 @@ import time
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
episode_idx = 0
|
||||
@@ -415,7 +415,7 @@ for idx in range(dataset.num_frames):
|
||||
}
|
||||
robot.send_action(action)
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
robot.disconnect()
|
||||
```
|
||||
|
||||
@@ -1,220 +0,0 @@
|
||||
# Imitation Learning in Sim
|
||||
|
||||
This tutorial will explain how to train a neural network to control a robot in simulation with imitation learning.
|
||||
|
||||
**You'll learn:**
|
||||
|
||||
1. How to record a dataset in simulation with [gym-hil](https://github.com/huggingface/gym-hil) and visualize the dataset.
|
||||
2. How to train a policy using your data.
|
||||
3. How to evaluate your policy in simulation and visualize the results.
|
||||
|
||||
For the simulation environment we use the same [repo](https://github.com/huggingface/gym-hil) that is also being used by the Human-In-the-Loop (HIL) reinforcement learning algorithm.
|
||||
This environment is based on [MuJoCo](https://mujoco.org) and allows you to record datasets in LeRobotDataset format.
|
||||
Teleoperation is easiest with a controller like the Logitech F710, but you can also use your keyboard if you are up for the challenge.
|
||||
|
||||
## Installation
|
||||
|
||||
First, install the `gym_hil` package within the LeRobot environment, go to your LeRobot folder and run this command:
|
||||
|
||||
```bash
|
||||
pip install -e ".[hilserl]"
|
||||
```
|
||||
|
||||
## Teleoperate and Record a Dataset
|
||||
|
||||
To use `gym_hil` with LeRobot, you need to use a configuration file. An example config file can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/env_config.json).
|
||||
|
||||
To teleoperate and collect a dataset, we need to modify this config file. Here's an example configuration for imitation learning data collection:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_gym",
|
||||
"root": null,
|
||||
"task": "pick_cube",
|
||||
"num_episodes_to_record": 30,
|
||||
"replay_episode": null,
|
||||
"push_to_hub": true
|
||||
},
|
||||
"mode": "record",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Key configuration points:
|
||||
|
||||
- Set your `repo_id` in the `dataset` section: `"repo_id": "your_username/il_gym"`
|
||||
- Set `num_episodes_to_record: 30` to collect 30 demonstration episodes
|
||||
- Ensure `mode` is set to `"record"`
|
||||
- If you don't have an NVIDIA GPU, change `"device": "cuda"` to `"mps"` for macOS or `"cpu"`
|
||||
- To use keyboard instead of gamepad, change `"task"` to `"PandaPickCubeKeyboard-v0"`
|
||||
|
||||
Then we can run this command to start:
|
||||
|
||||
<hfoptions id="teleop_sim">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.gym_manipulator --config_path path/to/env_config_gym_hil_il.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
Once rendered you can teleoperate the robot with the gamepad or keyboard, below you can find the gamepad/keyboard controls.
|
||||
|
||||
Note that to teleoperate the robot you have to hold the "Human Take Over Pause Policy" Button `RB` to enable control!
|
||||
|
||||
**Gamepad Controls**
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/gamepad_guide.jpg?raw=true"
|
||||
alt="Figure shows the control mappings on a Logitech gamepad."
|
||||
title="Gamepad Control Mapping"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Gamepad button mapping for robot control and episode management</i>
|
||||
</p>
|
||||
|
||||
**Keyboard controls**
|
||||
|
||||
For keyboard controls use the `spacebar` to enable control and the following keys to move the robot:
|
||||
|
||||
```bash
|
||||
Arrow keys: Move in X-Y plane
|
||||
Shift and Shift_R: Move in Z axis
|
||||
Right Ctrl and Left Ctrl: Open and close gripper
|
||||
ESC: Exit
|
||||
```
|
||||
|
||||
## Visualize a dataset
|
||||
|
||||
If you uploaded your dataset to the hub you can [visualize your dataset online](https://huggingface.co/spaces/lerobot/visualize_dataset) by copy pasting your repo id.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/dataset_visualizer_sim.png"
|
||||
alt="Figure shows the dataset visualizer"
|
||||
title="Dataset visualization"
|
||||
width="100%"
|
||||
></img>
|
||||
</p>
|
||||
<p align="center">
|
||||
<i>Dataset visualizer</i>
|
||||
</p>
|
||||
|
||||
## Train a policy
|
||||
|
||||
To train a policy to control your robot, use the [`lerobot-train`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/scripts/train.py) script. A few arguments are required. Here is an example command:
|
||||
|
||||
```bash
|
||||
lerobot-train \
|
||||
--dataset.repo_id=${HF_USER}/il_gym \
|
||||
--policy.type=act \
|
||||
--output_dir=outputs/train/il_sim_test \
|
||||
--job_name=il_sim_test \
|
||||
--policy.device=cuda \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
Let's explain the command:
|
||||
|
||||
1. We provided the dataset as argument with `--dataset.repo_id=${HF_USER}/il_gym`.
|
||||
2. We provided the policy with `policy.type=act`. This loads configurations from [`configuration_act.py`](https://github.com/huggingface/lerobot/blob/main/src/lerobot/policies/act/configuration_act.py). Importantly, this policy will automatically adapt to the number of motor states, motor actions and cameras of your robot (e.g. `laptop` and `phone`) which have been saved in your dataset.
|
||||
3. We provided `policy.device=cuda` since we are training on a Nvidia GPU, but you could use `policy.device=mps` to train on Apple silicon.
|
||||
4. We provided `wandb.enable=true` to use [Weights and Biases](https://docs.wandb.ai/quickstart) for visualizing training plots. This is optional but if you use it, make sure you are logged in by running `wandb login`.
|
||||
|
||||
Training should take several hours, 100k steps (which is the default) will take about 1h on Nvidia A100. You will find checkpoints in `outputs/train/il_sim_test/checkpoints`.
|
||||
|
||||
#### Train using Collab
|
||||
|
||||
If your local computer doesn't have a powerful GPU you could utilize Google Collab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||
|
||||
#### Upload policy checkpoints
|
||||
|
||||
Once training is done, upload the latest checkpoint with:
|
||||
|
||||
```bash
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test \
|
||||
outputs/train/il_sim_test/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
You can also upload intermediate checkpoints with:
|
||||
|
||||
```bash
|
||||
CKPT=010000
|
||||
huggingface-cli upload ${HF_USER}/il_sim_test${CKPT} \
|
||||
outputs/train/il_sim_test/checkpoints/${CKPT}/pretrained_model
|
||||
```
|
||||
|
||||
## Evaluate your policy in Sim
|
||||
|
||||
To evaluate your policy we have to use a configuration file. An example can be found [here](https://huggingface.co/datasets/lerobot/config_examples/resolve/main/sim_il/eval_config.json).
|
||||
|
||||
Here's an example evaluation configuration:
|
||||
|
||||
```json
|
||||
{
|
||||
"env": {
|
||||
"type": "gym_manipulator",
|
||||
"name": "gym_hil",
|
||||
"task": "PandaPickCubeGamepad-v0",
|
||||
"fps": 10
|
||||
},
|
||||
"dataset": {
|
||||
"repo_id": "your_username/il_sim_dataset",
|
||||
"dataset_root": null,
|
||||
"task": "pick_cube"
|
||||
},
|
||||
"pretrained_policy_name_or_path": "your_username/il_sim_model",
|
||||
"device": "cuda"
|
||||
}
|
||||
```
|
||||
|
||||
Make sure to replace:
|
||||
|
||||
- `repo_id` with the dataset you trained on (e.g., `your_username/il_sim_dataset`)
|
||||
- `pretrained_policy_name_or_path` with your model ID (e.g., `your_username/il_sim_model`)
|
||||
|
||||
Then you can run this command to visualize your trained policy
|
||||
|
||||
<hfoptions id="eval_policy">
|
||||
<hfoption id="Linux">
|
||||
|
||||
```bash
|
||||
python -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
<hfoption id="MacOS">
|
||||
|
||||
```bash
|
||||
mjpython -m lerobot.rl.eval_policy --config_path=path/to/eval_config_gym_hil.json
|
||||
```
|
||||
|
||||
</hfoption>
|
||||
</hfoptions>
|
||||
|
||||
> [!WARNING]
|
||||
> While the main workflow of training ACT in simulation is straightforward, there is significant room for exploring how to set up the task, define the initial state of the environment, and determine the type of data required during collection to learn the most effective policy. If your trained policy doesn't perform well, investigate the quality of the dataset it was trained on using our visualizers, as well as the action values and various hyperparameters related to ACT and the simulation.
|
||||
|
||||
Congrats 🎉, you have finished this tutorial. If you want to continue with using LeRobot in simulation follow this [Tutorial on reinforcement learning in sim with HIL-SERL](https://huggingface.co/docs/lerobot/hilserl_sim)
|
||||
|
||||
> [!TIP]
|
||||
> If you have any questions or need help, please reach out on [Discord](https://discord.com/invite/s3KuuzsPFb).
|
||||
@@ -82,7 +82,7 @@ For a full list of optional dependencies, see:
|
||||
https://pypi.org/project/lerobot/
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero or pi, you will have to do: `pip install "lerobot[pi,libero]@git+https://github.com/huggingface/lerobot.git"`
|
||||
> For lerobot 0.4.0, if you want to install pi, you will have to do: `pip install "lerobot[pi]@git+https://github.com/huggingface/lerobot.git"`
|
||||
|
||||
### Troubleshooting
|
||||
|
||||
|
||||
@@ -28,11 +28,6 @@ LIBERO is now part of our **multi-eval supported simulation**, meaning you can b
|
||||
To Install LIBERO, after following LeRobot official instructions, just do:
|
||||
`pip install -e ".[libero]"`
|
||||
|
||||
> [!NOTE]
|
||||
> For lerobot 0.4.0, if you want to install libero tag, you will have to do: `pip install "lerobot[libero]@git+https://github.com/huggingface/lerobot.git"`.
|
||||
>
|
||||
> This will be solved in the next patch release
|
||||
|
||||
### Single-suite evaluation
|
||||
|
||||
Evaluate a policy on one LIBERO suite:
|
||||
|
||||
@@ -0,0 +1,188 @@
|
||||
# Real-Time Chunking (RTC)
|
||||
|
||||
Real-Time Chunking (RTC) is an inference-time method that allows large, flow-matching based robotic policies, such as [Pi0](./pi0), [Pi0.5](./pi05), and [SmolVLA](./smolvla), to produce smooth, continuous, and reactive motion despite having high inference latency.
|
||||
|
||||
These policies generate chunks of future actions (e.g., 50 steps at a time) instead of single actions.
|
||||
Because the models are large, producing each chunk takes longer than the time it takes the robot to execute it.
|
||||
Naively executing chunks leads to problems such as pauses, jerky transitions, or sudden changes in strategy whenever the next chunk arrives late or disagrees with the previously executed actions.
|
||||
|
||||
RTC solves this by asynchronously generating the next chunk while the robot continues executing the current one, and by guiding the new chunk so it aligns smoothly with the portion of the previous chunk that has already been executed.
|
||||
|
||||
## How RTC Works (simplified)
|
||||
|
||||
RTC lets the robot think ahead while it’s still moving. When the robot is carrying out one chunk of actions, RTC starts creating the next chunk early.
|
||||
But since the robot has already moved a bit by the time the new chunk is ready, RTC has to make sure the new chunk still lines up smoothly with what the robot is currently doing.
|
||||
|
||||
To do this, RTC treats the beginning of the new chunk like an inpainting or “fill-in-the-gaps” problem:
|
||||
it gently adjusts the first part of the new chunk so it blends naturally with the robot’s ongoing motion. The result is no pauses, no sudden jumps.
|
||||
|
||||
In technical terms, RTC adds a guidance term to the flow-matching denoising process that forces the overlapping timesteps of the new chunk to stay close to the executed portion of the previous chunk, typically using a soft transition mask.
|
||||
|
||||
## Quick Start
|
||||
|
||||
### Installation
|
||||
|
||||
RTC is built into LeRobot. Just install the policy dependencies you need:
|
||||
|
||||
```bash
|
||||
# For Pi0 or Pi0.5
|
||||
pip install -e ".[pi]"
|
||||
|
||||
# For SmolVLA
|
||||
pip install -e ".[smolvla]"
|
||||
```
|
||||
|
||||
### Using RTC with Pi0
|
||||
|
||||
You can find a complete reference implementation in [eval_with_real_robot.py](examples/rtc/eval_with_real_robot.py).
|
||||
The snippet below provides a simplified pseudo-example of how RTC operates with Pi0 in your pipeline:
|
||||
|
||||
```python
|
||||
from lerobot.policies.pi0 import PI0Policy, PI0Config
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
|
||||
# Load Pi0 with RTC enabled
|
||||
policy_cfg = PI0Config()
|
||||
|
||||
# Enable RTC
|
||||
policy_cfg.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10, # How many steps to blend with previous chunk
|
||||
max_guidance_weight=10.0, # How strongly to enforce consistency
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP, # Exponential blend
|
||||
)
|
||||
|
||||
# Load the policy
|
||||
policy = PI0Policy.from_pretrained("lerobot/pi0_base", policy_cfg=policy_cfg, device="cuda")
|
||||
|
||||
# Now use predict_action_chunk with RTC parameters
|
||||
inference_delay = 4 # How many steps of inference latency, this values should be calculated based on the inference latency of the policy
|
||||
|
||||
# Initialize the action queue
|
||||
action_queue = ActionQueue(policy_cfg.rtc_config)
|
||||
|
||||
# Start in a separate thread with the following function
|
||||
def get_actions():
|
||||
while True:
|
||||
if should_get_actions:
|
||||
|
||||
prev_actions = action_queue.get_left_over()
|
||||
obs = get_robot_observations(robot)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
actions, actions, inference_delay
|
||||
)
|
||||
|
||||
for step in range(num_steps):
|
||||
action = action_queue.get()
|
||||
|
||||
# Execute the first N actions
|
||||
execute_actions(action)
|
||||
```
|
||||
|
||||
## Key Parameters
|
||||
|
||||
`RTCConfig` has the following parameters to tune:
|
||||
|
||||
**`execution_horizon`**: How many timesteps from the previous chunk to maintain consistency with. Higher values mean smoother transitions but potentially less reactivity.
|
||||
|
||||
Typical values: 8-12 steps
|
||||
|
||||
```python
|
||||
RTCConfig(execution_horizon=10)
|
||||
```
|
||||
|
||||
**`max_guidance_weight`**: How strongly to enforce consistency with the previous chunk. This is a hyperparameter that can be tuned to balance the smoothness of the transitions and the reactivity of the policy. For 10 steps flow matching (SmolVLA, Pi0, Pi0.5), a value of 10.0 is a optimal value.
|
||||
|
||||
**`prefix_attention_schedule`**: How to weight consistency across the overlap region.
|
||||
|
||||
- `LINEAR`: Linear decay from inference_delay to execution_horizon
|
||||
- `EXP`: Exponential decay (recommended for getting started)
|
||||
- `ONES`: Full weight across entire execution_horizon
|
||||
- `ZEROS`: Binary (full weight up to inference_delay, then zero)
|
||||
|
||||
**`inference_delay`**: How many timesteps of inference latency your system has. This is passed to `predict_action_chunk()` rather than the config, since it may vary at runtime.
|
||||
|
||||
## Testing RTC Offline
|
||||
|
||||
Before running on a real robot, test RTC with dataset samples to visualize how it works:
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=10 \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
The script generates a visualization of the denoising process, comparing standard generation (left) with RTC (right). In the RTC plots, you can see how the first few steps (blue/purple lines) are guided to match the red ground truth trajectory (previous chunk's tail), ensuring a smooth transition between chunks.
|
||||
|
||||
<p align="center">
|
||||
<img
|
||||
src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/lerobot/flow_matching.png"
|
||||
alt="Denoising steps with and without RTC"
|
||||
width="100%"
|
||||
/>
|
||||
</p>
|
||||
|
||||
## Testing RTC with a Real Robot
|
||||
|
||||
```bash
|
||||
python examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=${HF_USERNAME}/policy_repo_id \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120 \
|
||||
--device=cuda
|
||||
```
|
||||
|
||||
## How It Differs from the Async Inference in LeRobot
|
||||
|
||||
Both RTC and [async inference](./async) improve real-time robot control, but they solve different problems.
|
||||
|
||||
| Aspect | Async Inference | RTC |
|
||||
| ------------- | -------------------------------------------------------------------------- | --------------------------------------------------- |
|
||||
| **Problem** | Idle frames while waiting for inference | Discontinuities between action chunks |
|
||||
| **Solution** | Decouple prediction from execution | Guide new chunks to continue smoothly from previous |
|
||||
| **Benefit** | No waiting, continuous action | Smooth transitions, natural motion |
|
||||
| **Best Used** | Async inference is best used with large models with high inference latency | Flow-matching based policies |
|
||||
|
||||
**Use both together** for maximum smoothness and reactivity!
|
||||
|
||||
## Advanced: Debug Tracking
|
||||
|
||||
RTC includes built-in debug tracking to help you understand what's happening during inference:
|
||||
|
||||
```python
|
||||
# Enable debug tracking
|
||||
policy_cfg.rtc_config.debug = True
|
||||
policy_cfg.rtc_config.debug_maxlen = 100
|
||||
|
||||
# After inference, access debug data
|
||||
debug_data = policy.rtc_processor.get_debug_data()
|
||||
|
||||
# Visualize denoising steps, corrections, etc.
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
visualizer = RTCDebugVisualizer()
|
||||
# ... create plots
|
||||
```
|
||||
|
||||
See `examples/rtc/eval_dataset.py` for a complete example of visualization.
|
||||
|
||||
## References
|
||||
|
||||
- [Smooth-As-Butter Robot Policies](https://alexander-soare.github.io/robotics/2025/08/05/smooth-as-butter-robot-policies.html) - Excellent technical explanation with real robot results
|
||||
- [Physical Intelligence - Real-Time Chunking](https://www.physicalintelligence.company/research/real_time_chunking) - Original paper and research
|
||||
- [Kinetix RTC Implementation](https://github.com/Physical-Intelligence/real-time-chunking-kinetix) - Reference implementation from Physical Intelligence
|
||||
@@ -45,7 +45,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -97,7 +97,7 @@ def replay(cfg: ReplayConfig):
|
||||
robot.send_action(action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -34,105 +34,106 @@ from huggingface_hub import HfApi
|
||||
import lerobot
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
def main():
|
||||
# We ported a number of existing datasets ourselves, use this to see the list:
|
||||
print("List of available datasets:")
|
||||
pprint(lerobot.available_datasets)
|
||||
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
# You can also browse through the datasets created/ported by the community on the hub using the hub api:
|
||||
hub_api = HfApi()
|
||||
repo_ids = [info.id for info in hub_api.list_datasets(task_categories="robotics", tags=["LeRobot"])]
|
||||
pprint(repo_ids)
|
||||
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
# Or simply explore them in your web browser directly at:
|
||||
# https://huggingface.co/datasets?other=LeRobot
|
||||
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
# Let's take this one for this example
|
||||
repo_id = "lerobot/aloha_mobile_cabinet"
|
||||
# We can have a look and fetch its metadata to know more about it:
|
||||
ds_meta = LeRobotDatasetMetadata(repo_id)
|
||||
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
# By instantiating just this class, you can quickly access useful information about the content and the
|
||||
# structure of the dataset without downloading the actual data yet (only metadata files — which are
|
||||
# lightweight).
|
||||
print(f"Total number of episodes: {ds_meta.total_episodes}")
|
||||
print(f"Average number of frames per episode: {ds_meta.total_frames / ds_meta.total_episodes:.3f}")
|
||||
print(f"Frames per second used during data collection: {ds_meta.fps}")
|
||||
print(f"Robot type: {ds_meta.robot_type}")
|
||||
print(f"keys to access images from cameras: {ds_meta.camera_keys=}\n")
|
||||
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
print("Tasks:")
|
||||
print(ds_meta.tasks)
|
||||
print("Features:")
|
||||
pprint(ds_meta.features)
|
||||
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
# You can also get a short summary by simply printing the object:
|
||||
print(ds_meta)
|
||||
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# You can then load the actual dataset from the hub.
|
||||
# Either load any subset of episodes:
|
||||
dataset = LeRobotDataset(repo_id, episodes=[0, 10, 11, 23])
|
||||
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
# And see how many frames you have:
|
||||
print(f"Selected episodes: {dataset.episodes}")
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
# Or simply load the entire dataset:
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
print(f"Number of episodes selected: {dataset.num_episodes}")
|
||||
print(f"Number of frames selected: {dataset.num_frames}")
|
||||
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
# The previous metadata class is contained in the 'meta' attribute of the dataset:
|
||||
print(dataset.meta)
|
||||
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
# LeRobotDataset actually wraps an underlying Hugging Face dataset
|
||||
# (see https://huggingface.co/docs/datasets for more information).
|
||||
print(dataset.hf_dataset)
|
||||
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
# LeRobot datasets also subclasses PyTorch datasets so you can do everything you know and love from working
|
||||
# with the latter, like iterating through the dataset.
|
||||
# The __getitem__ iterates over the frames of the dataset. Since our datasets are also structured by
|
||||
# episodes, you can access the frame indices of any episode using dataset.meta.episodes. Here, we access
|
||||
# frame indices associated to the first episode:
|
||||
episode_index = 0
|
||||
from_idx = dataset.meta.episodes["dataset_from_index"][episode_index]
|
||||
to_idx = dataset.meta.episodes["dataset_to_index"][episode_index]
|
||||
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
# Then we grab all the image frames from the first camera:
|
||||
camera_key = dataset.meta.camera_keys[0]
|
||||
frames = [dataset[idx][camera_key] for idx in range(from_idx, to_idx)]
|
||||
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
# The objects returned by the dataset are all torch.Tensors
|
||||
print(type(frames[0]))
|
||||
print(frames[0].shape)
|
||||
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
# Since we're using pytorch, the shape is in pytorch, channel-first convention (c, h, w).
|
||||
# We can compare this shape with the information available for that feature
|
||||
pprint(dataset.features[camera_key])
|
||||
# In particular:
|
||||
print(dataset.features[camera_key]["shape"])
|
||||
# The shape is in (h, w, c) which is a more universal format.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
# For many machine learning applications we need to load the history of past observations or trajectories of
|
||||
# future actions. Our datasets can load previous and future frames for each key/modality, using timestamps
|
||||
# differences with the current loaded frame. For instance:
|
||||
delta_timestamps = {
|
||||
# loads 4 images: 1 second before current frame, 500 ms before, 200 ms before, and current frame
|
||||
camera_key: [-1, -0.5, -0.20, 0],
|
||||
# loads 6 state vectors: 1.5 seconds before, 1 second before, ... 200 ms, 100 ms, and current frame
|
||||
"observation.state": [-1.5, -1, -0.5, -0.20, -0.10, 0],
|
||||
# loads 64 action vectors: current frame, 1 frame in the future, 2 frames, ... 63 frames in the future
|
||||
"action": [t / dataset.fps for t in range(64)],
|
||||
}
|
||||
# Note that in any case, these delta_timestamps values need to be multiples of (1/fps) so that added to any
|
||||
# timestamp, you still get a valid timestamp.
|
||||
|
||||
dataset = LeRobotDataset(repo_id, delta_timestamps=delta_timestamps)
|
||||
print(f"\n{dataset[0][camera_key].shape=}") # (4, c, h, w)
|
||||
print(f"{dataset[0]['observation.state'].shape=}") # (6, c)
|
||||
print(f"{dataset[0]['action'].shape=}\n") # (64, c)
|
||||
|
||||
if __name__ == "__main__":
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
num_workers=4,
|
||||
@@ -144,3 +145,7 @@ if __name__ == "__main__":
|
||||
print(f"{batch['observation.state'].shape=}") # (32, 6, c)
|
||||
print(f"{batch['action'].shape=}") # (32, 64, c)
|
||||
break
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+86
-80
@@ -33,83 +33,68 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<eval_dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
robot = LeKiwiClient(robot_config)
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Running inference, recording eval episode {recorded_episodes} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -118,21 +103,42 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+82
-76
@@ -34,78 +34,62 @@ RESET_TIME_SEC = 10
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
leader_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig()
|
||||
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(leader_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
# TODO(Steven): Update this example to use pipelines
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
# Configure the dataset features
|
||||
action_features = hw_to_dataset_features(robot.action_features, ACTION)
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, OBS_STR)
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
features=dataset_features,
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
# Connect the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="lekiwi_record")
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {recorded_episodes}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
dataset=dataset,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
@@ -113,23 +97,45 @@ while recorded_episodes < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (
|
||||
(recorded_episodes < NUM_EPISODES - 1) or events["rerecord_episode"]
|
||||
):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=[leader_arm, keyboard],
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=teleop_action_processor,
|
||||
robot_action_processor=robot_action_processor,
|
||||
robot_observation_processor=robot_observation_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
recorded_episodes += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
leader_arm.disconnect()
|
||||
keyboard.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+32
-26
@@ -20,42 +20,48 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.robots.lekiwi.config_lekiwi import LeKiwiClientConfig
|
||||
from lerobot.robots.lekiwi.lekiwi_client import LeKiwiClient
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="lekiwi")
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Initialize the robot
|
||||
robot = LeKiwiClient(robot_config)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset("<hf_username>/<dataset_repo_id>", episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
# Get recorded action from dataset
|
||||
action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
busy_wait(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
robot.disconnect()
|
||||
precise_sleep(max(1.0 / dataset.fps - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,54 +19,60 @@ import time
|
||||
from lerobot.robots.lekiwi import LeKiwiClient, LeKiwiClientConfig
|
||||
from lerobot.teleoperators.keyboard.teleop_keyboard import KeyboardTeleop, KeyboardTeleopConfig
|
||||
from lerobot.teleoperators.so100_leader import SO100Leader, SO100LeaderConfig
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
robot_config = LeKiwiClientConfig(remote_ip="172.18.134.136", id="my_lekiwi")
|
||||
teleop_arm_config = SO100LeaderConfig(port="/dev/tty.usbmodem585A0077581", id="my_awesome_leader_arm")
|
||||
keyboard_config = KeyboardTeleopConfig(id="my_laptop_keyboard")
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
# Initialize the robot and teleoperator
|
||||
robot = LeKiwiClient(robot_config)
|
||||
leader_arm = SO100Leader(teleop_arm_config)
|
||||
keyboard = KeyboardTeleop(keyboard_config)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
# To connect you already should have this script running on LeKiwi: `python -m lerobot.robots.lekiwi.lekiwi_host --robot.id=my_awesome_kiwi`
|
||||
robot.connect()
|
||||
leader_arm.connect()
|
||||
keyboard.connect()
|
||||
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="lekiwi_teleop")
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected or not leader_arm.is_connected or not keyboard.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
# Get robot observation
|
||||
observation = robot.get_observation()
|
||||
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
# Get teleop action
|
||||
# Arm
|
||||
arm_action = leader_arm.get_action()
|
||||
arm_action = {f"arm_{k}": v for k, v in arm_action.items()}
|
||||
# Keyboard
|
||||
keyboard_keys = keyboard.get_action()
|
||||
base_action = robot._from_keyboard_to_base_action(keyboard_keys)
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
action = {**arm_action, **base_action} if len(base_action) > 0 else arm_action
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(action)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=observation, action=action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+135
-127
@@ -52,125 +52,114 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem58760434471",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -179,21 +168,40 @@ for episode_idx in range(NUM_EPISODES):
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+142
-133
@@ -50,133 +50,122 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
phone = Phone(teleop_config)
|
||||
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# Build pipeline to convert phone action to EE action
|
||||
phone_to_robot_ee_pose_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.20,
|
||||
),
|
||||
GripperVelocityToJoint(speed_factor=20.0),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joint observation to EE observation
|
||||
robot_joints_to_ee_pose = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=phone_to_robot_ee_pose_processor,
|
||||
initial_features=create_initial_features(action=phone.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
phone.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="phone_so100_record")
|
||||
|
||||
if not robot.is_connected or not phone.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop. Move your phone to teleoperate the robot...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
@@ -184,22 +173,42 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=phone,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=phone_to_robot_ee_pose_processor,
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
phone.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -29,72 +29,78 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -32,82 +32,90 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.phone.config_phone import PhoneConfig, PhoneOS
|
||||
from lerobot.teleoperators.phone.phone_processor import MapPhoneActionToRobotAction
|
||||
from lerobot.teleoperators.phone.teleop_phone import Phone
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
def main():
|
||||
# Initialize the robot and teleoperator
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
teleop_config = PhoneConfig(phone_os=PhoneOS.IOS) # or PhoneOS.ANDROID
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
robot = SO100Follower(robot_config)
|
||||
teleop_device = Phone(teleop_config)
|
||||
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
# Build pipeline to convert phone action to ee pose action to joint action
|
||||
phone_to_robot_joints_processor = RobotProcessorPipeline[
|
||||
tuple[RobotAction, RobotObservation], RobotAction
|
||||
](
|
||||
steps=[
|
||||
MapPhoneActionToRobotAction(platform=teleop_config.phone_os),
|
||||
EEReferenceAndDelta(
|
||||
kinematics=kinematics_solver,
|
||||
end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5},
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
use_latched_reference=True,
|
||||
),
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
GripperVelocityToJoint(
|
||||
speed_factor=20.0,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
robot.connect()
|
||||
teleop_device.connect()
|
||||
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="phone_so100_teleop")
|
||||
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected or not teleop_device.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
print("Starting teleop loop. Move your phone to teleoperate the robot...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
# Get teleop action
|
||||
phone_obs = teleop_device.get_action()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Phone -> EE pose -> Joints transition
|
||||
joint_action = phone_to_robot_joints_processor((phone_obs, robot_obs))
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=phone_obs, action=joint_action)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -15,16 +15,12 @@
|
||||
# limitations under the License.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from pathlib import Path
|
||||
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.utils.utils import init_logging
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class AggregateDatasets(PipelineStep):
|
||||
@@ -38,6 +34,11 @@ class AggregateDatasets(PipelineStep):
|
||||
self.aggr_repo_id = aggregated_repo_id
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
import logging
|
||||
|
||||
from lerobot.datasets.aggregate import aggregate_datasets
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
init_logging()
|
||||
|
||||
# Since aggregate_datasets already handles parallel processing internally,
|
||||
|
||||
@@ -20,7 +20,7 @@ from pathlib import Path
|
||||
from datatrove.executor import LocalPipelineExecutor
|
||||
from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
|
||||
class PortDroidShards(PipelineStep):
|
||||
@@ -35,7 +35,7 @@ class PortDroidShards(PipelineStep):
|
||||
|
||||
def run(self, data=None, rank: int = 0, world_size: int = 1):
|
||||
from datasets.utils.tqdm import disable_progress_bars
|
||||
from port_datasets.droid_rlds.port_droid import port_droid, validate_dataset
|
||||
from port_droid import port_droid, validate_dataset
|
||||
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ from datatrove.executor.slurm import SlurmPipelineExecutor
|
||||
from datatrove.pipeline.base import PipelineStep
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.constants import REPOCARD_NAME
|
||||
from port_datasets.droid_rlds.port_droid import DROID_SHARDS
|
||||
from port_droid import DROID_SHARDS
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import CODEBASE_VERSION, LeRobotDatasetMetadata
|
||||
from lerobot.datasets.utils import create_lerobot_dataset_card
|
||||
@@ -185,11 +185,11 @@ class UploadDataset(PipelineStep):
|
||||
|
||||
|
||||
def make_upload_executor(
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, slurm=True
|
||||
repo_id, job_name, logs_dir, workers, partition, cpus_per_task, mem_per_cpu, private=False, slurm=True
|
||||
):
|
||||
kwargs = {
|
||||
"pipeline": [
|
||||
UploadDataset(repo_id),
|
||||
UploadDataset(repo_id, private=private),
|
||||
],
|
||||
"logging_dir": str(logs_dir / job_name),
|
||||
}
|
||||
@@ -267,6 +267,12 @@ def main():
|
||||
default="1950M",
|
||||
help="Memory per cpu that each worker will use.",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--private",
|
||||
action="store_true",
|
||||
default=False,
|
||||
help="Whether to create a private repository.",
|
||||
)
|
||||
|
||||
init_logging()
|
||||
|
||||
|
||||
@@ -0,0 +1,951 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Evaluate Real-Time Chunking (RTC) performance on dataset samples.
|
||||
|
||||
This script takes two random samples from a dataset:
|
||||
- Uses actions from the first sample as previous chunk
|
||||
- Generates new actions for the second sample with and without RTC
|
||||
|
||||
It compares action predictions with and without RTC on dataset samples,
|
||||
measuring consistency and ground truth alignment.
|
||||
|
||||
Usage:
|
||||
# Basic usage with smolvla policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--rtc.max_guidance_weight=10.0 \
|
||||
--rtc.prefix_attention_schedule=EXP \
|
||||
--seed=10
|
||||
|
||||
# Basic usage with pi0.5 policy
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi05_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=10 \
|
||||
--device=mps
|
||||
--seed=10
|
||||
|
||||
# Basic usage with pi0.5 policy with cuda device
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi05_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# Basic usage with pi0 policy with cuda device
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lerobot/pi0_libero_finetuned \
|
||||
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=lipsop/reuben_pi0 \
|
||||
--dataset.repo_id=ReubenLim/so101_cube_in_cup \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=mps \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
|
||||
# With torch.compile on CUDA (CUDA graphs disabled by default)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--rtc.execution_horizon=8 \
|
||||
--device=cuda \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=reduce-overhead
|
||||
|
||||
# Enable CUDA graphs (advanced - may cause tensor aliasing errors)
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_backend=inductor \
|
||||
--torch_compile_mode=max-autotune \
|
||||
--torch_compile_disable_cudagraphs=false
|
||||
"""
|
||||
|
||||
import gc
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
try:
|
||||
import matplotlib.pyplot as plt
|
||||
|
||||
MATPLOTLIB_AVAILABLE = True
|
||||
except ImportError:
|
||||
MATPLOTLIB_AVAILABLE = False
|
||||
plt = None
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
|
||||
def set_seed(seed: int):
|
||||
"""Set random seed for reproducibility."""
|
||||
random.seed(seed)
|
||||
np.random.seed(seed)
|
||||
torch.manual_seed(seed)
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.manual_seed(seed)
|
||||
torch.cuda.manual_seed_all(seed)
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.manual_seed(seed)
|
||||
torch.backends.cudnn.deterministic = True
|
||||
torch.backends.cudnn.benchmark = False
|
||||
|
||||
|
||||
def _check_matplotlib_available():
|
||||
"""Check if matplotlib is available, raise helpful error if not."""
|
||||
if not MATPLOTLIB_AVAILABLE:
|
||||
raise ImportError(
|
||||
"matplotlib is required for RTC debug visualizations. "
|
||||
"Please install it by running:\n"
|
||||
" uv pip install matplotlib"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCEvalConfig(HubMixin):
|
||||
"""Configuration for RTC evaluation."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Dataset configuration
|
||||
dataset: DatasetConfig = field(default_factory=DatasetConfig)
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=20,
|
||||
max_guidance_weight=10.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=True,
|
||||
debug_maxlen=1000,
|
||||
)
|
||||
)
|
||||
|
||||
# Device configuration
|
||||
device: str | None = field(
|
||||
default=None,
|
||||
metadata={"help": "Device to run on (cuda, cpu, mps, auto)"},
|
||||
)
|
||||
|
||||
# Output configuration
|
||||
output_dir: str = field(
|
||||
default="rtc_debug_output",
|
||||
metadata={"help": "Directory to save debug visualizations"},
|
||||
)
|
||||
|
||||
# Seed configuration
|
||||
seed: int = field(
|
||||
default=42,
|
||||
metadata={"help": "Random seed for reproducibility"},
|
||||
)
|
||||
|
||||
inference_delay: int = field(
|
||||
default=4,
|
||||
metadata={"help": "Inference delay for RTC"},
|
||||
)
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# Parse policy path
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required (--policy.path)")
|
||||
|
||||
# Auto-detect device if not specified
|
||||
if self.device is None or self.device == "auto":
|
||||
if torch.cuda.is_available():
|
||||
self.device = "cuda"
|
||||
elif torch.backends.mps.is_available():
|
||||
self.device = "mps"
|
||||
else:
|
||||
self.device = "cpu"
|
||||
logging.info(f"Auto-detected device: {self.device}")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
class RTCEvaluator:
|
||||
"""Evaluator for RTC on dataset samples."""
|
||||
|
||||
def __init__(self, cfg: RTCEvalConfig):
|
||||
self.cfg = cfg
|
||||
self.device = cfg.device
|
||||
|
||||
# Load dataset with proper delta_timestamps based on policy configuration
|
||||
# Calculate delta_timestamps using the same logic as make_dataset factory
|
||||
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
|
||||
# Get dataset metadata to extract FPS
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
|
||||
|
||||
# Calculate delta_timestamps from policy's delta_indices
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
# Create dataset with calculated delta_timestamps
|
||||
self.dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
|
||||
# Create preprocessor/postprocessor
|
||||
self.preprocessor, self.postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": self.device},
|
||||
},
|
||||
)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Ready to run evaluation with sequential policy loading:")
|
||||
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
|
||||
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
|
||||
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
|
||||
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
|
||||
logging.info("=" * 80)
|
||||
|
||||
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||
"""Initialize a single policy instance with specified RTC configuration.
|
||||
|
||||
Args:
|
||||
name: Name identifier for logging purposes
|
||||
rtc_enabled: Whether to enable RTC for this policy
|
||||
rtc_debug: Whether to enable debug tracking for this policy
|
||||
|
||||
Returns:
|
||||
Configured policy instance with optional torch.compile applied
|
||||
"""
|
||||
logging.info(f"Initializing {name}...")
|
||||
|
||||
# Load policy from pretrained
|
||||
policy_class = get_policy_class(self.cfg.policy.type)
|
||||
|
||||
config = PreTrainedConfig.from_pretrained(self.cfg.policy.pretrained_path)
|
||||
|
||||
if self.cfg.policy.type == "pi05" or self.cfg.policy.type == "pi0":
|
||||
config.compile_model = self.cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(self.cfg.policy.pretrained_path, config=config)
|
||||
policy = policy.to(self.device)
|
||||
policy.eval()
|
||||
|
||||
# Configure RTC
|
||||
rtc_config = RTCConfig(
|
||||
enabled=rtc_enabled,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
max_guidance_weight=self.cfg.rtc.max_guidance_weight,
|
||||
prefix_attention_schedule=self.cfg.rtc.prefix_attention_schedule,
|
||||
debug=rtc_debug,
|
||||
debug_maxlen=self.cfg.rtc.debug_maxlen,
|
||||
)
|
||||
policy.config.rtc_config = rtc_config
|
||||
policy.init_rtc_processor()
|
||||
|
||||
logging.info(f" RTC enabled: {rtc_enabled}")
|
||||
logging.info(f" RTC debug: {rtc_debug}")
|
||||
logging.info(f" Policy config: {config}")
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if self.cfg.use_torch_compile:
|
||||
policy = self._apply_torch_compile(policy, name)
|
||||
|
||||
logging.info(f"✓ {name} initialized successfully")
|
||||
return policy
|
||||
|
||||
def _apply_torch_compile(self, policy, policy_name: str):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
policy_name: Name for logging purposes
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logging.warning(
|
||||
f" [{policy_name}] torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||
logging.info(f" Disable CUDA graphs: {self.cfg.torch_compile_disable_cudagraphs}")
|
||||
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - Debug tracker is excluded from compilation via @torch._dynamo.disable
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": self.cfg.torch_compile_backend,
|
||||
"mode": self.cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if self.cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logging.info(f" ✓ [{policy_name}] Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f" [{policy_name}] Failed to apply torch.compile: {e}")
|
||||
logging.warning(f" [{policy_name}] Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
def _destroy_policy(self, policy, policy_name: str):
|
||||
"""Explicitly destroy a policy and free all associated memory.
|
||||
|
||||
This method performs aggressive cleanup to ensure maximum memory is freed,
|
||||
which is critical for large models (e.g., VLAs with billions of parameters).
|
||||
|
||||
Args:
|
||||
policy: Policy instance to destroy
|
||||
policy_name: Name for logging purposes
|
||||
"""
|
||||
logging.info(f" Destroying {policy_name} and freeing memory...")
|
||||
|
||||
try:
|
||||
# Step 1: Move policy to CPU to free GPU/MPS memory
|
||||
policy.cpu()
|
||||
|
||||
# Step 2: Delete the policy object
|
||||
del policy
|
||||
|
||||
# Step 3: Force garbage collection to reclaim memory immediately
|
||||
gc.collect()
|
||||
|
||||
# Step 4: Clear device-specific caches
|
||||
if torch.cuda.is_available():
|
||||
torch.cuda.empty_cache()
|
||||
torch.cuda.synchronize() # Ensure all operations complete
|
||||
|
||||
if torch.backends.mps.is_available():
|
||||
torch.mps.empty_cache()
|
||||
|
||||
logging.info(f" ✓ {policy_name} destroyed and memory freed")
|
||||
|
||||
except Exception as e:
|
||||
logging.warning(f" Warning: Error during {policy_name} cleanup: {e}")
|
||||
|
||||
def run_evaluation(self):
|
||||
"""Run evaluation on two random dataset samples using three separate policies.
|
||||
|
||||
Note: Policies are deinitalized after each step to free memory. Large models
|
||||
(e.g., VLA models with billions of parameters) cannot fit three instances in
|
||||
memory simultaneously. By deleting and garbage collecting after each step,
|
||||
we ensure only one policy is loaded at a time.
|
||||
"""
|
||||
# Create output directory
|
||||
os.makedirs(self.cfg.output_dir, exist_ok=True)
|
||||
logging.info(f"Output directory: {self.cfg.output_dir}")
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Starting RTC evaluation")
|
||||
logging.info(f"Inference delay: {self.cfg.inference_delay}")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Load two random samples from dataset
|
||||
data_loader = torch.utils.data.DataLoader(self.dataset, batch_size=1, shuffle=True)
|
||||
loader_iter = iter(data_loader)
|
||||
first_sample = next(loader_iter)
|
||||
second_sample = next(loader_iter)
|
||||
|
||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||
|
||||
# ============================================================================
|
||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||
# ============================================================================
|
||||
# This policy is only used to generate the reference chunk and then freed
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Initialize policy 1
|
||||
policy_prev_chunk_policy = self._init_policy(
|
||||
name="policy_prev_chunk",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=False,
|
||||
)
|
||||
with torch.no_grad():
|
||||
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
||||
|
||||
# Destroy policy_prev_chunk to free memory for large models
|
||||
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
|
||||
|
||||
# ============================================================================
|
||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||
# ============================================================================
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
set_seed(self.cfg.seed)
|
||||
|
||||
# Initialize policy 2
|
||||
policy_no_rtc_policy = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
|
||||
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
policy_no_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
no_rtc_actions = policy_no_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
)
|
||||
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
|
||||
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||
logging.info(f" Generated no_rtc_actions shape: {no_rtc_actions.shape}")
|
||||
|
||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
|
||||
|
||||
# ============================================================================
|
||||
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||
# ============================================================================
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
set_seed(self.cfg.seed)
|
||||
|
||||
# Initialize policy 3
|
||||
policy_rtc_policy = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
policy_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
rtc_actions = policy_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
)
|
||||
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
|
||||
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||
logging.info(f" Generated rtc_actions shape: {rtc_actions.shape}")
|
||||
|
||||
# Save num_steps before destroying policy (needed for plotting)
|
||||
try:
|
||||
num_steps = policy_rtc_policy.config.num_steps
|
||||
except Exception as e:
|
||||
logging.error(f" Error getting num_steps: {e}")
|
||||
num_steps = policy_rtc_policy.config.num_inference_steps
|
||||
logging.warning(f" Using num_inference_steps: {num_steps} instead of num_steps")
|
||||
|
||||
# Destroy policy_rtc after final use
|
||||
self._destroy_policy(policy_rtc_policy, "policy_rtc")
|
||||
|
||||
# Plot and save results
|
||||
logging.info("=" * 80)
|
||||
logging.info("Plotting results...")
|
||||
self.plot_tracked_data(rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps)
|
||||
|
||||
# Plot final actions comparison
|
||||
logging.info("=" * 80)
|
||||
logging.info("Plotting final actions comparison...")
|
||||
self.plot_final_actions_comparison(rtc_actions, no_rtc_actions, prev_chunk_left_over)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("Evaluation completed successfully")
|
||||
|
||||
def plot_final_actions_comparison(self, rtc_actions, no_rtc_actions, prev_chunk_left_over):
|
||||
"""Plot final action predictions comparison on a single chart.
|
||||
|
||||
Args:
|
||||
rtc_actions: Final actions from RTC policy
|
||||
no_rtc_actions: Final actions from non-RTC policy
|
||||
prev_chunk_left_over: Previous chunk used as ground truth
|
||||
"""
|
||||
_check_matplotlib_available()
|
||||
|
||||
# Remove batch dimension if present
|
||||
rtc_actions_plot = rtc_actions.squeeze(0).cpu() if len(rtc_actions.shape) == 3 else rtc_actions.cpu()
|
||||
no_rtc_actions_plot = (
|
||||
no_rtc_actions.squeeze(0).cpu() if len(no_rtc_actions.shape) == 3 else no_rtc_actions.cpu()
|
||||
)
|
||||
prev_chunk_plot = prev_chunk_left_over.cpu()
|
||||
|
||||
# Create figure with 6 subplots (one per action dimension)
|
||||
fig, axes = plt.subplots(6, 1, figsize=(16, 12))
|
||||
fig.suptitle("Final Action Predictions Comparison (Raw)", fontsize=16)
|
||||
|
||||
# Plot each action dimension
|
||||
for dim_idx, ax in enumerate(axes):
|
||||
# Plot previous chunk (ground truth) in red
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
prev_chunk_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="red",
|
||||
label="Previous Chunk (Ground Truth)",
|
||||
linewidth=2.5,
|
||||
alpha=0.8,
|
||||
)
|
||||
|
||||
# Plot no-RTC actions in blue
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
no_rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="blue",
|
||||
label="No RTC",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Plot RTC actions in green
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
[ax],
|
||||
rtc_actions_plot[:, dim_idx : dim_idx + 1],
|
||||
start_from=0,
|
||||
color="green",
|
||||
label="RTC",
|
||||
linewidth=2,
|
||||
alpha=0.7,
|
||||
)
|
||||
|
||||
# Add vertical lines for inference delay and execution horizon
|
||||
inference_delay = self.cfg.inference_delay
|
||||
execution_horizon = self.cfg.rtc.execution_horizon
|
||||
|
||||
if inference_delay > 0:
|
||||
ax.axvline(
|
||||
x=inference_delay - 1,
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.5,
|
||||
label=f"Inference Delay ({inference_delay})",
|
||||
)
|
||||
|
||||
if execution_horizon > 0:
|
||||
ax.axvline(
|
||||
x=execution_horizon,
|
||||
color="purple",
|
||||
linestyle="--",
|
||||
alpha=0.5,
|
||||
label=f"Execution Horizon ({execution_horizon})",
|
||||
)
|
||||
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
|
||||
# Set x-axis ticks to show all integer values
|
||||
max_len = max(rtc_actions_plot.shape[0], no_rtc_actions_plot.shape[0], prev_chunk_plot.shape[0])
|
||||
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||
ax.set_xlim(-0.5, max_len - 0.5)
|
||||
|
||||
axes[-1].set_xlabel("Step", fontsize=10)
|
||||
|
||||
# Collect legend handles and labels from first subplot
|
||||
handles, labels = axes[0].get_legend_handles_labels()
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_handles = []
|
||||
unique_labels = []
|
||||
for handle, label in zip(handles, labels, strict=True):
|
||||
if label not in seen:
|
||||
seen.add(label)
|
||||
unique_handles.append(handle)
|
||||
unique_labels.append(label)
|
||||
|
||||
# Add legend outside the plot area (to the right)
|
||||
fig.legend(
|
||||
unique_handles,
|
||||
unique_labels,
|
||||
loc="center right",
|
||||
fontsize=9,
|
||||
bbox_to_anchor=(1.0, 0.5),
|
||||
framealpha=0.9,
|
||||
)
|
||||
|
||||
# Save figure
|
||||
output_path = os.path.join(self.cfg.output_dir, "final_actions_comparison.png")
|
||||
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend on right
|
||||
fig.savefig(output_path, dpi=150, bbox_inches="tight")
|
||||
logging.info(f"Saved final actions comparison to {output_path}")
|
||||
plt.close(fig)
|
||||
|
||||
def plot_tracked_data(self, rtc_tracked_steps, no_rtc_tracked_steps, prev_chunk_left_over, num_steps):
|
||||
_check_matplotlib_available()
|
||||
|
||||
# Create side-by-side figures for denoising visualization
|
||||
fig_xt, axs_xt = self._create_figure("x_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_vt, axs_vt = self._create_figure("v_t Denoising: No RTC (left) vs RTC (right)")
|
||||
fig_corr, axs_corr = self._create_figure("Correction: No RTC (left) vs RTC (right)")
|
||||
fig_x1t, axs_x1t = self._create_figure(
|
||||
"x1_t Predicted State & Error: No RTC (left - empty) vs RTC (right)"
|
||||
)
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
rtc_tracked_steps,
|
||||
axs_xt[:, 1], # Right column for x_t
|
||||
axs_vt[:, 1], # Right column for v_t
|
||||
axs_corr[:, 1], # Right column for correction
|
||||
axs_x1t[:, 1], # Right column for x1_t
|
||||
num_steps,
|
||||
add_labels=True, # Add labels for RTC (right column)
|
||||
)
|
||||
|
||||
self._plot_denoising_steps_from_tracker(
|
||||
no_rtc_tracked_steps,
|
||||
axs_xt[:, 0], # Left column for x_t
|
||||
axs_vt[:, 0], # Left column for v_t
|
||||
axs_corr[:, 0], # Left column for correction
|
||||
axs_x1t[:, 0], # Left column for x1_t
|
||||
num_steps,
|
||||
add_labels=False, # No labels for No RTC (left column)
|
||||
)
|
||||
|
||||
# Plot no-RTC x_t data on right chart as orange dashed line for comparison
|
||||
self._plot_no_rtc_xt_reference(no_rtc_tracked_steps, axs_xt[:, 1], num_steps)
|
||||
|
||||
# Plot ground truth on x_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x1_t axes
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 1], prev_chunk_left_over, start_from=0, color="red", label="Ground truth"
|
||||
)
|
||||
|
||||
# Plot ground truth on x_t axes (no labels for left column)
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_xt[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
|
||||
)
|
||||
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
axs_x1t[:, 0], prev_chunk_left_over, start_from=0, color="red", label=None
|
||||
)
|
||||
|
||||
# Add legends outside the plot area for each figure
|
||||
self._add_figure_legend(fig_xt, axs_xt)
|
||||
self._add_figure_legend(fig_vt, axs_vt)
|
||||
self._add_figure_legend(fig_corr, axs_corr)
|
||||
self._add_figure_legend(fig_x1t, axs_x1t)
|
||||
|
||||
# Save denoising plots
|
||||
self._save_figure(fig_xt, os.path.join(self.cfg.output_dir, "denoising_xt_comparison.png"))
|
||||
self._save_figure(fig_vt, os.path.join(self.cfg.output_dir, "denoising_vt_comparison.png"))
|
||||
self._save_figure(fig_corr, os.path.join(self.cfg.output_dir, "denoising_correction_comparison.png"))
|
||||
self._save_figure(fig_x1t, os.path.join(self.cfg.output_dir, "denoising_x1t_comparison.png"))
|
||||
|
||||
def _create_figure(self, title):
|
||||
fig, axs = plt.subplots(6, 2, figsize=(24, 12))
|
||||
fig.suptitle(title, fontsize=16)
|
||||
|
||||
for ax in axs[:, 0]:
|
||||
ax.set_title("No RTC (N/A)" if ax == axs[0, 0] else "", fontsize=12)
|
||||
for ax in axs[:, 1]:
|
||||
ax.set_title("RTC" if ax == axs[0, 1] else "", fontsize=12)
|
||||
|
||||
return fig, axs
|
||||
|
||||
def _add_figure_legend(self, fig, axs):
|
||||
"""Add a legend outside the plot area on the right side.
|
||||
|
||||
Args:
|
||||
fig: Matplotlib figure to add legend to
|
||||
axs: Array of axes to collect legend handles from
|
||||
"""
|
||||
# Collect all handles and labels from the first row of axes (right column)
|
||||
handles, labels = axs[0, 1].get_legend_handles_labels()
|
||||
|
||||
# Remove duplicates while preserving order
|
||||
seen = set()
|
||||
unique_handles = []
|
||||
unique_labels = []
|
||||
for handle, label in zip(handles, labels, strict=True):
|
||||
if label not in seen:
|
||||
seen.add(label)
|
||||
unique_handles.append(handle)
|
||||
unique_labels.append(label)
|
||||
|
||||
# Add legend outside the plot area (to the right, close to charts)
|
||||
if unique_handles:
|
||||
fig.legend(
|
||||
unique_handles,
|
||||
unique_labels,
|
||||
loc="center left",
|
||||
fontsize=8,
|
||||
bbox_to_anchor=(0.87, 0.5),
|
||||
framealpha=0.9,
|
||||
ncol=1,
|
||||
)
|
||||
|
||||
def _save_figure(self, fig, path):
|
||||
fig.tight_layout(rect=[0, 0, 0.85, 1]) # Leave space for legend/colorbar on right
|
||||
fig.savefig(path, dpi=150, bbox_inches="tight")
|
||||
logging.info(f"Saved figure to {path}")
|
||||
plt.close(fig)
|
||||
|
||||
def _plot_denoising_steps_from_tracker(
|
||||
self, tracked_steps, xt_axs, vt_axs, corr_axs, x1t_axs, num_steps, add_labels=True
|
||||
):
|
||||
"""Plot denoising steps from tracker data.
|
||||
|
||||
Args:
|
||||
tracked_steps: List of DebugStep objects containing debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes)
|
||||
vt_axs: Matplotlib axes for v_t plots (array of 6 axes)
|
||||
corr_axs: Matplotlib axes for correction plots (array of 6 axes)
|
||||
x1t_axs: Matplotlib axes for x1_t plots (array of 6 axes)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
add_labels: Whether to add legend labels for the plots
|
||||
"""
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info(f"Plotting {len(tracked_steps)} steps")
|
||||
|
||||
debug_steps = tracked_steps
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
# Define colors for different denoise steps (using a colormap)
|
||||
colors = plt.cm.viridis(np.linspace(0, 1, num_steps))
|
||||
|
||||
for step_idx, debug_step in enumerate(debug_steps):
|
||||
color = colors[step_idx % len(colors)]
|
||||
label = f"Step {step_idx}" if add_labels else None
|
||||
|
||||
# Plot x_t
|
||||
if debug_step.x_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
xt_axs, debug_step.x_t, start_from=0, color=color, label=label
|
||||
)
|
||||
|
||||
# Plot v_t
|
||||
if debug_step.v_t is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
vt_axs, debug_step.v_t, start_from=0, color=color, label=label
|
||||
)
|
||||
|
||||
# Plot correction on separate axes
|
||||
if debug_step.correction is not None:
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
corr_axs,
|
||||
debug_step.correction,
|
||||
start_from=0,
|
||||
color=color,
|
||||
label=label,
|
||||
)
|
||||
|
||||
# Plot x1_t (predicted state)
|
||||
if x1t_axs is not None and debug_step.x1_t is not None:
|
||||
x1t_label = f"x1_t Step {step_idx}" if add_labels else None
|
||||
RTCDebugVisualizer.plot_waypoints(
|
||||
x1t_axs,
|
||||
debug_step.x1_t,
|
||||
start_from=0,
|
||||
color=color,
|
||||
label=x1t_label,
|
||||
)
|
||||
|
||||
# Plot error in orange dashed
|
||||
if x1t_axs is not None and debug_step.err is not None:
|
||||
error_chunk = (
|
||||
debug_step.err[0].cpu().numpy()
|
||||
if len(debug_step.err.shape) == 3
|
||||
else debug_step.err.cpu().numpy()
|
||||
)
|
||||
|
||||
num_dims = min(error_chunk.shape[-1], 6)
|
||||
error_label = f"error Step {step_idx}" if add_labels else None
|
||||
for j in range(num_dims):
|
||||
x1t_axs[j].plot(
|
||||
np.arange(0, error_chunk.shape[0]),
|
||||
error_chunk[:, j],
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.7,
|
||||
label=error_label,
|
||||
)
|
||||
|
||||
# Recalculate axis limits after plotting to ensure proper scaling
|
||||
self._rescale_axes(xt_axs)
|
||||
self._rescale_axes(vt_axs)
|
||||
self._rescale_axes(corr_axs)
|
||||
self._rescale_axes(x1t_axs)
|
||||
|
||||
def _plot_no_rtc_xt_reference(self, no_rtc_tracked_steps, xt_axs, num_steps):
|
||||
"""Plot final no-RTC x_t data as orange dashed line on the RTC chart for comparison.
|
||||
|
||||
Args:
|
||||
no_rtc_tracked_steps: List of DebugStep objects containing no-RTC debug steps
|
||||
xt_axs: Matplotlib axes for x_t plots (array of 6 axes, right column)
|
||||
num_steps: Total number of denoising steps for colormap
|
||||
"""
|
||||
debug_steps = no_rtc_tracked_steps
|
||||
if not debug_steps:
|
||||
return
|
||||
|
||||
# Plot only the final x_t step as orange dashed line
|
||||
final_step = debug_steps[-1]
|
||||
logging.info("Plotting final no-RTC x_t step as orange dashed reference")
|
||||
|
||||
if final_step.x_t is not None:
|
||||
x_t_chunk = (
|
||||
final_step.x_t[0].cpu().numpy()
|
||||
if len(final_step.x_t.shape) == 3
|
||||
else final_step.x_t.cpu().numpy()
|
||||
)
|
||||
|
||||
num_dims = min(x_t_chunk.shape[-1], 6)
|
||||
for j in range(num_dims):
|
||||
xt_axs[j].plot(
|
||||
np.arange(0, x_t_chunk.shape[0]),
|
||||
x_t_chunk[:, j],
|
||||
color="orange",
|
||||
linestyle="--",
|
||||
alpha=0.7,
|
||||
linewidth=2,
|
||||
label="No RTC (final)" if j == 0 else "",
|
||||
)
|
||||
|
||||
def _rescale_axes(self, axes):
|
||||
"""Rescale axes to show all data with proper margins.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes to rescale
|
||||
"""
|
||||
for ax in axes:
|
||||
ax.relim()
|
||||
ax.autoscale_view()
|
||||
|
||||
# Add 10% margin to y-axis for better visualization
|
||||
ylim = ax.get_ylim()
|
||||
y_range = ylim[1] - ylim[0]
|
||||
if y_range > 0: # Avoid division by zero
|
||||
margin = y_range * 0.1
|
||||
ax.set_ylim(ylim[0] - margin, ylim[1] + margin)
|
||||
|
||||
# Set x-axis ticks to show all integer values
|
||||
xlim = ax.get_xlim()
|
||||
max_len = int(xlim[1]) + 1
|
||||
if max_len > 0:
|
||||
ax.set_xticks(range(0, max_len, max(1, max_len // 20))) # Show ~20 ticks
|
||||
ax.set_xlim(-0.5, max_len - 0.5)
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def main(cfg: RTCEvalConfig):
|
||||
"""Main entry point for RTC evaluation."""
|
||||
# Set random seed for reproducibility
|
||||
set_seed(cfg.seed)
|
||||
|
||||
init_logging()
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("RTC Dataset Evaluation")
|
||||
logging.info(f"Config: {cfg}")
|
||||
logging.info("=" * 80)
|
||||
|
||||
evaluator = RTCEvaluator(cfg)
|
||||
evaluator.run_evaluation()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,549 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Demo script showing how to use Real-Time Chunking (RTC) with action chunking policies on real robots.
|
||||
|
||||
This script demonstrates:
|
||||
1. Creating a robot and policy (SmolVLA, Pi0, etc.) with RTC
|
||||
2. Consuming actions from the policy while the robot executes
|
||||
3. Periodically requesting new action chunks in the background using threads
|
||||
4. Managing action buffers and timing for real-time operation
|
||||
|
||||
For simulation environments, see eval_with_simulation.py
|
||||
|
||||
Usage:
|
||||
# Run RTC with Real robot with RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot without RTC
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=false \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
|
||||
# Run RTC with Real robot with pi0.5 policy
|
||||
uv run examples/rtc/eval_with_real_robot.py \
|
||||
--policy.path=helper2424/pi05_check_rtc \
|
||||
--policy.device=mps \
|
||||
--rtc.enabled=true \
|
||||
--rtc.execution_horizon=20 \
|
||||
--robot.type=so100_follower \
|
||||
--robot.port=/dev/tty.usbmodem58FA0834591 \
|
||||
--robot.id=so100_follower \
|
||||
--robot.cameras="{ gripper: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}, front: {type: opencv, index_or_path: 1, width: 640, height: 480, fps: 30}}" \
|
||||
--task="Move green small object into the purple platform" \
|
||||
--duration=120
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
import sys
|
||||
import time
|
||||
import traceback
|
||||
from dataclasses import dataclass, field
|
||||
from threading import Event, Lock, Thread
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
make_default_robot_action_processor,
|
||||
make_default_robot_observation_processor,
|
||||
)
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.robots import ( # noqa: F401
|
||||
Robot,
|
||||
RobotConfig,
|
||||
koch_follower,
|
||||
so100_follower,
|
||||
so101_follower,
|
||||
)
|
||||
from lerobot.robots.utils import make_robot_from_config
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
from lerobot.utils.hub import HubMixin
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RobotWrapper:
|
||||
def __init__(self, robot: Robot):
|
||||
self.robot = robot
|
||||
self.lock = Lock()
|
||||
|
||||
def get_observation(self) -> dict[str, Tensor]:
|
||||
with self.lock:
|
||||
return self.robot.get_observation()
|
||||
|
||||
def send_action(self, action: Tensor):
|
||||
with self.lock:
|
||||
self.robot.send_action(action)
|
||||
|
||||
def observation_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.observation_features
|
||||
|
||||
def action_features(self) -> list[str]:
|
||||
with self.lock:
|
||||
return self.robot.action_features
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies and real robots."""
|
||||
|
||||
# Policy configuration
|
||||
policy: PreTrainedConfig | None = None
|
||||
|
||||
# Robot configuration
|
||||
robot: RobotConfig | None = None
|
||||
|
||||
# RTC configuration
|
||||
rtc: RTCConfig = field(
|
||||
default_factory=lambda: RTCConfig(
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=1.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
)
|
||||
)
|
||||
|
||||
# Demo parameters
|
||||
duration: float = 30.0 # Duration to run the demo (seconds)
|
||||
fps: float = 10.0 # Action execution frequency (Hz)
|
||||
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
policy_path = parser.get_path_arg("policy")
|
||||
if policy_path:
|
||||
cli_overrides = parser.get_cli_overrides("policy")
|
||||
self.policy = PreTrainedConfig.from_pretrained(policy_path, cli_overrides=cli_overrides)
|
||||
self.policy.pretrained_path = policy_path
|
||||
else:
|
||||
raise ValueError("Policy path is required")
|
||||
|
||||
# Validate that robot configuration is provided
|
||||
if self.robot is None:
|
||||
raise ValueError("Robot configuration must be provided")
|
||||
|
||||
@classmethod
|
||||
def __get_path_fields__(cls) -> list[str]:
|
||||
"""This enables the parser to load config from the policy using `--policy.path=local/dir`"""
|
||||
return ["policy"]
|
||||
|
||||
|
||||
def is_image_key(k: str) -> bool:
|
||||
return k.startswith(OBS_IMAGES)
|
||||
|
||||
|
||||
def get_actions(
|
||||
policy,
|
||||
robot: RobotWrapper,
|
||||
robot_observation_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to request action chunks from the policy.
|
||||
|
||||
Args:
|
||||
policy: The policy instance (SmolVLA, Pi0, etc.)
|
||||
robot: The robot instance for getting observations
|
||||
robot_observation_processor: Processor for raw robot observations
|
||||
action_queue: Queue to put new action chunks
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[GET_ACTIONS] Starting get actions thread")
|
||||
|
||||
latency_tracker = LatencyTracker() # Track latency of action chunks
|
||||
fps = cfg.fps
|
||||
time_per_chunk = 1.0 / fps
|
||||
|
||||
dataset_features = hw_to_dataset_features(robot.observation_features(), "observation")
|
||||
policy_device = policy.config.device
|
||||
|
||||
# Load preprocessor and postprocessor from pretrained files
|
||||
# The stats are embedded in the processor .safetensors files
|
||||
logger.info(f"[GET_ACTIONS] Loading preprocessor/postprocessor from {cfg.policy.pretrained_path}")
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
dataset_stats=None, # Will load from pretrained processor files
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": cfg.policy.device},
|
||||
},
|
||||
)
|
||||
|
||||
logger.info("[GET_ACTIONS] Preprocessor/postprocessor loaded successfully with embedded stats")
|
||||
|
||||
get_actions_threshold = cfg.action_queue_size_to_get_new_actions
|
||||
|
||||
if not cfg.rtc.enabled:
|
||||
get_actions_threshold = 0
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
if action_queue.qsize() <= get_actions_threshold:
|
||||
current_time = time.perf_counter()
|
||||
action_index_before_inference = action_queue.get_action_index()
|
||||
prev_actions = action_queue.get_left_over()
|
||||
|
||||
inference_latency = latency_tracker.max()
|
||||
inference_delay = math.ceil(inference_latency / time_per_chunk)
|
||||
|
||||
obs = robot.get_observation()
|
||||
|
||||
# Apply robot observation processor
|
||||
obs_processed = robot_observation_processor(obs)
|
||||
|
||||
obs_with_policy_features = build_dataset_frame(
|
||||
dataset_features, obs_processed, prefix="observation"
|
||||
)
|
||||
|
||||
for name in obs_with_policy_features:
|
||||
obs_with_policy_features[name] = torch.from_numpy(obs_with_policy_features[name])
|
||||
if "image" in name:
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].type(torch.float32) / 255
|
||||
)
|
||||
obs_with_policy_features[name] = (
|
||||
obs_with_policy_features[name].permute(2, 0, 1).contiguous()
|
||||
)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].unsqueeze(0)
|
||||
obs_with_policy_features[name] = obs_with_policy_features[name].to(policy_device)
|
||||
|
||||
obs_with_policy_features["task"] = [cfg.task] # Task should be a list, not a string!
|
||||
obs_with_policy_features["robot_type"] = (
|
||||
robot.robot.name if hasattr(robot.robot, "name") else ""
|
||||
)
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
|
||||
new_latency = time.perf_counter() - current_time
|
||||
new_delay = math.ceil(new_latency / time_per_chunk)
|
||||
latency_tracker.add(new_latency)
|
||||
|
||||
if cfg.action_queue_size_to_get_new_actions < cfg.rtc.execution_horizon + new_delay:
|
||||
logger.warning(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
else:
|
||||
# Small sleep to prevent busy waiting
|
||||
time.sleep(0.1)
|
||||
|
||||
logger.info("[GET_ACTIONS] get actions thread shutting down")
|
||||
except Exception as e:
|
||||
logger.error(f"[GET_ACTIONS] Fatal exception in get_actions thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def actor_control(
|
||||
robot: RobotWrapper,
|
||||
robot_action_processor,
|
||||
action_queue: ActionQueue,
|
||||
shutdown_event: Event,
|
||||
cfg: RTCDemoConfig,
|
||||
):
|
||||
"""Thread function to execute actions on the robot.
|
||||
|
||||
Args:
|
||||
robot: The robot instance
|
||||
action_queue: Queue to get actions from
|
||||
shutdown_event: Event to signal shutdown
|
||||
cfg: Demo configuration
|
||||
"""
|
||||
try:
|
||||
logger.info("[ACTOR] Starting actor thread")
|
||||
|
||||
action_count = 0
|
||||
action_interval = 1.0 / cfg.fps
|
||||
|
||||
while not shutdown_event.is_set():
|
||||
start_time = time.perf_counter()
|
||||
|
||||
# Try to get an action from the queue with timeout
|
||||
action = action_queue.get()
|
||||
|
||||
if action is not None:
|
||||
action = action.cpu()
|
||||
action_dict = {key: action[i].item() for i, key in enumerate(robot.action_features())}
|
||||
action_processed = robot_action_processor((action_dict, None))
|
||||
robot.send_action(action_processed)
|
||||
|
||||
action_count += 1
|
||||
|
||||
dt_s = time.perf_counter() - start_time
|
||||
time.sleep(max(0, (action_interval - dt_s) - 0.001))
|
||||
|
||||
logger.info(f"[ACTOR] Actor thread shutting down. Total actions executed: {action_count}")
|
||||
except Exception as e:
|
||||
logger.error(f"[ACTOR] Fatal exception in actor_control thread: {e}")
|
||||
logger.error(traceback.format_exc())
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def demo_cli(cfg: RTCDemoConfig):
|
||||
"""Main entry point for RTC demo with draccus configuration."""
|
||||
|
||||
# Initialize logging
|
||||
init_logging()
|
||||
|
||||
logger.info(f"Using device: {cfg.device}")
|
||||
|
||||
# Setup signal handler for graceful shutdown
|
||||
signal_handler = ProcessSignalHandler(use_threads=True, display_pid=False)
|
||||
shutdown_event = signal_handler.shutdown_event
|
||||
|
||||
policy = None
|
||||
robot = None
|
||||
get_actions_thread = None
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla", "pi05", "pi0"], "Only smolvla, pi05, and pi0 are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot
|
||||
logger.info(f"Initializing robot: {cfg.robot.type}")
|
||||
robot = make_robot_from_config(cfg.robot)
|
||||
robot.connect()
|
||||
robot_wrapper = RobotWrapper(robot)
|
||||
|
||||
# Create robot observation processor
|
||||
robot_observation_processor = make_default_robot_observation_processor()
|
||||
robot_action_processor = make_default_robot_action_processor()
|
||||
|
||||
# Create action queue for communication between threads
|
||||
action_queue = ActionQueue(cfg.rtc)
|
||||
|
||||
# Start chunk requester thread
|
||||
get_actions_thread = Thread(
|
||||
target=get_actions,
|
||||
args=(policy, robot_wrapper, robot_observation_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="GetActions",
|
||||
)
|
||||
get_actions_thread.start()
|
||||
logger.info("Started get actions thread")
|
||||
|
||||
# Start action executor thread
|
||||
actor_thread = Thread(
|
||||
target=actor_control,
|
||||
args=(robot_wrapper, robot_action_processor, action_queue, shutdown_event, cfg),
|
||||
daemon=True,
|
||||
name="Actor",
|
||||
)
|
||||
actor_thread.start()
|
||||
logger.info("Started actor thread")
|
||||
|
||||
logger.info("Started stop by duration thread")
|
||||
|
||||
# Main thread monitors for duration or shutdown
|
||||
logger.info(f"Running demo for {cfg.duration} seconds...")
|
||||
start_time = time.time()
|
||||
|
||||
while not shutdown_event.is_set() and (time.time() - start_time) < cfg.duration:
|
||||
time.sleep(10)
|
||||
|
||||
# Log queue status periodically
|
||||
if int(time.time() - start_time) % 5 == 0:
|
||||
logger.info(f"[MAIN] Action queue size: {action_queue.qsize()}")
|
||||
|
||||
if time.time() - start_time > cfg.duration:
|
||||
break
|
||||
|
||||
logger.info("Demo duration reached or shutdown requested")
|
||||
|
||||
# Signal shutdown
|
||||
shutdown_event.set()
|
||||
|
||||
# Wait for threads to finish
|
||||
if get_actions_thread and get_actions_thread.is_alive():
|
||||
logger.info("Waiting for chunk requester thread to finish...")
|
||||
get_actions_thread.join()
|
||||
|
||||
if actor_thread and actor_thread.is_alive():
|
||||
logger.info("Waiting for action executor thread to finish...")
|
||||
actor_thread.join()
|
||||
|
||||
# Cleanup robot
|
||||
if robot:
|
||||
robot.disconnect()
|
||||
logger.info("Robot disconnected")
|
||||
|
||||
logger.info("Cleanup completed")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
demo_cli()
|
||||
logging.info("RTC demo finished")
|
||||
@@ -52,126 +52,114 @@ TASK_DESCRIPTION = "My task description"
|
||||
HF_MODEL_ID = "<hf_username>/<model_repo_id>"
|
||||
HF_DATASET_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys()))
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
def main():
|
||||
# Create the robot configuration & robot
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Create policy
|
||||
policy = ACTPolicy.from_pretrained(HF_MODEL_ID)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert joints observation to EE observation
|
||||
robot_joints_to_ee_pose_processor = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=kinematics_solver, motor_names=list(robot.bus.motors.keys())
|
||||
)
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_joints_to_ee_pose_processor,
|
||||
initial_features=create_initial_features(observation=robot.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
# User for now should be explicit on the feature keys that were used for record
|
||||
# Alternatively, the user can pass the processor step that has the right features
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=make_default_teleop_action_processor(),
|
||||
initial_features=create_initial_features(
|
||||
action={
|
||||
f"ee.{k}": PolicyFeature(type=FeatureType.ACTION, shape=(1,))
|
||||
for k in ["x", "y", "z", "wx", "wy", "wz", "gripper_pos"]
|
||||
}
|
||||
),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=robot.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Build Policy Processors
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
# The inference device is automatically set to match the detected hardware, overriding any previous device settings from training to ensure compatibility.
|
||||
preprocessor_overrides={"device_processor": {"device": str(policy.config.device)}},
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
robot.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="so100_so100_evaluate")
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
print("Starting evaluate loop...")
|
||||
episode_idx = 0
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
policy=policy,
|
||||
preprocessor=preprocessor, # Pass the pre and post policy processors
|
||||
postprocessor=postprocessor,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
@@ -180,21 +168,40 @@ for episode_idx in range(NUM_EPISODES):
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and ((episode_idx < NUM_EPISODES - 1) or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=robot,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=make_default_teleop_action_processor(),
|
||||
robot_action_processor=robot_ee_to_joints_processor,
|
||||
robot_observation_processor=robot_joints_to_ee_pose_processor,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-record episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
robot.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -48,134 +48,122 @@ RESET_TIME_SEC = 30
|
||||
TASK_DESCRIPTION = "My task description"
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", cameras=camera_config, use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
def main():
|
||||
# Create the robot and teleoperator configurations
|
||||
camera_config = {"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS)}
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411",
|
||||
id="my_awesome_follower_arm",
|
||||
cameras=camera_config,
|
||||
use_degrees=True,
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert follower joints to EE observation
|
||||
follower_joints_to_ee = RobotProcessorPipeline[RobotObservation, RobotObservation](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=follower_kinematics_solver, motor_names=list(follower.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=observation_to_transition,
|
||||
to_output=transition_to_observation,
|
||||
)
|
||||
|
||||
# Build pipeline to convert leader joints to EE action
|
||||
leader_joints_to_ee = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Build pipeline to convert EE action to follower joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=True,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Create the dataset
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_REPO_ID,
|
||||
fps=FPS,
|
||||
features=combine_feature_dicts(
|
||||
# Run the feature contract of the pipelines
|
||||
# This tells you how the features would look like after the pipeline steps
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=leader_joints_to_ee,
|
||||
initial_features=create_initial_features(action=leader.action_features),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=follower_joints_to_ee,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
),
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_threads=4,
|
||||
)
|
||||
|
||||
# Connect the robot and teleoperator
|
||||
leader.connect()
|
||||
follower.connect()
|
||||
|
||||
# Initialize the keyboard listener and rerun visualization
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="recording_phone")
|
||||
|
||||
if not leader.is_connected or not follower.is_connected:
|
||||
raise ValueError("Robot or teleop is not connected!")
|
||||
|
||||
print("Starting record loop...")
|
||||
episode_idx = 0
|
||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
|
||||
# Main record loop
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
dataset=dataset,
|
||||
control_time_s=EPISODE_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
@@ -183,22 +171,42 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
# Reset the environment if not stopping or re-recording
|
||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||
log_say("Reset the environment")
|
||||
record_loop(
|
||||
robot=follower,
|
||||
events=events,
|
||||
fps=FPS,
|
||||
teleop=leader,
|
||||
control_time_s=RESET_TIME_SEC,
|
||||
single_task=TASK_DESCRIPTION,
|
||||
display_data=True,
|
||||
teleop_action_processor=leader_joints_to_ee,
|
||||
robot_action_processor=ee_to_follower_joints,
|
||||
robot_observation_processor=follower_joints_to_ee,
|
||||
)
|
||||
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
if events["rerecord_episode"]:
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
# Save episode
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
# Clean up
|
||||
log_say("Stop recording")
|
||||
leader.disconnect()
|
||||
follower.disconnect()
|
||||
listener.stop()
|
||||
|
||||
dataset.finalize()
|
||||
dataset.push_to_hub()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -30,72 +30,78 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
)
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
EPISODE_IDX = 0
|
||||
HF_REPO_ID = "<hf_username>/<dataset_repo_id>"
|
||||
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
def main():
|
||||
# Initialize the robot config
|
||||
robot_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot
|
||||
robot = SO100Follower(robot_config)
|
||||
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(robot.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
# Build pipeline to convert EE action to joints action
|
||||
robot_ee_to_joints_processor = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
steps=[
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=kinematics_solver,
|
||||
motor_names=list(robot.bus.motors.keys()),
|
||||
initial_guess_current_joints=False, # Because replay is open loop
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
# Fetch the dataset to replay
|
||||
dataset = LeRobotDataset(HF_REPO_ID, episodes=[EPISODE_IDX])
|
||||
# Filter dataset to only include frames from the specified episode since episodes are chunked in dataset V3.0
|
||||
episode_frames = dataset.hf_dataset.filter(lambda x: x["episode_index"] == EPISODE_IDX)
|
||||
actions = episode_frames.select_columns(ACTION)
|
||||
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
# Connect to the robot
|
||||
robot.connect()
|
||||
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
if not robot.is_connected:
|
||||
raise ValueError("Robot is not connected!")
|
||||
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
print("Starting replay loop...")
|
||||
log_say(f"Replaying episode {EPISODE_IDX}")
|
||||
for idx in range(len(episode_frames)):
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
# Get recorded action from dataset
|
||||
ee_action = {
|
||||
name: float(actions[idx][ACTION][i]) for i, name in enumerate(dataset.features[ACTION]["names"])
|
||||
}
|
||||
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
# Get robot observation
|
||||
robot_obs = robot.get_observation()
|
||||
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
# Dataset EE -> robot joints
|
||||
joint_action = robot_ee_to_joints_processor((ee_action, robot_obs))
|
||||
|
||||
busy_wait(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
# Send action to robot
|
||||
_ = robot.send_action(joint_action)
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
precise_sleep(1.0 / dataset.fps - (time.perf_counter() - t0))
|
||||
|
||||
# Clean up
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -32,90 +32,96 @@ from lerobot.robots.so100_follower.robot_kinematic_processor import (
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
from lerobot.teleoperators.so100_leader.config_so100_leader import SO100LeaderConfig
|
||||
from lerobot.teleoperators.so100_leader.so100_leader import SO100Leader
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
FPS = 30
|
||||
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
def main():
|
||||
# Initialize the robot and teleoperator config
|
||||
follower_config = SO100FollowerConfig(
|
||||
port="/dev/tty.usbmodem5A460814411", id="my_awesome_follower_arm", use_degrees=True
|
||||
)
|
||||
leader_config = SO100LeaderConfig(port="/dev/tty.usbmodem5A460819811", id="my_awesome_leader_arm")
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
# Initialize the robot and teleoperator
|
||||
follower = SO100Follower(follower_config)
|
||||
leader = SO100Leader(leader_config)
|
||||
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
follower_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(follower.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# NOTE: It is highly recommended to use the urdf in the SO-ARM100 repo: https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf
|
||||
leader_kinematics_solver = RobotKinematics(
|
||||
urdf_path="./SO101/so101_new_calib.urdf",
|
||||
target_frame_name="gripper_frame_link",
|
||||
joint_names=list(leader.bus.motors.keys()),
|
||||
)
|
||||
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
# Build pipeline to convert teleop joints to EE action
|
||||
leader_to_ee = RobotProcessorPipeline[RobotAction, RobotAction](
|
||||
steps=[
|
||||
ForwardKinematicsJointsToEE(
|
||||
kinematics=leader_kinematics_solver, motor_names=list(leader.bus.motors.keys())
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
# build pipeline to convert EE action to robot joints
|
||||
ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservation], RobotAction](
|
||||
[
|
||||
EEBoundsAndSafety(
|
||||
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
|
||||
max_ee_step_m=0.10,
|
||||
),
|
||||
InverseKinematicsEEToJoints(
|
||||
kinematics=follower_kinematics_solver,
|
||||
motor_names=list(follower.bus.motors.keys()),
|
||||
initial_guess_current_joints=False,
|
||||
),
|
||||
],
|
||||
to_transition=robot_action_observation_to_transition,
|
||||
to_output=transition_to_robot_action,
|
||||
)
|
||||
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
# Connect to the robot and teleoperator
|
||||
follower.connect()
|
||||
leader.connect()
|
||||
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
# Init rerun viewer
|
||||
init_rerun(session_name="so100_so100_EE_teleop")
|
||||
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
print("Starting teleop loop...")
|
||||
while True:
|
||||
t0 = time.perf_counter()
|
||||
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
# Get robot observation
|
||||
robot_obs = follower.get_observation()
|
||||
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
# Get teleop observation
|
||||
leader_joints_obs = leader.get_action()
|
||||
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
# teleop joints -> teleop EE action
|
||||
leader_ee_act = leader_to_ee(leader_joints_obs)
|
||||
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
# teleop EE -> robot joints
|
||||
follower_joints_act = ee_to_follower_joints((leader_ee_act, robot_obs))
|
||||
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
# Send action to robot
|
||||
_ = follower.send_action(follower_joints_act)
|
||||
|
||||
busy_wait(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
# Visualize
|
||||
log_rerun_data(observation=leader_ee_act, action=follower_joints_act)
|
||||
|
||||
precise_sleep(max(1.0 / FPS - (time.perf_counter() - t0), 0.0))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,80 +19,86 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/act")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = ACTConfig(input_features=input_features, output_features=output_features)
|
||||
policy = ACTPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_act")
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_act")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,50 +8,56 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_act"
|
||||
model = ACTPolicy.from_pretrained(model_id)
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(model.config, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot.send_action(action)
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -1,11 +1,17 @@
|
||||
from lerobot.async_inference.configs import PolicyServerConfig
|
||||
from lerobot.async_inference.policy_server import serve
|
||||
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
def main():
|
||||
host = ... # something like "127.0.0.1" if you're exposing to localhost
|
||||
port = ... # something like 8080
|
||||
|
||||
config = PolicyServerConfig(
|
||||
host=host,
|
||||
port=port,
|
||||
)
|
||||
serve(config)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -6,50 +6,56 @@ from lerobot.async_inference.robot_client import RobotClient
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.robots.so100_follower import SO100FollowerConfig
|
||||
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
def main():
|
||||
# these cameras must match the ones expected by the policy - find your cameras with lerobot-find-cameras
|
||||
# check the config.json on the Hub for the policy you are using to see the expected camera specs
|
||||
camera_cfg = {
|
||||
"up": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"side": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_cfg)
|
||||
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="fracapuano/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
server_address = ... # something like "127.0.0.1:8080" if using localhost
|
||||
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
# 3. Create client configuration
|
||||
client_cfg = RobotClientConfig(
|
||||
robot=robot_cfg,
|
||||
server_address=server_address,
|
||||
policy_device="mps",
|
||||
policy_type="act",
|
||||
pretrained_name_or_path="<user>/robot_learning_tutorial_act",
|
||||
chunk_size_threshold=0.5, # g
|
||||
actions_per_chunk=50, # make sure this is less than the max actions of the policy
|
||||
)
|
||||
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
# 4. Create and start client
|
||||
client = RobotClient(client_cfg)
|
||||
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
# 5. Provide a textual description of the task
|
||||
task = ...
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
if client.start():
|
||||
# Start action receiver thread
|
||||
action_receiver_thread = threading.Thread(target=client.receive_actions, daemon=True)
|
||||
action_receiver_thread.start()
|
||||
|
||||
try:
|
||||
# Run the control loop
|
||||
client.control_loop(task)
|
||||
except KeyboardInterrupt:
|
||||
client.stop()
|
||||
action_receiver_thread.join()
|
||||
# (Optionally) plot the action queue size
|
||||
visualize_action_queue_size(client.action_queue_size)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -19,81 +19,87 @@ def make_delta_timestamps(delta_indices: list[int] | None, fps: int) -> list[flo
|
||||
return [i / fps for i in delta_indices]
|
||||
|
||||
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
def main():
|
||||
output_directory = Path("outputs/robot_learning_tutorial/diffusion")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
# Select your device
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
# This specifies the inputs the model will be expecting and the outputs it will produce
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
features = dataset_to_policy_features(dataset_metadata.features)
|
||||
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
output_features = {key: ft for key, ft in features.items() if ft.type is FeatureType.ACTION}
|
||||
input_features = {key: ft for key, ft in features.items() if key not in output_features}
|
||||
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
cfg = DiffusionConfig(input_features=input_features, output_features=output_features)
|
||||
policy = DiffusionPolicy(cfg)
|
||||
preprocessor, postprocessor = make_pre_post_processors(cfg, dataset_stats=dataset_metadata.stats)
|
||||
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
policy.train()
|
||||
policy.to(device)
|
||||
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
# To perform action chunking, ACT expects a given number of actions as targets
|
||||
delta_timestamps = {
|
||||
"observation.state": make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps),
|
||||
"action": make_delta_timestamps(cfg.action_delta_indices, dataset_metadata.fps),
|
||||
}
|
||||
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps) for k in cfg.image_features
|
||||
}
|
||||
# add image features if they are present
|
||||
delta_timestamps |= {
|
||||
k: make_delta_timestamps(cfg.observation_delta_indices, dataset_metadata.fps)
|
||||
for k in cfg.image_features
|
||||
}
|
||||
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
# Instantiate the dataset
|
||||
dataset = LeRobotDataset(dataset_id, delta_timestamps=delta_timestamps)
|
||||
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
# Create the optimizer and dataloader for offline training
|
||||
optimizer = cfg.get_optimizer_preset().build(policy.parameters())
|
||||
batch_size = 32
|
||||
dataloader = torch.utils.data.DataLoader(
|
||||
dataset,
|
||||
batch_size=batch_size,
|
||||
shuffle=True,
|
||||
pin_memory=device.type != "cpu",
|
||||
drop_last=True,
|
||||
)
|
||||
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
# Number of training steps and logging frequency
|
||||
training_steps = 1
|
||||
log_freq = 1
|
||||
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
# Run training loop
|
||||
step = 0
|
||||
done = False
|
||||
while not done:
|
||||
for batch in dataloader:
|
||||
batch = preprocessor(batch)
|
||||
loss, _ = policy.forward(batch)
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
optimizer.zero_grad()
|
||||
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
if step % log_freq == 0:
|
||||
print(f"step: {step} loss: {loss.item():.3f}")
|
||||
step += 1
|
||||
if step >= training_steps:
|
||||
done = True
|
||||
break
|
||||
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
# Save the policy checkpoint, alongside the pre/post processors
|
||||
policy.save_pretrained(output_directory)
|
||||
preprocessor.save_pretrained(output_directory)
|
||||
postprocessor.save_pretrained(output_directory)
|
||||
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("fracapuano/robot_learning_tutorial_diffusion")
|
||||
# Save all assets to the Hub
|
||||
policy.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
preprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
postprocessor.push_to_hub("<user>/robot_learning_tutorial_diffusion")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -8,53 +8,57 @@ from lerobot.policies.utils import build_inference_frame, make_robot_action
|
||||
from lerobot.robots.so100_follower.config_so100_follower import SO100FollowerConfig
|
||||
from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "fracapuano/robot_learning_tutorial_diffusion"
|
||||
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "<user>/robot_learning_tutorial_diffusion"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
model = DiffusionPolicy.from_pretrained(model_id)
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
dataset_id = "lerobot/svla_so101_pickplace"
|
||||
# This only downloads the metadata for the dataset, ~10s of MB even for large-scale datasets
|
||||
dataset_metadata = LeRobotDatasetMetadata(dataset_id)
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config, model_id, dataset_stats=dataset_metadata.stats
|
||||
)
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# # find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# # the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"side": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"up": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_metadata.features, device=device
|
||||
)
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_metadata.features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,57 +11,63 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/pi0_base"
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
model = PI0Policy.from_pretrained(model_id)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"base_0_rgb": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"left_wrist_0_rgb": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
"right_wrist_0_rgb": OpenCVCameraConfig(index_or_path=2, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -20,6 +20,8 @@ from lerobot.teleoperators.utils import TeleopEvents
|
||||
|
||||
LOG_EVERY = 10
|
||||
SEND_EVERY = 10
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
|
||||
def run_learner(
|
||||
@@ -223,123 +225,123 @@ def make_policy_obs(obs, device: torch.device = "cpu"):
|
||||
}
|
||||
|
||||
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
def main():
|
||||
"""Main function - coordinates actor and learner processes."""
|
||||
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
device = "mps" # or "cuda" or "cpu"
|
||||
output_directory = Path("outputs/robot_learning_tutorial/hil_serl")
|
||||
output_directory.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ...
|
||||
leader_port = ...
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ...
|
||||
leader_id = ...
|
||||
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
# A pretrained model (to be used in-distribution!)
|
||||
reward_classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
reward_classifier = Classifier.from_pretrained(reward_classifier_id)
|
||||
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
reward_classifier.to(device)
|
||||
reward_classifier.eval()
|
||||
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
|
||||
# Robot and environment configuration
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id)
|
||||
teleop_cfg = SO100LeaderConfig(port=leader_port, id=leader_id)
|
||||
processor_cfg = HILSerlProcessorConfig(control_mode="leader")
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
|
||||
env_cfg = HILSerlRobotEnvConfig(robot=robot_cfg, teleop=teleop_cfg, processor=processor_cfg)
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
|
||||
# Create robot environment
|
||||
env, teleop_device = make_robot_env(env_cfg)
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
|
||||
obs_features = hw_to_dataset_features(env.robot.observation_features, "observation")
|
||||
action_features = hw_to_dataset_features(env.robot.action_features, "action")
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
|
||||
# Create SAC policy for action selection
|
||||
policy_cfg = SACConfig(
|
||||
device=device,
|
||||
input_features=obs_features,
|
||||
output_features=action_features,
|
||||
)
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
|
||||
policy_actor = SACPolicy(policy_cfg)
|
||||
policy_learner = SACPolicy(policy_cfg)
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
|
||||
demonstrations_repo_id = "lerobot/example_hil_serl_dataset"
|
||||
offline_dataset = LeRobotDataset(repo_id=demonstrations_repo_id)
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
|
||||
# Online buffer: initialized from scratch
|
||||
online_replay_buffer = ReplayBuffer(device=device, state_keys=list(obs_features.keys()))
|
||||
# Offline buffer: Created from dataset (pre-populated it with demonstrations)
|
||||
offline_replay_buffer = ReplayBuffer.from_lerobot_dataset(
|
||||
lerobot_dataset=offline_dataset, device=device, state_keys=list(obs_features.keys())
|
||||
)
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
|
||||
# Create communication channels between learner and actor processes
|
||||
transitions_queue = mp.Queue(maxsize=10)
|
||||
parameters_queue = mp.Queue(maxsize=2)
|
||||
shutdown_event = mp.Event()
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
|
||||
|
||||
# Signal handler for graceful shutdown
|
||||
def signal_handler(sig):
|
||||
print(f"\nSignal {sig} received, shutting down...")
|
||||
shutdown_event.set()
|
||||
|
||||
|
||||
signal.signal(signal.SIGINT, signal_handler)
|
||||
signal.signal(signal.SIGTERM, signal_handler)
|
||||
|
||||
# Create processes
|
||||
learner_process = mp.Process(
|
||||
target=run_learner,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_learner,
|
||||
online_replay_buffer,
|
||||
offline_replay_buffer,
|
||||
),
|
||||
kwargs={"device": device}, # can run on accelerated hardware for training
|
||||
)
|
||||
|
||||
actor_process = mp.Process(
|
||||
target=run_actor,
|
||||
args=(
|
||||
transitions_queue,
|
||||
parameters_queue,
|
||||
shutdown_event,
|
||||
policy_actor,
|
||||
reward_classifier,
|
||||
env_cfg,
|
||||
output_directory,
|
||||
),
|
||||
kwargs={"device": "cpu"}, # actor is frozen, can run on CPU or accelerate for inference
|
||||
)
|
||||
|
||||
learner_process.start()
|
||||
actor_process.start()
|
||||
|
||||
try:
|
||||
# Wait for actor to finish (it controls the episode loop)
|
||||
actor_process.join()
|
||||
shutdown_event.set()
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("Main process interrupted")
|
||||
shutdown_event.set()
|
||||
actor_process.join(timeout=5)
|
||||
learner_process.join(timeout=10)
|
||||
|
||||
finally:
|
||||
if learner_process.is_alive():
|
||||
learner_process.terminate()
|
||||
if actor_process.is_alive():
|
||||
actor_process.terminate()
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -4,59 +4,64 @@ from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.sac.reward_model.configuration_classifier import RewardClassifierConfig
|
||||
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
def main():
|
||||
# Device to use for training
|
||||
device = "mps" # or "cuda", or "cpu"
|
||||
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
# Load the dataset used for training
|
||||
repo_id = "lerobot/example_hil_serl_dataset"
|
||||
dataset = LeRobotDataset(repo_id)
|
||||
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
# Configure the policy to extract features from the image frames
|
||||
camera_keys = dataset.meta.camera_keys
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
config = RewardClassifierConfig(
|
||||
num_cameras=len(camera_keys),
|
||||
device=device,
|
||||
# backbone model to extract features from the image frames
|
||||
model_name="microsoft/resnet-18",
|
||||
)
|
||||
|
||||
# Make policy, preprocessor, and optimizer
|
||||
policy = make_policy(config, ds_meta=dataset.meta)
|
||||
optimizer = config.get_optimizer_preset().build(policy.parameters())
|
||||
preprocessor, _ = make_pre_post_processors(policy_cfg=config, dataset_stats=dataset.meta.stats)
|
||||
|
||||
classifier_id = "<user>/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
|
||||
|
||||
classifier_id = "fracapuano/reward_classifier_hil_serl_example"
|
||||
|
||||
# Instantiate a dataloader
|
||||
dataloader = torch.utils.data.DataLoader(dataset, batch_size=16, shuffle=True)
|
||||
|
||||
# Training loop
|
||||
num_epochs = 5
|
||||
for epoch in range(num_epochs):
|
||||
total_loss = 0
|
||||
total_accuracy = 0
|
||||
for batch in dataloader:
|
||||
# Preprocess the batch and move it to the correct device.
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Forward pass
|
||||
loss, output_dict = policy.forward(batch)
|
||||
|
||||
# Backward pass and optimization
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
|
||||
total_loss += loss.item()
|
||||
total_accuracy += output_dict["accuracy"]
|
||||
|
||||
avg_loss = total_loss / len(dataloader)
|
||||
avg_accuracy = total_accuracy / len(dataloader)
|
||||
print(f"Epoch {epoch + 1}/{num_epochs}, Loss: {avg_loss:.4f}, Accuracy: {avg_accuracy:.2f}%")
|
||||
|
||||
print("Training finished!")
|
||||
|
||||
# You can now save the trained policy.
|
||||
policy.push_to_hub(classifier_id)
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
@@ -11,56 +11,62 @@ from lerobot.robots.so100_follower.so100_follower import SO100Follower
|
||||
MAX_EPISODES = 5
|
||||
MAX_STEPS_PER_EPISODE = 20
|
||||
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
def main():
|
||||
device = torch.device("mps") # or "cuda" or "cpu"
|
||||
model_id = "lerobot/smolvla_base"
|
||||
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
model = SmolVLAPolicy.from_pretrained(model_id)
|
||||
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
preprocess, postprocess = make_pre_post_processors(
|
||||
model.config,
|
||||
model_id,
|
||||
# This overrides allows to run on MPS, otherwise defaults to CUDA (if available)
|
||||
preprocessor_overrides={"device_processor": {"device": str(device)}},
|
||||
)
|
||||
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
# find ports using lerobot-find-port
|
||||
follower_port = ... # something like "/dev/tty.usbmodem58760431631"
|
||||
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
# the robot ids are used the load the right calibration files
|
||||
follower_id = ... # something like "follower_so100"
|
||||
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
# Robot and environment configuration
|
||||
# Camera keys must match the name and resolutions of the ones used for training!
|
||||
# You can check the camera keys expected by a model in the info.json card on the model card on the Hub
|
||||
camera_config = {
|
||||
"camera1": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||
"camera2": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30),
|
||||
}
|
||||
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
robot_cfg = SO100FollowerConfig(port=follower_port, id=follower_id, cameras=camera_config)
|
||||
robot = SO100Follower(robot_cfg)
|
||||
robot.connect()
|
||||
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
task = "" # something like "pick the red block"
|
||||
robot_type = "" # something like "so100_follower" for multi-embodiment datasets
|
||||
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
# This is used to match the raw observation keys to the keys expected by the policy
|
||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||
dataset_features = {**action_features, **obs_features}
|
||||
|
||||
obs = preprocess(obs_frame)
|
||||
for _ in range(MAX_EPISODES):
|
||||
for _ in range(MAX_STEPS_PER_EPISODE):
|
||||
obs = robot.get_observation()
|
||||
obs_frame = build_inference_frame(
|
||||
observation=obs, ds_features=dataset_features, device=device, task=task, robot_type=robot_type
|
||||
)
|
||||
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
obs = preprocess(obs_frame)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
action = model.select_action(obs)
|
||||
action = postprocess(action)
|
||||
action = make_robot_action(action, dataset_features)
|
||||
robot.send_action(action)
|
||||
|
||||
print("Episode finished! Starting new episode...")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
+1
-1
@@ -25,7 +25,7 @@ discord = "https://discord.gg/s3KuuzsPFb"
|
||||
|
||||
[project]
|
||||
name = "lerobot"
|
||||
version = "0.4.1"
|
||||
version = "0.4.2"
|
||||
description = "🤗 LeRobot: State-of-the-art Machine Learning for Real-World Robotics in Pytorch"
|
||||
readme = "README.md"
|
||||
license = { text = "Apache-2.0" }
|
||||
|
||||
@@ -43,3 +43,10 @@ class NormalizationMode(str, Enum):
|
||||
class PolicyFeature:
|
||||
type: FeatureType
|
||||
shape: tuple[int, ...]
|
||||
|
||||
|
||||
class RTCAttentionSchedule(str, Enum):
|
||||
ZEROS = "ZEROS"
|
||||
ONES = "ONES"
|
||||
LINEAR = "LINEAR"
|
||||
EXP = "EXP"
|
||||
|
||||
@@ -110,8 +110,8 @@ def worker_thread_loop(queue: queue.Queue):
|
||||
if item is None:
|
||||
queue.task_done()
|
||||
break
|
||||
image_array, fpath = item
|
||||
write_image(image_array, fpath)
|
||||
image_array, fpath, compress_level = item
|
||||
write_image(image_array, fpath, compress_level)
|
||||
queue.task_done()
|
||||
|
||||
|
||||
@@ -169,11 +169,13 @@ class AsyncImageWriter:
|
||||
p.start()
|
||||
self.processes.append(p)
|
||||
|
||||
def save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path):
|
||||
def save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
):
|
||||
if isinstance(image, torch.Tensor):
|
||||
# Convert tensor to numpy array to minimize main process time
|
||||
image = image.cpu().numpy()
|
||||
self.queue.put((image, fpath))
|
||||
self.queue.put((image, fpath, compress_level))
|
||||
|
||||
def wait_until_done(self):
|
||||
self.queue.join()
|
||||
|
||||
@@ -13,6 +13,7 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import logging
|
||||
import shutil
|
||||
@@ -539,6 +540,15 @@ class LeRobotDatasetMetadata:
|
||||
return obj
|
||||
|
||||
|
||||
def _encode_video_worker(video_key: str, episode_index: int, root: Path, fps: int) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(img_dir, temp_path, fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
|
||||
class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def __init__(
|
||||
self,
|
||||
@@ -712,6 +722,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.download(download_videos)
|
||||
self.hf_dataset = self.load_hf_dataset()
|
||||
|
||||
# Create mapping from absolute indices to relative indices when only a subset of the episodes are loaded
|
||||
# Build a mapping: absolute_index -> relative_index_in_filtered_dataset
|
||||
self._absolute_to_relative_idx = None
|
||||
if self.episodes is not None:
|
||||
self._absolute_to_relative_idx = {
|
||||
abs_idx.item() if isinstance(abs_idx, torch.Tensor) else abs_idx: rel_idx
|
||||
for rel_idx, abs_idx in enumerate(self.hf_dataset["index"])
|
||||
}
|
||||
|
||||
# Setup delta_indices
|
||||
if self.delta_timestamps is not None:
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
@@ -830,7 +849,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def load_hf_dataset(self) -> datasets.Dataset:
|
||||
"""hf_dataset contains all the observations, states, actions, rewards, etc."""
|
||||
features = get_hf_features_from_features(self.features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features)
|
||||
hf_dataset = load_nested_dataset(self.root / "data", features=features, episodes=self.episodes)
|
||||
hf_dataset.set_transform(hf_transform_to_torch)
|
||||
return hf_dataset
|
||||
|
||||
@@ -847,10 +866,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Determine requested episodes
|
||||
if self.episodes is None:
|
||||
# Requesting all episodes - check if we have all episodes from metadata
|
||||
requested_episodes = set(range(self.meta.total_episodes))
|
||||
else:
|
||||
# Requesting specific episodes
|
||||
requested_episodes = set(self.episodes)
|
||||
|
||||
# Check if all requested episodes are available in cached data
|
||||
@@ -932,7 +949,11 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
query_timestamps = {}
|
||||
for key in self.meta.video_keys:
|
||||
if query_indices is not None and key in query_indices:
|
||||
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||
if self._absolute_to_relative_idx is not None:
|
||||
relative_indices = [self._absolute_to_relative_idx[idx] for idx in query_indices[key]]
|
||||
timestamps = self.hf_dataset[relative_indices]["timestamp"]
|
||||
else:
|
||||
timestamps = self.hf_dataset[query_indices[key]]["timestamp"]
|
||||
query_timestamps[key] = torch.stack(timestamps).tolist()
|
||||
else:
|
||||
query_timestamps[key] = [current_ts]
|
||||
@@ -940,11 +961,32 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return query_timestamps
|
||||
|
||||
def _query_hf_dataset(self, query_indices: dict[str, list[int]]) -> dict:
|
||||
return {
|
||||
key: torch.stack(self.hf_dataset[q_idx][key])
|
||||
for key, q_idx in query_indices.items()
|
||||
if key not in self.meta.video_keys
|
||||
}
|
||||
"""
|
||||
Query dataset for indices across keys, skipping video keys.
|
||||
|
||||
Tries column-first [key][indices] for speed, falls back to row-first.
|
||||
|
||||
Args:
|
||||
query_indices: Dict mapping keys to index lists to retrieve
|
||||
|
||||
Returns:
|
||||
Dict with stacked tensors of queried data (video keys excluded)
|
||||
"""
|
||||
result: dict = {}
|
||||
for key, q_idx in query_indices.items():
|
||||
if key in self.meta.video_keys:
|
||||
continue
|
||||
# Map absolute indices to relative indices if needed
|
||||
relative_indices = (
|
||||
q_idx
|
||||
if self._absolute_to_relative_idx is None
|
||||
else [self._absolute_to_relative_idx[idx] for idx in q_idx]
|
||||
)
|
||||
try:
|
||||
result[key] = torch.stack(self.hf_dataset[key][relative_indices])
|
||||
except (KeyError, TypeError, IndexError):
|
||||
result[key] = torch.stack(self.hf_dataset[relative_indices][key])
|
||||
return result
|
||||
|
||||
def _query_videos(self, query_timestamps: dict[str, list[float]], ep_idx: int) -> dict[str, torch.Tensor]:
|
||||
"""Note: When using data workers (e.g. DataLoader with num_workers>0), do not call this function
|
||||
@@ -1039,6 +1081,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
ep_buffer[key] = current_ep_idx if key == "episode_index" else []
|
||||
return ep_buffer
|
||||
|
||||
# TODO(Steven): consider move this to utils
|
||||
def _get_image_file_path(self, episode_index: int, image_key: str, frame_index: int) -> Path:
|
||||
fpath = DEFAULT_IMAGE_PATH.format(
|
||||
image_key=image_key, episode_index=episode_index, frame_index=frame_index
|
||||
@@ -1048,13 +1091,15 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
def _get_image_file_dir(self, episode_index: int, image_key: str) -> Path:
|
||||
return self._get_image_file_path(episode_index, image_key, frame_index=0).parent
|
||||
|
||||
def _save_image(self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path) -> None:
|
||||
def _save_image(
|
||||
self, image: torch.Tensor | np.ndarray | PIL.Image.Image, fpath: Path, compress_level: int = 1
|
||||
) -> None:
|
||||
if self.image_writer is None:
|
||||
if isinstance(image, torch.Tensor):
|
||||
image = image.cpu().numpy()
|
||||
write_image(image, fpath)
|
||||
write_image(image, fpath, compress_level=compress_level)
|
||||
else:
|
||||
self.image_writer.save_image(image=image, fpath=fpath)
|
||||
self.image_writer.save_image(image=image, fpath=fpath, compress_level=compress_level)
|
||||
|
||||
def add_frame(self, frame: dict) -> None:
|
||||
"""
|
||||
@@ -1092,14 +1137,19 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
)
|
||||
if frame_index == 0:
|
||||
img_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
self._save_image(frame[key], img_path)
|
||||
compress_level = 1 if self.features[key]["dtype"] == "video" else 6
|
||||
self._save_image(frame[key], img_path, compress_level)
|
||||
self.episode_buffer[key].append(str(img_path))
|
||||
else:
|
||||
self.episode_buffer[key].append(frame[key])
|
||||
|
||||
self.episode_buffer["size"] += 1
|
||||
|
||||
def save_episode(self, episode_data: dict | None = None) -> None:
|
||||
def save_episode(
|
||||
self,
|
||||
episode_data: dict | None = None,
|
||||
parallel_encoding: bool = True,
|
||||
) -> None:
|
||||
"""
|
||||
This will save to disk the current episode in self.episode_buffer.
|
||||
|
||||
@@ -1111,6 +1161,8 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
episode_data (dict | None, optional): Dict containing the episode data to save. If None, this will
|
||||
save the current episode in self.episode_buffer, which is filled with 'add_frame'. Defaults to
|
||||
None.
|
||||
parallel_encoding (bool, optional): If True, encode videos in parallel using ProcessPoolExecutor.
|
||||
Defaults to True on Linux, False on macOS as it tends to use all the CPU available already.
|
||||
"""
|
||||
episode_buffer = episode_data if episode_data is not None else self.episode_buffer
|
||||
|
||||
@@ -1147,8 +1199,40 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
# num_cameras * num_threads = (total_cpu -1)
|
||||
with concurrent.futures.ProcessPoolExecutor(max_workers=num_cameras) as executor:
|
||||
future_to_key = {
|
||||
executor.submit(
|
||||
_encode_video_worker,
|
||||
video_key,
|
||||
episode_index,
|
||||
self.root,
|
||||
self.fps,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
|
||||
results = {}
|
||||
for future in concurrent.futures.as_completed(future_to_key):
|
||||
video_key = future_to_key[future]
|
||||
try:
|
||||
temp_path = future.result()
|
||||
results[video_key] = temp_path
|
||||
except Exception as exc:
|
||||
logging.error(f"Video encoding failed for {video_key}: {exc}")
|
||||
raise exc
|
||||
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path = results[video_key]
|
||||
ep_metadata.update(
|
||||
self._save_episode_video(video_key, episode_index, temp_path=temp_path)
|
||||
)
|
||||
else:
|
||||
for video_key in self.meta.video_keys:
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index))
|
||||
|
||||
# `meta.save_episode` need to be executed after encoding the videos
|
||||
self.meta.save_episode(episode_index, episode_length, episode_tasks, ep_stats, ep_metadata)
|
||||
@@ -1313,9 +1397,18 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
return metadata
|
||||
|
||||
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
|
||||
def _save_episode_video(
|
||||
self,
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
temp_path: Path | None = None,
|
||||
) -> dict:
|
||||
# Encode episode frames into a temporary video
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
if temp_path is None:
|
||||
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
|
||||
else:
|
||||
ep_path = temp_path
|
||||
|
||||
ep_size_in_mb = get_file_size_in_mb(ep_path)
|
||||
ep_duration_in_s = get_video_duration_in_s(ep_path)
|
||||
|
||||
@@ -1433,11 +1526,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
temp_path = Path(tempfile.mkdtemp(dir=self.root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
img_dir = self._get_image_file_dir(episode_index, video_key)
|
||||
encode_video_frames(img_dir, temp_path, self.fps, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1483,6 +1572,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.image_transforms = None
|
||||
obj.delta_timestamps = None
|
||||
obj.delta_indices = None
|
||||
obj._absolute_to_relative_idx = None
|
||||
obj.video_backend = video_backend if video_backend is not None else get_safe_default_codec()
|
||||
obj.writer = None
|
||||
obj.latest_episode = None
|
||||
|
||||
@@ -28,6 +28,7 @@ import numpy as np
|
||||
import packaging.version
|
||||
import pandas
|
||||
import pandas as pd
|
||||
import pyarrow.dataset as pa_ds
|
||||
import pyarrow.parquet as pq
|
||||
import torch
|
||||
from datasets import Dataset
|
||||
@@ -48,7 +49,7 @@ from lerobot.utils.utils import SuppressProgressBars, is_valid_numpy_dtype_strin
|
||||
|
||||
DEFAULT_CHUNK_SIZE = 1000 # Max number of files per chunk
|
||||
DEFAULT_DATA_FILE_SIZE_IN_MB = 100 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 500 # Max size per file
|
||||
DEFAULT_VIDEO_FILE_SIZE_IN_MB = 200 # Max size per file
|
||||
|
||||
INFO_PATH = "meta/info.json"
|
||||
STATS_PATH = "meta/stats.json"
|
||||
@@ -103,7 +104,9 @@ def update_chunk_file_indices(chunk_idx: int, file_idx: int, chunks_size: int) -
|
||||
return chunk_idx, file_idx
|
||||
|
||||
|
||||
def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None) -> Dataset:
|
||||
def load_nested_dataset(
|
||||
pq_dir: Path, features: datasets.Features | None = None, episodes: list[int] | None = None
|
||||
) -> Dataset:
|
||||
"""Find parquet files in provided directory {pq_dir}/chunk-xxx/file-xxx.parquet
|
||||
Convert parquet files to pyarrow memory mapped in a cache folder for efficient RAM usage
|
||||
Concatenate all pyarrow references to return HF Dataset format
|
||||
@@ -111,15 +114,26 @@ def load_nested_dataset(pq_dir: Path, features: datasets.Features | None = None)
|
||||
Args:
|
||||
pq_dir: Directory containing parquet files
|
||||
features: Optional features schema to ensure consistent loading of complex types like images
|
||||
episodes: Optional list of episode indices to filter. Uses PyArrow predicate pushdown for efficiency.
|
||||
"""
|
||||
paths = sorted(pq_dir.glob("*/*.parquet"))
|
||||
if len(paths) == 0:
|
||||
raise FileNotFoundError(f"Provided directory does not contain any parquet file: {pq_dir}")
|
||||
|
||||
# TODO(rcadene): set num_proc to accelerate conversion to pyarrow
|
||||
with SuppressProgressBars():
|
||||
datasets = Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
return datasets
|
||||
# When no filtering needed, Dataset uses memory-mapped loading for efficiency
|
||||
# PyArrow loads the entire dataset into memory
|
||||
if episodes is None:
|
||||
return Dataset.from_parquet([str(path) for path in paths], features=features)
|
||||
|
||||
arrow_dataset = pa_ds.dataset(paths, format="parquet")
|
||||
filter_expr = pa_ds.field("episode_index").isin(episodes)
|
||||
table = arrow_dataset.to_table(filter=filter_expr)
|
||||
|
||||
if features is not None:
|
||||
table = table.cast(features.arrow_schema)
|
||||
|
||||
return Dataset(table)
|
||||
|
||||
|
||||
def get_parquet_num_frames(parquet_path: str | Path) -> int:
|
||||
|
||||
@@ -311,6 +311,7 @@ def encode_video_frames(
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
# Check encoder availability
|
||||
@@ -359,6 +360,9 @@ def encode_video_frames(
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
# "While less efficient, it is generally preferable to modify logging with Python's logging"
|
||||
|
||||
@@ -21,7 +21,22 @@ import draccus
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.robots import RobotConfig
|
||||
from lerobot.teleoperators.config import TeleoperatorConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
LIBERO_KEY_EEF_MAT,
|
||||
LIBERO_KEY_EEF_POS,
|
||||
LIBERO_KEY_EEF_QUAT,
|
||||
LIBERO_KEY_GRIPPER_QPOS,
|
||||
LIBERO_KEY_GRIPPER_QVEL,
|
||||
LIBERO_KEY_JOINTS_POS,
|
||||
LIBERO_KEY_JOINTS_VEL,
|
||||
LIBERO_KEY_PIXELS_AGENTVIEW,
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND,
|
||||
OBS_ENV_STATE,
|
||||
OBS_IMAGE,
|
||||
OBS_IMAGES,
|
||||
OBS_STATE,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -246,28 +261,61 @@ class LiberoEnv(EnvConfig):
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agent_pos": OBS_STATE,
|
||||
"pixels/agentview_image": f"{OBS_IMAGES}.image",
|
||||
"pixels/robot0_eye_in_hand_image": f"{OBS_IMAGES}.image2",
|
||||
LIBERO_KEY_EEF_POS: f"{OBS_STATE}.eef_pos",
|
||||
LIBERO_KEY_EEF_QUAT: f"{OBS_STATE}.eef_quat",
|
||||
LIBERO_KEY_EEF_MAT: f"{OBS_STATE}.eef_mat",
|
||||
LIBERO_KEY_GRIPPER_QPOS: f"{OBS_STATE}.gripper_qpos",
|
||||
LIBERO_KEY_GRIPPER_QVEL: f"{OBS_STATE}.gripper_qvel",
|
||||
LIBERO_KEY_JOINTS_POS: f"{OBS_STATE}.joint_pos",
|
||||
LIBERO_KEY_JOINTS_VEL: f"{OBS_STATE}.joint_vel",
|
||||
LIBERO_KEY_PIXELS_AGENTVIEW: f"{OBS_IMAGES}.image",
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND: f"{OBS_IMAGES}.image2",
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
if self.obs_type == "pixels":
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
elif self.obs_type == "pixels_agent_pos":
|
||||
self.features["agent_pos"] = PolicyFeature(type=FeatureType.STATE, shape=(8,))
|
||||
self.features["pixels/agentview_image"] = PolicyFeature(
|
||||
self.features[LIBERO_KEY_PIXELS_AGENTVIEW] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features["pixels/robot0_eye_in_hand_image"] = PolicyFeature(
|
||||
self.features[LIBERO_KEY_PIXELS_EYE_IN_HAND] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.observation_height, self.observation_width, 3)
|
||||
)
|
||||
self.features[LIBERO_KEY_EEF_POS] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(3,),
|
||||
)
|
||||
self.features[LIBERO_KEY_EEF_QUAT] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(4,),
|
||||
)
|
||||
self.features[LIBERO_KEY_EEF_MAT] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(3, 3),
|
||||
)
|
||||
self.features[LIBERO_KEY_GRIPPER_QPOS] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(2,),
|
||||
)
|
||||
self.features[LIBERO_KEY_GRIPPER_QVEL] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(2,),
|
||||
)
|
||||
self.features[LIBERO_KEY_JOINTS_POS] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,),
|
||||
)
|
||||
self.features[LIBERO_KEY_JOINTS_VEL] = PolicyFeature(
|
||||
type=FeatureType.STATE,
|
||||
shape=(7,),
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unsupported obs_type: {self.obs_type}")
|
||||
|
||||
|
||||
@@ -14,12 +14,16 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import importlib
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
@@ -33,6 +37,41 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
|
||||
def make_env_pre_post_processors(
|
||||
env_cfg: EnvConfig,
|
||||
) -> tuple[
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
]:
|
||||
"""
|
||||
Create preprocessor and postprocessor pipelines for environment observations.
|
||||
|
||||
This function creates processor pipelines that transform raw environment
|
||||
observations and actions. By default, it returns identity processors that do nothing.
|
||||
For specific environments like LIBERO, it adds environment-specific processing steps.
|
||||
|
||||
Args:
|
||||
env_cfg: The configuration of the environment.
|
||||
|
||||
Returns:
|
||||
A tuple containing:
|
||||
- preprocessor: Pipeline that processes environment observations
|
||||
- postprocessor: Pipeline that processes environment outputs (currently identity)
|
||||
"""
|
||||
# Preprocessor and Postprocessor steps are Identity for most environments
|
||||
preprocessor_steps: list[ProcessorStep] = []
|
||||
postprocessor_steps: list[ProcessorStep] = []
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
preprocessor = PolicyProcessorPipeline(steps=preprocessor_steps)
|
||||
postprocessor = PolicyProcessorPipeline(steps=postprocessor_steps)
|
||||
|
||||
return preprocessor, postprocessor
|
||||
|
||||
|
||||
def make_env(
|
||||
cfg: EnvConfig | str,
|
||||
n_envs: int = 1,
|
||||
|
||||
+69
-21
@@ -28,7 +28,6 @@ import torch
|
||||
from gymnasium import spaces
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
from robosuite.utils.transform_utils import quat2axisangle
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
@@ -175,11 +174,36 @@ class LiberoEnv(gym.Env):
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"agent_pos": spaces.Box(
|
||||
low=AGENT_POS_LOW,
|
||||
high=AGENT_POS_HIGH,
|
||||
shape=(OBS_STATE_DIM,),
|
||||
dtype=np.float64,
|
||||
"robot_state": spaces.Dict(
|
||||
{
|
||||
"eef": spaces.Dict(
|
||||
{
|
||||
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(3,), dtype=np.float64),
|
||||
"quat": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(4,), dtype=np.float64
|
||||
),
|
||||
"mat": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(3, 3), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
"gripper": spaces.Dict(
|
||||
{
|
||||
"qpos": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
|
||||
),
|
||||
"qvel": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(2,), dtype=np.float64
|
||||
),
|
||||
}
|
||||
),
|
||||
"joints": spaces.Dict(
|
||||
{
|
||||
"pos": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
|
||||
"vel": spaces.Box(low=-np.inf, high=np.inf, shape=(7,), dtype=np.float64),
|
||||
}
|
||||
),
|
||||
}
|
||||
),
|
||||
}
|
||||
)
|
||||
@@ -191,6 +215,7 @@ class LiberoEnv(gym.Env):
|
||||
def render(self):
|
||||
raw_obs = self._env.env._get_observations()
|
||||
image = self._format_raw_obs(raw_obs)["pixels"]["image"]
|
||||
image = image[::-1, ::-1] # flip both H and W for visualization
|
||||
return image
|
||||
|
||||
def _make_envs_task(self, task_suite: Any, task_id: int = 0):
|
||||
@@ -212,23 +237,48 @@ class LiberoEnv(gym.Env):
|
||||
images = {}
|
||||
for camera_name in self.camera_name:
|
||||
image = raw_obs[camera_name]
|
||||
image = image[::-1, ::-1] # rotate 180 degrees
|
||||
images[self.camera_name_mapping[camera_name]] = image
|
||||
state = np.concatenate(
|
||||
(
|
||||
raw_obs["robot0_eef_pos"],
|
||||
quat2axisangle(raw_obs["robot0_eef_quat"]),
|
||||
raw_obs["robot0_gripper_qpos"],
|
||||
)
|
||||
)
|
||||
agent_pos = state
|
||||
|
||||
eef_pos = raw_obs.get("robot0_eef_pos")
|
||||
eef_quat = raw_obs.get("robot0_eef_quat")
|
||||
|
||||
# rotation matrix from controller
|
||||
eef_mat = self._env.robots[0].controller.ee_ori_mat if eef_pos is not None else None
|
||||
gripper_qpos = raw_obs.get("robot0_gripper_qpos")
|
||||
gripper_qvel = raw_obs.get("robot0_gripper_qvel")
|
||||
joint_pos = raw_obs.get("robot0_joint_pos")
|
||||
joint_vel = raw_obs.get("robot0_joint_vel")
|
||||
obs = {
|
||||
"pixels": images,
|
||||
"robot_state": {
|
||||
"eef": {
|
||||
"pos": eef_pos, # (3,)
|
||||
"quat": eef_quat, # (4,)
|
||||
"mat": eef_mat, # (3, 3)
|
||||
},
|
||||
"gripper": {
|
||||
"qpos": gripper_qpos, # (2,)
|
||||
"qvel": gripper_qvel, # (2,)
|
||||
},
|
||||
"joints": {
|
||||
"pos": joint_pos, # (7,)
|
||||
"vel": joint_vel, # (7,)
|
||||
},
|
||||
},
|
||||
}
|
||||
if self.obs_type == "pixels":
|
||||
return {"pixels": images.copy()}
|
||||
|
||||
if self.obs_type == "pixels_agent_pos":
|
||||
return {
|
||||
"pixels": images.copy(),
|
||||
"agent_pos": agent_pos,
|
||||
}
|
||||
# Validate required fields are present
|
||||
if eef_pos is None or eef_quat is None or gripper_qpos is None:
|
||||
raise ValueError(
|
||||
f"Missing required robot state fields in raw observation. "
|
||||
f"Got eef_pos={eef_pos is not None}, eef_quat={eef_quat is not None}, "
|
||||
f"gripper_qpos={gripper_qpos is not None}"
|
||||
)
|
||||
return obs
|
||||
|
||||
raise NotImplementedError(
|
||||
f"The observation type '{self.obs_type}' is not supported in LiberoEnv. "
|
||||
"Please switch to an image-based obs_type (e.g. 'pixels', 'pixels_agent_pos')."
|
||||
@@ -355,12 +405,10 @@ def create_libero_envs(
|
||||
print(f"Restricting to task_ids={task_ids_filter}")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
|
||||
|
||||
@@ -29,10 +29,22 @@ from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE
|
||||
from lerobot.utils.constants import OBS_ENV_STATE, OBS_IMAGE, OBS_IMAGES, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.utils import get_channel_first_image_shape
|
||||
|
||||
|
||||
def _convert_nested_dict(d):
|
||||
result = {}
|
||||
for k, v in d.items():
|
||||
if isinstance(v, dict):
|
||||
result[k] = _convert_nested_dict(v)
|
||||
elif isinstance(v, np.ndarray):
|
||||
result[k] = torch.from_numpy(v)
|
||||
else:
|
||||
result[k] = v
|
||||
return result
|
||||
|
||||
|
||||
def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Tensor]:
|
||||
# TODO(aliberts, rcadene): refactor this to use features from the environment (no hardcoding)
|
||||
"""Convert environment observation to LeRobot format observation.
|
||||
@@ -78,12 +90,14 @@ def preprocess_observation(observations: dict[str, np.ndarray]) -> dict[str, Ten
|
||||
|
||||
return_observations[OBS_ENV_STATE] = env_state
|
||||
|
||||
# TODO(rcadene): enable pixels only baseline with `obs_type="pixels"` in environment by removing
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
if "agent_pos" in observations:
|
||||
agent_pos = torch.from_numpy(observations["agent_pos"]).float()
|
||||
if agent_pos.dim() == 1:
|
||||
agent_pos = agent_pos.unsqueeze(0)
|
||||
return_observations[OBS_STATE] = agent_pos
|
||||
|
||||
if "robot_state" in observations:
|
||||
return_observations[f"{OBS_STR}.robot_state"] = _convert_nested_dict(observations["robot_state"])
|
||||
return return_observations
|
||||
|
||||
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -47,6 +48,9 @@ class PI0Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI0 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI0Config):
|
||||
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(
|
||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
state,
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
state=state,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI0 model
|
||||
self.model = PI0Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
|
||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
state = self.prepare_state(batch)
|
||||
|
||||
# Sample actions using the model
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
||||
# Sample actions using the model (pass through RTC kwargs)
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -20,6 +20,7 @@ from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import FeatureType, NormalizationMode, PolicyFeature
|
||||
from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("pi05")
|
||||
@@ -46,6 +47,9 @@ class PI05Config(PreTrainedConfig):
|
||||
min_period: float = 4e-3
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
||||
|
||||
# Add empty images. Used to add empty cameras when no image features are present.
|
||||
|
||||
@@ -19,11 +19,12 @@ import logging
|
||||
import math
|
||||
from collections import deque
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Literal
|
||||
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.utils.import_utils import _transformers_available
|
||||
|
||||
@@ -42,6 +43,7 @@ else:
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.utils.constants import (
|
||||
ACTION,
|
||||
OBS_LANGUAGE_ATTENTION_MASK,
|
||||
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
|
||||
)
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def get_safe_dtype(target_dtype, device_type):
|
||||
"""Get a safe dtype for the given device type."""
|
||||
if device_type == "mps" and target_dtype == torch.float64:
|
||||
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
|
||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
"""Core PI05 PyTorch model."""
|
||||
|
||||
def __init__(self, config: PI05Config):
|
||||
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||
"""Helper method to apply gradient checkpointing if enabled."""
|
||||
if self.gradient_checkpointing_enabled and self.training:
|
||||
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
return F.mse_loss(u_t, v_t, reduction="none")
|
||||
|
||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
tokens,
|
||||
masks,
|
||||
noise=None,
|
||||
num_steps=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action."""
|
||||
if num_steps is None:
|
||||
num_steps = self.config.num_inference_steps
|
||||
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
x_t = x_t + dt * v_t
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
x_t=input_x_t,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
self.config = config
|
||||
|
||||
# Initialize the core PI05 model
|
||||
self.model = PI05Pytorch(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
||||
|
||||
# Enable gradient checkpointing if requested
|
||||
if config.gradient_checkpointing:
|
||||
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Create processor if config provided
|
||||
# If RTC is not enabled - we can still track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||
"""Preprocess images for the model.
|
||||
|
||||
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a single action given environment observations."""
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
|
||||
# Action queue logic for n_action_steps > 1
|
||||
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
|
||||
return self._action_queue.popleft()
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||
"""Predict a chunk of actions given environment observations."""
|
||||
self.eval()
|
||||
|
||||
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
|
||||
images, img_masks = self._preprocess_images(batch)
|
||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
# Sample actions using the model (no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
||||
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||
|
||||
# Unpad actions to actual action dimension
|
||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||
|
||||
@@ -0,0 +1,38 @@
|
||||
# Real-Time Chunking (RTC)
|
||||
|
||||
This module contains the LeRobot implementation of **Real-Time Chunking (RTC)**, an inference-time technique for flow-matching based policies.
|
||||
|
||||
**Note**: RTC is not a policy itself, but rather an inference enhancement that works with flow-matching based policies including [π₀](../pi0/), [π₀.₅](../pi05/), and [SmolVLA](../smolvla/).
|
||||
|
||||
---
|
||||
|
||||
## Citation
|
||||
|
||||
If you use Real-Time Chunking in your work, please cite:
|
||||
|
||||
```bibtex
|
||||
@misc{openpi2024,
|
||||
author = {Physical Intelligence Lab},
|
||||
title = {OpenPI: PyTorch Implementation of π0 and π0.5 Policies},
|
||||
year = {2024},
|
||||
publisher = {GitHub},
|
||||
howpublished = {\url{https://github.com/Physical-Intelligence/openpi}},
|
||||
license = {Apache-2.0}
|
||||
}
|
||||
|
||||
@misc{black2025realtimeexecutionactionchunking,
|
||||
title={Real-Time Execution of Action Chunking Flow Policies},
|
||||
author={Kevin Black and Manuel Y. Galliker and Sergey Levine},
|
||||
year={2025},
|
||||
eprint={2506.07339},
|
||||
archivePrefix={arXiv},
|
||||
primaryClass={cs.RO},
|
||||
url={https://arxiv.org/abs/2506.07339},
|
||||
}
|
||||
```
|
||||
|
||||
---
|
||||
|
||||
## License
|
||||
|
||||
This implementation follows the **Apache 2.0 License**, consistent with the LeRobot project.
|
||||
@@ -0,0 +1,219 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Action queue management for Real-Time Chunking (RTC).
|
||||
|
||||
This module provides ActionQueue, a thread-safe queue for managing action chunks
|
||||
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
|
||||
handling action merging and leftover tracking.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from threading import Lock
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
"""Thread-safe queue for managing action chunks in real-time control.
|
||||
|
||||
This queue handles two types of action sequences:
|
||||
- Original actions: Used for RTC to compute leftovers from previous chunks
|
||||
- Processed actions: Post-processed actions ready for robot execution
|
||||
|
||||
The queue operates in two modes:
|
||||
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
|
||||
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
|
||||
|
||||
Attributes:
|
||||
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
|
||||
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
|
||||
last_index (int): Current consumption index in the queue.
|
||||
"""
|
||||
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
"""Initialize the action queue.
|
||||
|
||||
Args:
|
||||
cfg: RTC configuration controlling queue behavior.
|
||||
"""
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
"""Get the next action from the queue.
|
||||
|
||||
Returns:
|
||||
Tensor | None: The next action (action_dim,) or None if queue is empty.
|
||||
Returns a clone to prevent external modifications.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
"""Get the number of remaining actions in the queue.
|
||||
|
||||
Returns:
|
||||
int: Number of unconsumed actions.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
"""Check if the queue is empty.
|
||||
|
||||
Returns:
|
||||
bool: True if no actions remain, False otherwise.
|
||||
"""
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
"""Get the current action consumption index.
|
||||
|
||||
Returns:
|
||||
int: Index of the next action to be consumed.
|
||||
"""
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor | None:
|
||||
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||
|
||||
These are the unconsumed actions from the current chunk, which will be
|
||||
used by RTC to compute corrections for the next chunk.
|
||||
|
||||
Returns:
|
||||
Tensor | None: Remaining original actions (remaining_steps, action_dim),
|
||||
or None if no original queue exists.
|
||||
"""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
"""Merge new actions into the queue.
|
||||
|
||||
This method operates differently based on RTC mode:
|
||||
- RTC enabled: Replaces the queue, accounting for inference delay
|
||||
- RTC disabled: Appends to the queue, maintaining continuity
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy (time_steps, action_dim).
|
||||
processed_actions: Post-processed actions for robot (time_steps, action_dim).
|
||||
real_delay: Number of time steps of inference delay.
|
||||
action_index_before_inference: Index before inference started, for validation.
|
||||
"""
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
"""Replace the queue with new actions (RTC mode).
|
||||
|
||||
Discards the first `real_delay` actions since they correspond to the time
|
||||
spent during inference, when the robot was executing previous actions.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
real_delay: Number of time steps to skip due to inference delay.
|
||||
"""
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.debug(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
"""Append new actions to the queue (non-RTC mode).
|
||||
|
||||
Removes already-consumed actions and appends new ones, maintaining
|
||||
queue continuity without replacement.
|
||||
|
||||
Args:
|
||||
original_actions: Unprocessed actions from policy.
|
||||
processed_actions: Post-processed actions for robot.
|
||||
"""
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
"""Validate that computed delays match expectations.
|
||||
|
||||
Compares the delay computed from inference latency with the actual
|
||||
number of actions consumed during inference.
|
||||
|
||||
Args:
|
||||
real_delay: Delay computed from inference latency.
|
||||
action_index_before_inference: Action index when inference started.
|
||||
"""
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as delay calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
@@ -0,0 +1,55 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real Time Chunking (RTC) and Bidirectional Decoding (BID) configuration classes.
|
||||
|
||||
Based on:
|
||||
- Real Time Chunking: https://www.physicalintelligence.company/research/real_time_chunking
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCConfig:
|
||||
"""Configuration for Real Time Chunking (RTC) inference.
|
||||
|
||||
RTC improves real-time inference by treating chunk generation as an inpainting problem,
|
||||
strategically handling overlapping timesteps between action chunks using prefix attention.
|
||||
"""
|
||||
|
||||
# Infrastructure
|
||||
enabled: bool = False
|
||||
|
||||
# Core RTC settings
|
||||
# Todo change to exp
|
||||
prefix_attention_schedule: RTCAttentionSchedule = RTCAttentionSchedule.LINEAR
|
||||
max_guidance_weight: float = 10.0
|
||||
execution_horizon: int = 10
|
||||
|
||||
# Debug settings
|
||||
debug: bool = False
|
||||
debug_maxlen: int = 100
|
||||
|
||||
def __post_init__(self):
|
||||
"""Validate RTC configuration parameters."""
|
||||
if self.max_guidance_weight <= 0:
|
||||
raise ValueError(f"max_guidance_weight must be positive, got {self.max_guidance_weight}")
|
||||
if self.debug_maxlen <= 0:
|
||||
raise ValueError(f"debug_maxlen must be positive, got {self.debug_maxlen}")
|
||||
@@ -0,0 +1,233 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Debug information handler for Real-Time Chunking (RTC)."""
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
|
||||
@dataclass
|
||||
class DebugStep:
|
||||
"""Container for debug information from a single denoising step.
|
||||
|
||||
Attributes:
|
||||
step_idx (int): Step index/counter.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction (x_t - time * v_t).
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
time (float | Tensor | None): Time parameter.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
metadata (dict[str, Any]): Additional metadata.
|
||||
"""
|
||||
|
||||
step_idx: int = 0
|
||||
x_t: Tensor | None = None
|
||||
v_t: Tensor | None = None
|
||||
x1_t: Tensor | None = None
|
||||
correction: Tensor | None = None
|
||||
err: Tensor | None = None
|
||||
weights: Tensor | None = None
|
||||
guidance_weight: float | Tensor | None = None
|
||||
time: float | Tensor | None = None
|
||||
inference_delay: int | None = None
|
||||
execution_horizon: int | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
def to_dict(self, include_tensors: bool = False) -> dict[str, Any]:
|
||||
"""Convert debug step to dictionary.
|
||||
|
||||
Args:
|
||||
include_tensors (bool): If True, include tensor values. If False, only include
|
||||
tensor statistics (shape, mean, std, min, max).
|
||||
|
||||
Returns:
|
||||
Dictionary representation of the debug step.
|
||||
"""
|
||||
result = {
|
||||
"step_idx": self.step_idx,
|
||||
"guidance_weight": (
|
||||
self.guidance_weight.item()
|
||||
if isinstance(self.guidance_weight, Tensor)
|
||||
else self.guidance_weight
|
||||
),
|
||||
"time": self.time.item() if isinstance(self.time, Tensor) else self.time,
|
||||
"inference_delay": self.inference_delay,
|
||||
"execution_horizon": self.execution_horizon,
|
||||
"metadata": self.metadata.copy(),
|
||||
}
|
||||
|
||||
# Add tensor information
|
||||
tensor_fields = ["x_t", "v_t", "x1_t", "correction", "err", "weights"]
|
||||
for field_name in tensor_fields:
|
||||
tensor = getattr(self, field_name)
|
||||
if tensor is not None:
|
||||
if include_tensors:
|
||||
result[field_name] = tensor.detach().cpu()
|
||||
else:
|
||||
result[f"{field_name}_stats"] = {
|
||||
"shape": tuple(tensor.shape),
|
||||
"mean": tensor.mean().item(),
|
||||
"std": tensor.std().item(),
|
||||
"min": tensor.min().item(),
|
||||
"max": tensor.max().item(),
|
||||
}
|
||||
|
||||
return result
|
||||
|
||||
|
||||
class Tracker:
|
||||
"""Collects and manages debug information for RTC processing.
|
||||
|
||||
This tracker stores debug information from recent denoising steps in a dictionary,
|
||||
using time as the key for efficient lookups and updates.
|
||||
|
||||
Args:
|
||||
enabled (bool): Whether debug collection is enabled.
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` debug steps are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, enabled: bool = False, maxlen: int = 100):
|
||||
self.enabled = enabled
|
||||
self._steps = {} if enabled else None # Dictionary with time as key
|
||||
self._maxlen = maxlen
|
||||
self._step_counter = 0
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded debug information."""
|
||||
if self.enabled and self._steps is not None:
|
||||
self._steps.clear()
|
||||
self._step_counter = 0
|
||||
|
||||
@torch._dynamo.disable
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Track debug information for a denoising step at a given time.
|
||||
|
||||
If a step with the given time already exists, it will be updated with the new data.
|
||||
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
||||
|
||||
Note: This method is excluded from torch.compile to avoid graph breaks from
|
||||
operations like .item() which are incompatible with compiled graphs.
|
||||
|
||||
Args:
|
||||
time (float | Tensor): Time parameter - used as the key to identify the step.
|
||||
x_t (Tensor | None): Current latent/state tensor.
|
||||
v_t (Tensor | None): Velocity from denoiser.
|
||||
x1_t (Tensor | None): Denoised prediction.
|
||||
correction (Tensor | None): Correction gradient tensor.
|
||||
err (Tensor | None): Weighted error term.
|
||||
weights (Tensor | None): Prefix attention weights.
|
||||
guidance_weight (float | Tensor | None): Applied guidance weight.
|
||||
inference_delay (int | None): Inference delay parameter.
|
||||
execution_horizon (int | None): Execution horizon parameter.
|
||||
**metadata: Additional metadata to store.
|
||||
"""
|
||||
if not self.enabled:
|
||||
return
|
||||
|
||||
# Convert time to float and round to avoid float precision issues
|
||||
time_value = time.item() if isinstance(time, Tensor) else time
|
||||
time_key = round(time_value, 6) # Use rounded time as dictionary key
|
||||
|
||||
# Check if step with this time already exists
|
||||
if time_key in self._steps:
|
||||
# Update existing step with non-None fields
|
||||
existing_step = self._steps[time_key]
|
||||
if x_t is not None:
|
||||
existing_step.x_t = x_t.detach().clone()
|
||||
if v_t is not None:
|
||||
existing_step.v_t = v_t.detach().clone()
|
||||
if x1_t is not None:
|
||||
existing_step.x1_t = x1_t.detach().clone()
|
||||
if correction is not None:
|
||||
existing_step.correction = correction.detach().clone()
|
||||
if err is not None:
|
||||
existing_step.err = err.detach().clone()
|
||||
if weights is not None:
|
||||
existing_step.weights = weights.detach().clone()
|
||||
if guidance_weight is not None:
|
||||
existing_step.guidance_weight = guidance_weight
|
||||
if inference_delay is not None:
|
||||
existing_step.inference_delay = inference_delay
|
||||
if execution_horizon is not None:
|
||||
existing_step.execution_horizon = execution_horizon
|
||||
if metadata:
|
||||
existing_step.metadata.update(metadata)
|
||||
else:
|
||||
# Create new step
|
||||
step = DebugStep(
|
||||
step_idx=self._step_counter,
|
||||
x_t=x_t.detach().clone() if x_t is not None else None,
|
||||
v_t=v_t.detach().clone() if v_t is not None else None,
|
||||
x1_t=x1_t.detach().clone() if x1_t is not None else None,
|
||||
correction=correction.detach().clone() if correction is not None else None,
|
||||
err=err.detach().clone() if err is not None else None,
|
||||
weights=weights.detach().clone() if weights is not None else None,
|
||||
guidance_weight=guidance_weight,
|
||||
time=time_value,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
metadata=metadata,
|
||||
)
|
||||
|
||||
# Add to dictionary
|
||||
self._steps[time_key] = step
|
||||
self._step_counter += 1
|
||||
|
||||
# Enforce maxlen if set
|
||||
if self._maxlen is not None and len(self._steps) > self._maxlen:
|
||||
# Remove oldest entry (first key in dict - Python 3.7+ preserves insertion order)
|
||||
oldest_key = next(iter(self._steps))
|
||||
del self._steps[oldest_key]
|
||||
|
||||
def get_all_steps(self) -> list[DebugStep]:
|
||||
"""Get all recorded debug steps.
|
||||
|
||||
Returns:
|
||||
List of all DebugStep objects (may be empty if disabled).
|
||||
"""
|
||||
if not self.enabled or self._steps is None:
|
||||
return []
|
||||
|
||||
return list(self._steps.values())
|
||||
|
||||
def __len__(self) -> int:
|
||||
"""Return the number of recorded debug steps."""
|
||||
if not self.enabled or self._steps is None:
|
||||
return 0
|
||||
return len(self._steps)
|
||||
@@ -0,0 +1,113 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Visualization utilities for RTC debug information."""
|
||||
|
||||
import torch
|
||||
|
||||
|
||||
class RTCDebugVisualizer:
|
||||
"""Visualizer for RTC debug information.
|
||||
|
||||
This class provides methods to visualize debug information collected by the Tracker,
|
||||
including corrections, errors, weights, and guidance weights over denoising steps.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def plot_waypoints(
|
||||
axes,
|
||||
tensor,
|
||||
start_from: int = 0,
|
||||
color: str = "blue",
|
||||
label: str = "",
|
||||
alpha: float = 0.7,
|
||||
linewidth: float = 2,
|
||||
marker: str | None = None,
|
||||
markersize: int = 4,
|
||||
):
|
||||
"""Plot trajectories across multiple dimensions.
|
||||
|
||||
This function plots a tensor's values across time for multiple dimensions,
|
||||
with each dimension plotted on a separate axis.
|
||||
|
||||
Args:
|
||||
axes: Array of matplotlib axes (one for each dimension).
|
||||
tensor: The tensor to plot (can be torch.Tensor or numpy array).
|
||||
Shape should be (time_steps, num_dims) or (batch, time_steps, num_dims).
|
||||
start_from: Starting index for the x-axis.
|
||||
color: Color for the plot lines.
|
||||
label: Label for the plot legend.
|
||||
alpha: Transparency level for the plot.
|
||||
linewidth: Width of the plot lines.
|
||||
marker: Marker style for data points (e.g., 'o', 's', '^').
|
||||
markersize: Size of the markers.
|
||||
"""
|
||||
import numpy as np
|
||||
|
||||
# Handle None tensor
|
||||
if tensor is None:
|
||||
return
|
||||
|
||||
# Convert tensor to numpy if needed
|
||||
tensor_np = tensor.detach().cpu().numpy() if isinstance(tensor, torch.Tensor) else tensor
|
||||
|
||||
# Handle different tensor shapes
|
||||
if tensor_np.ndim == 3:
|
||||
# If batch dimension present, take first batch
|
||||
tensor_np = tensor_np[0]
|
||||
elif tensor_np.ndim == 1:
|
||||
# If 1D, reshape to (time_steps, 1)
|
||||
tensor_np = tensor_np.reshape(-1, 1)
|
||||
|
||||
# Get dimensions
|
||||
time_steps, num_dims = tensor_np.shape
|
||||
|
||||
# Create x-axis indices
|
||||
x_indices = np.arange(start_from, start_from + time_steps)
|
||||
|
||||
# Plot each dimension on its corresponding axis
|
||||
num_axes = len(axes) if hasattr(axes, "__len__") else 1
|
||||
for dim_idx in range(min(num_dims, num_axes)):
|
||||
ax = axes[dim_idx] if hasattr(axes, "__len__") else axes
|
||||
|
||||
# Plot the trajectory
|
||||
if marker:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
marker=marker,
|
||||
markersize=markersize,
|
||||
)
|
||||
else:
|
||||
ax.plot(
|
||||
x_indices,
|
||||
tensor_np[:, dim_idx],
|
||||
color=color,
|
||||
label=label if dim_idx == 0 else "", # Only show label once
|
||||
alpha=alpha,
|
||||
linewidth=linewidth,
|
||||
)
|
||||
|
||||
# Add grid and labels if not already present
|
||||
if not ax.xaxis.get_label().get_text():
|
||||
ax.set_xlabel("Step", fontsize=10)
|
||||
if not ax.yaxis.get_label().get_text():
|
||||
ax.set_ylabel(f"Dim {dim_idx}", fontsize=10)
|
||||
ax.grid(True, alpha=0.3)
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Latency tracking utilities for Real-Time Chunking (RTC)."""
|
||||
|
||||
from collections import deque
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
class LatencyTracker:
|
||||
"""Tracks recent latencies and provides max/percentile queries.
|
||||
|
||||
Args:
|
||||
maxlen (int | None): Optional sliding window size. If provided, only the
|
||||
most recent ``maxlen`` latencies are kept. If ``None``, keeps all.
|
||||
"""
|
||||
|
||||
def __init__(self, maxlen: int = 100):
|
||||
self._values = deque(maxlen=maxlen)
|
||||
self.reset()
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Clear all recorded latencies."""
|
||||
self._values.clear()
|
||||
self.max_latency = 0.0
|
||||
|
||||
def add(self, latency: float) -> None:
|
||||
"""Add a latency sample (seconds)."""
|
||||
# Ensure numeric and non-negative
|
||||
val = float(latency)
|
||||
|
||||
if val < 0:
|
||||
return
|
||||
self._values.append(val)
|
||||
self.max_latency = max(self.max_latency, val)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self._values)
|
||||
|
||||
def max(self) -> float | None:
|
||||
"""Return the maximum latency or None if empty."""
|
||||
return self.max_latency
|
||||
|
||||
def percentile(self, q: float) -> float | None:
|
||||
"""Return the q-quantile (q in [0,1]) of recorded latencies or None if empty."""
|
||||
if not self._values:
|
||||
return 0.0
|
||||
q = float(q)
|
||||
if q <= 0.0:
|
||||
return min(self._values)
|
||||
if q >= 1.0:
|
||||
return self.max_latency
|
||||
vals = np.array(list(self._values), dtype=np.float32)
|
||||
return float(np.quantile(vals, q))
|
||||
|
||||
def p95(self) -> float | None:
|
||||
"""Return the 95th percentile latency or None if empty."""
|
||||
return self.percentile(0.95)
|
||||
@@ -0,0 +1,297 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Real-Time Chunking (RTC) implementation for LeRobot.
|
||||
|
||||
Based on Physical Intelligence's Kinetix implementation:
|
||||
https://github.com/Physical-Intelligence/real-time-chunking-kinetix/blob/main/src/model.py#L214
|
||||
"""
|
||||
|
||||
import logging
|
||||
import math
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_tracker import Tracker
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
class RTCProcessor:
|
||||
"""Real-Time Chunking processor for action chunking policies.
|
||||
|
||||
This class implements RTC techniques including velocity calculation,
|
||||
prefix attention, and adaptive chunk processing.
|
||||
"""
|
||||
|
||||
def __init__(self, rtc_config: RTCConfig):
|
||||
self.rtc_config = rtc_config
|
||||
|
||||
self.tracker = None
|
||||
|
||||
if rtc_config.debug:
|
||||
self.tracker = Tracker(
|
||||
enabled=rtc_config.debug,
|
||||
maxlen=rtc_config.debug_maxlen,
|
||||
)
|
||||
|
||||
# ====================== Tracker Proxy Methods ======================
|
||||
def track(
|
||||
self,
|
||||
time: float | Tensor,
|
||||
x_t: Tensor | None = None,
|
||||
v_t: Tensor | None = None,
|
||||
x1_t: Tensor | None = None,
|
||||
correction: Tensor | None = None,
|
||||
err: Tensor | None = None,
|
||||
weights: Tensor | None = None,
|
||||
guidance_weight: float | Tensor | None = None,
|
||||
inference_delay: int | None = None,
|
||||
execution_horizon: int | None = None,
|
||||
**metadata,
|
||||
) -> None:
|
||||
"""Proxy method to track debug information.
|
||||
|
||||
If tracker is None or disabled, this method does nothing.
|
||||
Otherwise, it forwards the call to tracker.track().
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.track(
|
||||
time=time,
|
||||
x_t=x_t,
|
||||
v_t=v_t,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
**metadata,
|
||||
)
|
||||
|
||||
def get_all_debug_steps(self) -> list:
|
||||
"""Get all debug steps from tracker.
|
||||
|
||||
Returns empty list if tracker is disabled or None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
return self.tracker.get_all_steps()
|
||||
return []
|
||||
|
||||
def is_debug_enabled(self) -> bool:
|
||||
"""Check if debug tracking is enabled.
|
||||
|
||||
Returns True if tracker exists and is enabled.
|
||||
"""
|
||||
return self.tracker is not None and self.tracker.enabled
|
||||
|
||||
def reset_tracker(self) -> None:
|
||||
"""Reset the tracker, clearing all recorded steps.
|
||||
|
||||
Does nothing if tracker is None.
|
||||
"""
|
||||
if self.tracker is not None:
|
||||
self.tracker.reset()
|
||||
|
||||
# ====================== End Tracker Proxy Methods ======================
|
||||
|
||||
def denoise_step(
|
||||
self,
|
||||
x_t,
|
||||
prev_chunk_left_over,
|
||||
inference_delay,
|
||||
time,
|
||||
original_denoise_step_partial,
|
||||
execution_horizon=None,
|
||||
) -> Tensor:
|
||||
"""RTC guidance wrapper around an existing denoiser.
|
||||
|
||||
This method wraps an original denoising callable that only takes ``x_t`` and
|
||||
returns a base denoised velocity ``v_t``. It then applies Real-Time Chunking
|
||||
(RTC) prefix guidance using the leftover prefix from the previous chunk.
|
||||
|
||||
Args:
|
||||
x_t (Tensor): Current latent/state to denoise. Shape ``(B, T, A)`` or ``(T, A)``.
|
||||
prev_chunk_left_over (Tensor | None): Unexecuted prefix from the previous
|
||||
chunk. Shape ``(B, T_prev, A)`` or ``(T_prev, A)``. If ``None``, no guidance
|
||||
is applied and the method returns ``v_t`` from the original denoiser.
|
||||
inference_delay (int): Number of timesteps from the prefix to use for guidance.
|
||||
time (float | Tensor): Scalar in [0, 1] indicating normalized time. Must be
|
||||
broadcastable with ``x_t``.
|
||||
original_denoise_step_partial (Callable[[Tensor], Tensor]): Callable that
|
||||
computes the base denoised velocity given only ``x_t``.
|
||||
execution_horizon (int | None): Horizon used to build prefix weights. If
|
||||
``None``, defaults to ``self.rtc_config.execution_horizon``.
|
||||
|
||||
Returns:
|
||||
Tensor: Guided velocity with the same shape as ``v_t``.
|
||||
|
||||
Notes:
|
||||
- If inputs are 2D, a batch dimension is temporarily added and removed at the end.
|
||||
- If ``prev_chunk_left_over`` is shorter than the current chunk length ``T``, it is
|
||||
right-padded with zeros to match ``T``.
|
||||
- Prefix weights are constructed via ``get_prefix_weights(inference_delay, execution_horizon, T)``
|
||||
and broadcast to ``(B, T, A)``.
|
||||
- Guidance correction is computed via autograd using ``x1_t = x_t + time * v_t`` and
|
||||
``error = (prev_chunk_left_over - x1_t) * weights``.
|
||||
- The final guidance weight is clamped by ``max_guidance_weight`` from the config.
|
||||
|
||||
Reference:
|
||||
https://www.physicalintelligence.company/download/real_time_chunking.pdf
|
||||
"""
|
||||
|
||||
# In the original implementation, the time goes from 0 to 1 and
|
||||
# In our implementation, the time goes from 1 to 0
|
||||
# So we need to invert the time
|
||||
tau = 1 - time
|
||||
|
||||
if prev_chunk_left_over is None:
|
||||
# First step, no guidance - return v_t
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
return v_t
|
||||
|
||||
x_t = x_t.clone().detach()
|
||||
|
||||
squeezed = False
|
||||
if len(x_t.shape) < 3:
|
||||
# Add batch dimension
|
||||
x_t = x_t.unsqueeze(0)
|
||||
squeezed = True
|
||||
|
||||
if len(prev_chunk_left_over.shape) < 3:
|
||||
# Add batch dimension
|
||||
prev_chunk_left_over = prev_chunk_left_over.unsqueeze(0)
|
||||
|
||||
if execution_horizon is None:
|
||||
execution_horizon = self.rtc_config.execution_horizon
|
||||
|
||||
# If the previous action chunk is to short then it doesn't make sense to use long execution horizon
|
||||
# because there is nothing to merge
|
||||
if execution_horizon > prev_chunk_left_over.shape[1]:
|
||||
execution_horizon = prev_chunk_left_over.shape[1]
|
||||
|
||||
batch_size = x_t.shape[0]
|
||||
action_chunk_size = x_t.shape[1]
|
||||
action_dim = x_t.shape[2]
|
||||
|
||||
if prev_chunk_left_over.shape[1] < action_chunk_size or prev_chunk_left_over.shape[2] < action_dim:
|
||||
padded = torch.zeros(batch_size, action_chunk_size, action_dim).to(x_t.device)
|
||||
padded[:, : prev_chunk_left_over.shape[1], : prev_chunk_left_over.shape[2]] = prev_chunk_left_over
|
||||
prev_chunk_left_over = padded
|
||||
|
||||
assert prev_chunk_left_over.shape == x_t.shape, (
|
||||
"The padded previous chunk must be the same size as the input tensor"
|
||||
)
|
||||
|
||||
weights = (
|
||||
self.get_prefix_weights(inference_delay, execution_horizon, action_chunk_size)
|
||||
.to(x_t.device)
|
||||
.unsqueeze(0)
|
||||
.unsqueeze(-1)
|
||||
)
|
||||
|
||||
with torch.enable_grad():
|
||||
v_t = original_denoise_step_partial(x_t)
|
||||
x_t.requires_grad_(True)
|
||||
|
||||
x1_t = x_t - time * v_t # noqa: N806
|
||||
err = (prev_chunk_left_over - x1_t) * weights
|
||||
grad_outputs = err.clone().detach()
|
||||
correction = torch.autograd.grad(x1_t, x_t, grad_outputs, retain_graph=False)[0]
|
||||
|
||||
max_guidance_weight = torch.as_tensor(self.rtc_config.max_guidance_weight)
|
||||
tau_tensor = torch.as_tensor(tau)
|
||||
squared_one_minus_tau = (1 - tau_tensor) ** 2
|
||||
inv_r2 = (squared_one_minus_tau + tau_tensor**2) / (squared_one_minus_tau)
|
||||
c = torch.nan_to_num((1 - tau_tensor) / tau_tensor, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.nan_to_num(c * inv_r2, posinf=max_guidance_weight)
|
||||
guidance_weight = torch.minimum(guidance_weight, max_guidance_weight)
|
||||
|
||||
result = v_t - guidance_weight * correction
|
||||
|
||||
# Remove the batch dimension if it was added
|
||||
if squeezed:
|
||||
result = result.squeeze(0)
|
||||
correction = correction.squeeze(0)
|
||||
x1_t = x1_t.squeeze(0)
|
||||
err = err.squeeze(0)
|
||||
|
||||
self.track(
|
||||
time=time,
|
||||
x1_t=x1_t,
|
||||
correction=correction,
|
||||
err=err,
|
||||
weights=weights,
|
||||
guidance_weight=guidance_weight,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def get_prefix_weights(self, start, end, total):
|
||||
start = min(start, end)
|
||||
|
||||
if self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ZEROS:
|
||||
weights = torch.zeros(total)
|
||||
weights[:start] = 1.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.ONES:
|
||||
weights = torch.ones(total)
|
||||
weights[end:] = 0.0
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
elif self.rtc_config.prefix_attention_schedule == RTCAttentionSchedule.EXP:
|
||||
lin_weights = self._linweights(start, end, total)
|
||||
lin_weights = lin_weights * torch.expm1(lin_weights).div(math.e - 1)
|
||||
weights = self._add_trailing_zeros(lin_weights, total, end)
|
||||
weights = self._add_leading_ones(weights, start, total)
|
||||
|
||||
return weights
|
||||
|
||||
def _linweights(self, start, end, total):
|
||||
skip_steps_at_end = max(total - end, 0)
|
||||
|
||||
linspace_steps = total - skip_steps_at_end - start
|
||||
|
||||
if end <= start or linspace_steps <= 0:
|
||||
return torch.tensor([])
|
||||
|
||||
return torch.linspace(1, 0, linspace_steps + 2)[1:-1]
|
||||
|
||||
def _add_trailing_zeros(self, weights, total, end):
|
||||
zeros_len = total - end
|
||||
|
||||
if zeros_len <= 0:
|
||||
return weights
|
||||
|
||||
zeros = torch.zeros(zeros_len)
|
||||
return torch.cat([weights, zeros])
|
||||
|
||||
def _add_leading_ones(self, weights, start, total):
|
||||
ones_len = min(start, total)
|
||||
|
||||
if ones_len <= 0:
|
||||
return weights
|
||||
|
||||
ones = torch.ones(ones_len)
|
||||
return torch.cat([ones, weights])
|
||||
@@ -20,6 +20,7 @@ from lerobot.optim.optimizers import AdamWConfig
|
||||
from lerobot.optim.schedulers import (
|
||||
CosineDecayWithWarmupSchedulerConfig,
|
||||
)
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
|
||||
@@ -102,6 +103,9 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
min_period: float = 4e-3 # sensitivity range for the timestep used in sine-cosine positional encoding
|
||||
max_period: float = 4.0
|
||||
|
||||
# Real-Time Chunking (RTC) configuration
|
||||
rtc_config: RTCConfig | None = None
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
|
||||
@@ -54,12 +54,15 @@ policy = SmolVLAPolicy.from_pretrained("lerobot/smolvla_base")
|
||||
|
||||
import math
|
||||
from collections import deque
|
||||
from typing import TypedDict
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor, nn
|
||||
from typing_extensions import Unpack
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig
|
||||
from lerobot.policies.smolvla.smolvlm_with_expert import SmolVLMWithExpertModel
|
||||
from lerobot.policies.utils import (
|
||||
@@ -69,6 +72,12 @@ from lerobot.utils.constants import ACTION, OBS_LANGUAGE_ATTENTION_MASK, OBS_LAN
|
||||
from lerobot.utils.utils import get_safe_dtype
|
||||
|
||||
|
||||
class ActionSelectKwargs(TypedDict, total=False):
|
||||
inference_delay: int | None
|
||||
prev_chunk_left_over: Tensor | None
|
||||
execution_horizon: int | None
|
||||
|
||||
|
||||
def create_sinusoidal_pos_embedding(
|
||||
time: torch.tensor, dimension: int, min_period: float, max_period: float, device="cpu"
|
||||
) -> Tensor:
|
||||
@@ -232,8 +241,8 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
super().__init__(config)
|
||||
config.validate_features()
|
||||
self.config = config
|
||||
|
||||
self.model = VLAFlowMatching(config)
|
||||
self.init_rtc_processor()
|
||||
self.model = VLAFlowMatching(config, rtc_processor=self.rtc_processor)
|
||||
self.reset()
|
||||
|
||||
def reset(self):
|
||||
@@ -242,10 +251,28 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||
}
|
||||
|
||||
def init_rtc_processor(self):
|
||||
"""Initialize RTC processor if RTC is enabled in config."""
|
||||
self.rtc_processor = None
|
||||
|
||||
# Lets create processor if the config provided
|
||||
# If RTC is not enabled - we still can track the denoising data
|
||||
if self.config.rtc_config is not None:
|
||||
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||
|
||||
# In case of calling init_rtc_processor after the model is created
|
||||
# We need to set the rtc_processor to the model
|
||||
# During the normal initialization process the model is not created yet
|
||||
model_value = getattr(self, "model", None)
|
||||
if model_value is not None:
|
||||
model_value.rtc_processor = self.rtc_processor
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
return self.parameters()
|
||||
|
||||
def _get_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def _get_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
# TODO: Check if this for loop is needed.
|
||||
# Context: In fact, self.queues contains only ACTION field, and in inference, we don't have action in the batch
|
||||
# In the case of offline inference, we have the action in the batch
|
||||
@@ -260,7 +287,9 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
lang_tokens = batch[f"{OBS_LANGUAGE_TOKENS}"]
|
||||
lang_masks = batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||
|
||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, noise=noise)
|
||||
actions = self.model.sample_actions(
|
||||
images, img_masks, lang_tokens, lang_masks, state, noise=noise, **kwargs
|
||||
)
|
||||
|
||||
# Unpad actions
|
||||
original_action_dim = self.config.action_feature.shape[0]
|
||||
@@ -278,30 +307,37 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
return batch
|
||||
|
||||
@torch.no_grad()
|
||||
def predict_action_chunk(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def predict_action_chunk(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
self.eval()
|
||||
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
actions = self._get_action_chunk(batch, noise, **kwargs)
|
||||
return actions
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor], noise: Tensor | None = None) -> Tensor:
|
||||
def select_action(
|
||||
self, batch: dict[str, Tensor], noise: Tensor | None = None, **kwargs: Unpack[ActionSelectKwargs]
|
||||
) -> Tensor:
|
||||
"""Select a single action given environment observations.
|
||||
|
||||
This method wraps `select_actions` in order to return one action at a time for execution in the
|
||||
environment. It works by managing the actions in a queue and only calling `select_actions` when the
|
||||
queue is empty.
|
||||
"""
|
||||
|
||||
assert not self._rtc_enabled(), (
|
||||
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||
)
|
||||
|
||||
self.eval()
|
||||
batch = self._prepare_batch(batch)
|
||||
self._queues = populate_queues(self._queues, batch, exclude_keys=[ACTION])
|
||||
|
||||
# Action queue logic for n_action_steps > 1. When the action_queue is depleted, populate it by
|
||||
# querying the policy.
|
||||
if len(self._queues[ACTION]) == 0:
|
||||
if self._check_get_actions_condition():
|
||||
actions = self._get_action_chunk(batch, noise)
|
||||
|
||||
# `self.predict_action_chunk` returns a (batch_size, n_action_steps, action_dim) tensor, but the queue
|
||||
@@ -310,6 +346,12 @@ class SmolVLAPolicy(PreTrainedPolicy):
|
||||
|
||||
return self._queues[ACTION].popleft()
|
||||
|
||||
def _check_get_actions_condition(self) -> bool:
|
||||
return len(self._queues[ACTION]) == 0
|
||||
|
||||
def _rtc_enabled(self) -> bool:
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def forward(self, batch: dict[str, Tensor], noise=None, time=None) -> dict[str, Tensor]:
|
||||
"""Do a full training forward pass to compute the loss"""
|
||||
if self.config.adapt_to_pi_aloha:
|
||||
@@ -471,7 +513,7 @@ class VLAFlowMatching(nn.Module):
|
||||
└──────────────────────────────┘
|
||||
"""
|
||||
|
||||
def __init__(self, config: SmolVLAConfig):
|
||||
def __init__(self, config: SmolVLAConfig, rtc_processor: RTCProcessor | None = None):
|
||||
super().__init__()
|
||||
self.config = config
|
||||
|
||||
@@ -485,7 +527,6 @@ class VLAFlowMatching(nn.Module):
|
||||
num_vlm_layers=self.config.num_vlm_layers,
|
||||
self_attn_every_n_layers=self.config.self_attn_every_n_layers,
|
||||
expert_width_multiplier=self.config.expert_width_multiplier,
|
||||
device=self.config.device,
|
||||
)
|
||||
self.state_proj = nn.Linear(
|
||||
self.config.max_state_dim, self.vlm_with_expert.config.text_config.hidden_size
|
||||
@@ -510,6 +551,10 @@ class VLAFlowMatching(nn.Module):
|
||||
self.add_image_special_tokens = self.config.add_image_special_tokens
|
||||
self.image_end_token = torch.tensor([self.fake_image_token], dtype=torch.long)
|
||||
self.prefix_length = self.config.prefix_length
|
||||
self.rtc_processor = rtc_processor
|
||||
|
||||
def _rtc_enabled(self):
|
||||
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||
|
||||
def set_requires_grad(self):
|
||||
for params in self.state_proj.parameters():
|
||||
@@ -706,7 +751,16 @@ class VLAFlowMatching(nn.Module):
|
||||
losses = F.mse_loss(u_t, v_t, reduction="none")
|
||||
return losses
|
||||
|
||||
def sample_actions(self, images, img_masks, lang_tokens, lang_masks, state, noise=None) -> Tensor:
|
||||
def sample_actions(
|
||||
self,
|
||||
images,
|
||||
img_masks,
|
||||
lang_tokens,
|
||||
lang_masks,
|
||||
state,
|
||||
noise=None,
|
||||
**kwargs: Unpack[ActionSelectKwargs],
|
||||
) -> Tensor:
|
||||
"""Do a full inference forward and compute the action (batch_size x num_steps x num_motors)"""
|
||||
bsize = state.shape[0]
|
||||
device = state.device
|
||||
@@ -734,17 +788,45 @@ class VLAFlowMatching(nn.Module):
|
||||
|
||||
x_t = noise
|
||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||
|
||||
while time >= -dt / 2:
|
||||
expanded_time = time.expand(bsize)
|
||||
v_t = self.denoise_step(
|
||||
prefix_pad_masks,
|
||||
past_key_values,
|
||||
x_t,
|
||||
expanded_time,
|
||||
)
|
||||
|
||||
# Define a closure function to properly capture expanded_time
|
||||
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||
return self.denoise_step(
|
||||
x_t=input_x_t,
|
||||
prefix_pad_masks=prefix_pad_masks,
|
||||
past_key_values=past_key_values,
|
||||
timestep=current_timestep,
|
||||
)
|
||||
|
||||
if self._rtc_enabled():
|
||||
inference_delay = kwargs.get("inference_delay")
|
||||
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||
execution_horizon = kwargs.get("execution_horizon")
|
||||
|
||||
v_t = self.rtc_processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
inference_delay=inference_delay,
|
||||
time=time,
|
||||
original_denoise_step_partial=denoise_step_partial_call,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
else:
|
||||
v_t = denoise_step_partial_call(x_t)
|
||||
|
||||
# Euler step
|
||||
x_t += dt * v_t
|
||||
|
||||
# Record x_t and v_t after Euler step (other params are recorded in rtc_processor.denoise_step)
|
||||
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||
|
||||
time += dt
|
||||
|
||||
return x_t
|
||||
|
||||
def denoise_step(
|
||||
|
||||
@@ -0,0 +1,154 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from dataclasses import dataclass
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import PipelineFeatureType, PolicyFeature
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
from .pipeline import ObservationProcessorStep, ProcessorStepRegistry
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="libero_processor")
|
||||
class LiberoProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes LIBERO observations into the LeRobot format.
|
||||
|
||||
This step handles the specific observation structure from LIBERO environments,
|
||||
which includes nested robot_state dictionaries and image observations.
|
||||
|
||||
**State Processing:**
|
||||
- Processes the `robot_state` dictionary which contains nested end-effector,
|
||||
gripper, and joint information.
|
||||
- Extracts and concatenates:
|
||||
- End-effector position (3D)
|
||||
- End-effector quaternion converted to axis-angle (3D)
|
||||
- Gripper joint positions (2D)
|
||||
- Maps the concatenated state to `"observation.state"`.
|
||||
|
||||
**Image Processing:**
|
||||
- Rotates images by 180 degrees by flipping both height and width dimensions.
|
||||
- This accounts for the HuggingFaceVLA/libero camera orientation convention.
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation):
|
||||
"""
|
||||
Processes both image and robot_state observations from LIBERO.
|
||||
"""
|
||||
processed_obs = observation.copy()
|
||||
for key in list(processed_obs.keys()):
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
img = processed_obs[key]
|
||||
|
||||
# Flip both H and W
|
||||
img = torch.flip(img, dims=[2, 3])
|
||||
|
||||
processed_obs[key] = img
|
||||
# Process robot_state into a flat state vector
|
||||
if "observation.robot_state" in processed_obs:
|
||||
robot_state = processed_obs.pop("observation.robot_state")
|
||||
|
||||
# Extract components
|
||||
eef_pos = robot_state["eef"]["pos"] # (B, 3,)
|
||||
eef_quat = robot_state["eef"]["quat"] # (B, 4,)
|
||||
gripper_qpos = robot_state["gripper"]["qpos"] # (B, 2,)
|
||||
|
||||
# Convert quaternion to axis-angle
|
||||
eef_axisangle = self._quat2axisangle(eef_quat) # (B, 3)
|
||||
# Concatenate into a single state vector
|
||||
state = torch.cat((eef_pos, eef_axisangle, gripper_qpos), dim=-1)
|
||||
|
||||
# ensure float32
|
||||
state = state.float()
|
||||
if state.dim() == 1:
|
||||
state = state.unsqueeze(0)
|
||||
|
||||
processed_obs[OBS_STATE] = state
|
||||
return processed_obs
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
"""
|
||||
Transforms feature keys from the LIBERO format to the LeRobot standard.
|
||||
"""
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {}
|
||||
|
||||
# copy over non-STATE features
|
||||
for ft, feats in features.items():
|
||||
if ft != PipelineFeatureType.STATE:
|
||||
new_features[ft] = feats.copy()
|
||||
|
||||
# rebuild STATE features
|
||||
state_feats = {}
|
||||
|
||||
# add our new flattened state
|
||||
state_feats["observation.state"] = PolicyFeature(
|
||||
key="observation.state",
|
||||
shape=(8,), # [eef_pos(3), axis_angle(3), gripper(2)]
|
||||
dtype="float32",
|
||||
description=("Concatenated end-effector position (3), axis-angle (3), and gripper qpos (2)."),
|
||||
)
|
||||
|
||||
new_features[PipelineFeatureType.STATE] = state_feats
|
||||
|
||||
return new_features
|
||||
|
||||
def observation(self, observation):
|
||||
return self._process_observation(observation)
|
||||
|
||||
def _quat2axisangle(self, quat: torch.Tensor) -> torch.Tensor:
|
||||
"""
|
||||
Convert batched quaternions to axis-angle format.
|
||||
Only accepts torch tensors of shape (B, 4).
|
||||
|
||||
Args:
|
||||
quat (Tensor): (B, 4) tensor of quaternions in (x, y, z, w) format
|
||||
|
||||
Returns:
|
||||
Tensor: (B, 3) axis-angle vectors
|
||||
|
||||
Raises:
|
||||
TypeError: if input is not a torch tensor
|
||||
ValueError: if shape is not (B, 4)
|
||||
"""
|
||||
|
||||
if not isinstance(quat, torch.Tensor):
|
||||
raise TypeError(f"_quat2axisangle expected a torch.Tensor, got {type(quat)}")
|
||||
|
||||
if quat.ndim != 2 or quat.shape[1] != 4:
|
||||
raise ValueError(f"_quat2axisangle expected shape (B, 4), got {tuple(quat.shape)}")
|
||||
|
||||
quat = quat.to(dtype=torch.float32)
|
||||
device = quat.device
|
||||
batch_size = quat.shape[0]
|
||||
|
||||
w = quat[:, 3].clamp(-1.0, 1.0)
|
||||
|
||||
den = torch.sqrt(torch.clamp(1.0 - w * w, min=0.0))
|
||||
|
||||
result = torch.zeros((batch_size, 3), device=device)
|
||||
|
||||
mask = den > 1e-10
|
||||
|
||||
if mask.any():
|
||||
angle = 2.0 * torch.acos(w[mask]) # (M,)
|
||||
axis = quat[mask, :3] / den[mask].unsqueeze(1)
|
||||
result[mask] = axis * angle.unsqueeze(1)
|
||||
|
||||
return result
|
||||
@@ -78,7 +78,7 @@ from lerobot.transport.utils import (
|
||||
transitions_to_bytes,
|
||||
)
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.transition import (
|
||||
Transition,
|
||||
move_state_dict_to_device,
|
||||
@@ -398,7 +398,7 @@ def act_with_policy(
|
||||
|
||||
if cfg.env.fps is not None:
|
||||
dt_time = time.perf_counter() - start_time
|
||||
busy_wait(1 / cfg.env.fps - dt_time)
|
||||
precise_sleep(1 / cfg.env.fps - dt_time)
|
||||
|
||||
|
||||
# Communication Functions - Group all gRPC/messaging functions
|
||||
|
||||
@@ -74,7 +74,7 @@ from lerobot.teleoperators import (
|
||||
from lerobot.teleoperators.teleoperator import Teleoperator
|
||||
from lerobot.teleoperators.utils import TeleopEvents
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_IMAGES, OBS_STATE, REWARD
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import log_say
|
||||
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
@@ -114,7 +114,7 @@ def reset_follower_position(robot_arm: Robot, target_position: np.ndarray) -> No
|
||||
for pose in trajectory:
|
||||
action_dict = dict(zip(current_position_dict, pose, strict=False))
|
||||
robot_arm.bus.sync_write("Goal_Position", action_dict)
|
||||
busy_wait(0.015)
|
||||
precise_sleep(0.015)
|
||||
|
||||
|
||||
class RobotEnv(gym.Env):
|
||||
@@ -238,7 +238,7 @@ class RobotEnv(gym.Env):
|
||||
reset_follower_position(self.robot, np.array(self.reset_pose))
|
||||
log_say("Reset the environment done.", play_sounds=True)
|
||||
|
||||
busy_wait(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
precise_sleep(self.reset_time_s - (time.perf_counter() - start_time))
|
||||
|
||||
super().reset(seed=seed, options=options)
|
||||
|
||||
@@ -713,7 +713,7 @@ def control_loop(
|
||||
transition = env_processor(transition)
|
||||
|
||||
# Maintain fps timing
|
||||
busy_wait(dt - (time.perf_counter() - step_start_time))
|
||||
precise_sleep(dt - (time.perf_counter() - step_start_time))
|
||||
|
||||
if dataset is not None and cfg.dataset.push_to_hub:
|
||||
logging.info("Pushing dataset to hub")
|
||||
@@ -745,7 +745,7 @@ def replay_trajectory(
|
||||
)
|
||||
transition = action_processor(transition)
|
||||
env.step(transition[TransitionKey.ACTION])
|
||||
busy_wait(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
precise_sleep(1 / cfg.env.fps - (time.perf_counter() - start_time))
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
|
||||
@@ -71,7 +71,7 @@ from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.utils import (
|
||||
add_envs_task,
|
||||
check_env_attributes_and_types,
|
||||
@@ -94,6 +94,8 @@ from lerobot.utils.utils import (
|
||||
def rollout(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
seeds: list[int] | None = None,
|
||||
@@ -165,11 +167,19 @@ def rollout(
|
||||
# Infer "task" from attributes of environments.
|
||||
# TODO: works with SyncVectorEnv but not AsyncVectorEnv
|
||||
observation = add_envs_task(env, observation)
|
||||
|
||||
# Apply environment-specific preprocessing (e.g., LiberoProcessorStep for LIBERO)
|
||||
observation = env_preprocessor(observation)
|
||||
|
||||
observation = preprocessor(observation)
|
||||
with torch.inference_mode():
|
||||
action = policy.select_action(observation)
|
||||
action = postprocessor(action)
|
||||
|
||||
action_transition = {"action": action}
|
||||
action_transition = env_postprocessor(action_transition)
|
||||
action = action_transition["action"]
|
||||
|
||||
# Convert to CPU / numpy.
|
||||
action_numpy: np.ndarray = action.to("cpu").numpy()
|
||||
assert action_numpy.ndim == 2, "Action dimensions should be (batch, action_dim)"
|
||||
@@ -239,6 +249,8 @@ def rollout(
|
||||
def eval_policy(
|
||||
env: gym.vector.VectorEnv,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -319,6 +331,8 @@ def eval_policy(
|
||||
rollout_data = rollout(
|
||||
env=env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
seeds=list(seeds) if seeds else None,
|
||||
@@ -517,10 +531,16 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
@@ -561,6 +581,8 @@ def eval_one(
|
||||
env: gym.vector.VectorEnv,
|
||||
*,
|
||||
policy: PreTrainedPolicy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -576,6 +598,8 @@ def eval_one(
|
||||
task_result = eval_policy(
|
||||
env=env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
@@ -600,6 +624,8 @@ def run_one(
|
||||
env,
|
||||
*,
|
||||
policy,
|
||||
env_preprocessor,
|
||||
env_postprocessor,
|
||||
preprocessor,
|
||||
postprocessor,
|
||||
n_episodes: int,
|
||||
@@ -622,6 +648,8 @@ def run_one(
|
||||
metrics = eval_one(
|
||||
env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
@@ -639,6 +667,8 @@ def run_one(
|
||||
def eval_policy_all(
|
||||
envs: dict[str, dict[int, gym.vector.VectorEnv]],
|
||||
policy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
n_episodes: int,
|
||||
@@ -694,6 +724,8 @@ def eval_policy_all(
|
||||
task_runner = partial(
|
||||
run_one,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=n_episodes,
|
||||
|
||||
@@ -50,7 +50,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
make_teleoperator_from_config,
|
||||
so100_leader,
|
||||
)
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -114,7 +114,7 @@ def find_joint_and_ee_bounds(cfg: FindJointLimitsConfig):
|
||||
print(f"Min joint pos position {np.round(min_pos, 4).tolist()}")
|
||||
break
|
||||
|
||||
busy_wait(0.01)
|
||||
precise_sleep(0.01)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -119,7 +119,7 @@ from lerobot.utils.control_utils import (
|
||||
sanity_check_dataset_robot_compatibility,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
get_safe_torch_device,
|
||||
init_logging,
|
||||
@@ -364,7 +364,7 @@ def record_loop(
|
||||
log_rerun_data(observation=obs_processed, action=action_values)
|
||||
|
||||
dt_s = time.perf_counter() - start_loop_t
|
||||
busy_wait(1 / fps - dt_s)
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
|
||||
timestamp = time.perf_counter() - start_episode_t
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ from lerobot.robots import ( # noqa: F401
|
||||
)
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import (
|
||||
init_logging,
|
||||
log_say,
|
||||
@@ -121,7 +121,7 @@ def replay(cfg: ReplayConfig):
|
||||
_ = robot.send_action(processed_action)
|
||||
|
||||
dt_s = time.perf_counter() - start_episode_t
|
||||
busy_wait(1 / dataset.fps - dt_s)
|
||||
precise_sleep(1 / dataset.fps - dt_s)
|
||||
|
||||
robot.disconnect()
|
||||
|
||||
|
||||
@@ -89,7 +89,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
so101_leader,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_devices
|
||||
from lerobot.utils.robot_utils import busy_wait
|
||||
from lerobot.utils.robot_utils import precise_sleep
|
||||
from lerobot.utils.utils import init_logging, move_cursor_up
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
|
||||
@@ -170,12 +170,13 @@ def teleop_loop(
|
||||
# Display the final robot action that was sent
|
||||
for motor, value in robot_action_to_send.items():
|
||||
print(f"{motor:<{display_len}} | {value:>7.2f}")
|
||||
move_cursor_up(len(robot_action_to_send) + 5)
|
||||
move_cursor_up(len(robot_action_to_send) + 3)
|
||||
|
||||
dt_s = time.perf_counter() - loop_start
|
||||
busy_wait(1 / fps - dt_s)
|
||||
precise_sleep(1 / fps - dt_s)
|
||||
loop_s = time.perf_counter() - loop_start
|
||||
print(f"\ntime: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
print(f"Teleop loop time: {loop_s * 1e3:.2f}ms ({1 / loop_s:.0f} Hz)")
|
||||
move_cursor_up(1)
|
||||
|
||||
if duration is not None and time.perf_counter() - start >= duration:
|
||||
return
|
||||
|
||||
@@ -29,7 +29,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.sampler import EpisodeAwareSampler
|
||||
from lerobot.datasets.utils import cycle
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.utils import close_envs
|
||||
from lerobot.optim.factory import make_optimizer_and_scheduler
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
@@ -259,6 +259,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
logging.info(colored("Output dir:", "yellow", attrs=["bold"]) + f" {cfg.output_dir}")
|
||||
if cfg.env is not None:
|
||||
logging.info(f"{cfg.env.task=}")
|
||||
logging.info("Creating environment processors")
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env)
|
||||
logging.info(f"{cfg.steps=} ({format_big_number(cfg.steps)})")
|
||||
logging.info(f"{dataset.num_frames=} ({format_big_number(dataset.num_frames)})")
|
||||
logging.info(f"{dataset.num_episodes=}")
|
||||
@@ -274,6 +276,7 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
sampler = EpisodeAwareSampler(
|
||||
dataset.meta.episodes["dataset_from_index"],
|
||||
dataset.meta.episodes["dataset_to_index"],
|
||||
episode_indices_to_use=dataset.episodes,
|
||||
drop_n_last_frames=cfg.policy.drop_n_last_frames,
|
||||
shuffle=True,
|
||||
)
|
||||
@@ -384,6 +387,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
eval_info = eval_policy_all(
|
||||
envs=eval_env, # dict[suite][task_id] -> vec_env
|
||||
policy=accelerator.unwrap_model(policy),
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
|
||||
@@ -70,3 +70,15 @@ LOOKAHEAD_BACKTRACKTABLE = 100
|
||||
|
||||
# openpi
|
||||
OPENPI_ATTENTION_MASK_VALUE = -2.3819763e38 # TODO(pepijn): Modify this when extending support to fp8 models
|
||||
|
||||
# Constants for LIBERO observation keys
|
||||
LIBERO_KEY_EEF_POS = "robot_state/eef/pos"
|
||||
LIBERO_KEY_EEF_QUAT = "robot_state/eef/quat"
|
||||
LIBERO_KEY_EEF_MAT = "robot_state/eef/mat"
|
||||
LIBERO_KEY_EEF_AXISANGLE = "robot_state/eef/axisangle"
|
||||
LIBERO_KEY_GRIPPER_QPOS = "robot_state/gripper/qpos"
|
||||
LIBERO_KEY_GRIPPER_QVEL = "robot_state/gripper/qvel"
|
||||
LIBERO_KEY_JOINTS_POS = "robot_state/joints/pos"
|
||||
LIBERO_KEY_JOINTS_VEL = "robot_state/joints/vel"
|
||||
LIBERO_KEY_PIXELS_AGENTVIEW = "pixels/agentview_image"
|
||||
LIBERO_KEY_PIXELS_EYE_IN_HAND = "pixels/robot0_eye_in_hand_image"
|
||||
|
||||
@@ -16,14 +16,40 @@ import platform
|
||||
import time
|
||||
|
||||
|
||||
def busy_wait(seconds):
|
||||
if platform.system() == "Darwin" or platform.system() == "Windows":
|
||||
# On Mac and Windows, `time.sleep` is not accurate and we need to use this while loop trick,
|
||||
# but it consumes CPU cycles.
|
||||
def precise_sleep(seconds: float, spin_threshold: float = 0.010, sleep_margin: float = 0.003):
|
||||
"""
|
||||
Wait for `seconds` with better precision than time.sleep alone at the expense of more CPU usage.
|
||||
|
||||
Parameters:
|
||||
- seconds: duration to wait
|
||||
- spin_threshold: if remaining <= spin_threshold -> spin; otherwise sleep (seconds). Default 10ms
|
||||
- sleep_margin: when sleeping leave this much time before deadline to avoid oversleep. Default 3ms
|
||||
|
||||
Note:
|
||||
The default parameters are chosen to prioritize timing accuracy over CPU usage for the common 30 FPS use case.
|
||||
"""
|
||||
if seconds <= 0:
|
||||
return
|
||||
|
||||
system = platform.system()
|
||||
# On macOS and Windows the scheduler / sleep granularity can make
|
||||
# short sleeps inaccurate. Instead of burning CPU for the whole
|
||||
# duration, sleep for most of the time and spin for the final few
|
||||
# milliseconds to achieve good accuracy with much lower CPU usage.
|
||||
if system in ("Darwin", "Windows"):
|
||||
end_time = time.perf_counter() + seconds
|
||||
while time.perf_counter() < end_time:
|
||||
pass
|
||||
while True:
|
||||
remaining = end_time - time.perf_counter()
|
||||
if remaining <= 0:
|
||||
break
|
||||
# If there's more than a couple milliseconds left, sleep most
|
||||
# of the remaining time and leave a small margin for the final spin.
|
||||
if remaining > spin_threshold:
|
||||
# Sleep but avoid sleeping past the end by leaving a small margin.
|
||||
time.sleep(max(remaining - sleep_margin, 0))
|
||||
else:
|
||||
# Final short spin to hit precise timing without long sleeps.
|
||||
pass
|
||||
else:
|
||||
# On Linux time.sleep is accurate
|
||||
if seconds > 0:
|
||||
time.sleep(seconds)
|
||||
# On Linux time.sleep is accurate enough for most uses
|
||||
time.sleep(seconds)
|
||||
|
||||
@@ -0,0 +1,336 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test PI0.5 policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.pi05 import PI05Config, PI05Policy, make_pi05_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_initialization():
|
||||
"""Test PI0.5 policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ PI0.5 RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_initialization_without_rtc_config():
|
||||
"""Test PI0.5 policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI05Policy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ PI0.5 RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_inference_with_prev_chunk():
|
||||
"""Test PI0.5 policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ PI0.5 RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_inference_without_prev_chunk():
|
||||
"""Test PI0.5 policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ PI0.5 RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi05_rtc_validation_rules():
|
||||
"""Test PI0.5 policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI05Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats (PI0.5 uses QUANTILES normalization)
|
||||
dataset_stats = {
|
||||
"observation.state": {
|
||||
"mean": torch.zeros(14),
|
||||
"std": torch.ones(14),
|
||||
"q01": -torch.ones(14),
|
||||
"q99": torch.ones(14),
|
||||
},
|
||||
"action": {
|
||||
"mean": torch.zeros(7),
|
||||
"std": torch.ones(7),
|
||||
"q01": -torch.ones(7),
|
||||
"q99": torch.ones(7),
|
||||
},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI05Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi05_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
@@ -0,0 +1,378 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test PI0 policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
# Skip this entire module in CI
|
||||
pytestmark = pytest.mark.skipif(
|
||||
os.environ.get("CI") == "true" or os.environ.get("GITHUB_ACTIONS") == "true",
|
||||
reason="This test requires local OpenPI installation and is not meant for CI",
|
||||
)
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.pi0 import PI0Config, PI0Policy, make_pi0_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda # noqa: E402
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_initialization():
|
||||
"""Test PI0 policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ PI0 RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_initialization_without_rtc_config():
|
||||
"""Test PI0 policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, dtype="float32")
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ PI0 RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
def test_pi0_rtc_inference_with_prev_chunk():
|
||||
"""Test PI0 policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ PI0 RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_inference_without_prev_chunk():
|
||||
"""Test PI0 policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ PI0 RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_cuda
|
||||
def test_pi0_rtc_validation_rules():
|
||||
"""Test PI0 policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and preprocessor
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
"""Test PI0 with different RTC attention schedules."""
|
||||
set_seed(42)
|
||||
|
||||
schedules = [
|
||||
RTCAttentionSchedule.ZEROS,
|
||||
RTCAttentionSchedule.ONES,
|
||||
RTCAttentionSchedule.LINEAR,
|
||||
RTCAttentionSchedule.EXP,
|
||||
]
|
||||
|
||||
config = PI0Config(max_action_dim=7, max_state_dim=14, chunk_size=50, dtype="float32")
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
device = config.device
|
||||
|
||||
for schedule in schedules:
|
||||
print(f"Testing schedule: {schedule}")
|
||||
|
||||
# Add RTC config with specific schedule
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=schedule,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
# Instantiate policy
|
||||
policy = PI0Policy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pi0_pre_post_processors(config=config, dataset_stats=dataset_stats)
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
actions = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Verify shape
|
||||
assert actions.shape == (1, config.chunk_size, 7)
|
||||
print(f" ✓ Schedule {schedule}: Test passed")
|
||||
|
||||
print("✓ PI0 RTC different schedules: All schedules tested")
|
||||
@@ -0,0 +1,825 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC ActionQueue module."""
|
||||
|
||||
import threading
|
||||
import time
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_enabled():
|
||||
"""Create an RTC config with RTC enabled."""
|
||||
return RTCConfig(enabled=True, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_disabled():
|
||||
"""Create an RTC config with RTC disabled."""
|
||||
return RTCConfig(enabled=False, execution_horizon=10, max_guidance_weight=1.0)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_actions():
|
||||
"""Create sample action tensors for testing."""
|
||||
return {
|
||||
"original": torch.randn(50, 6), # (time_steps, action_dim)
|
||||
"processed": torch.randn(50, 6),
|
||||
"short": torch.randn(10, 6),
|
||||
"longer": torch.randn(100, 6),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_enabled(rtc_config_enabled):
|
||||
"""Create an ActionQueue with RTC enabled."""
|
||||
return ActionQueue(rtc_config_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def action_queue_rtc_disabled(rtc_config_disabled):
|
||||
"""Create an ActionQueue with RTC disabled."""
|
||||
return ActionQueue(rtc_config_disabled)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_enabled(rtc_config_enabled):
|
||||
"""Test ActionQueue initializes correctly with RTC enabled."""
|
||||
queue = ActionQueue(rtc_config_enabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is True
|
||||
|
||||
|
||||
def test_action_queue_initialization_rtc_disabled(rtc_config_disabled):
|
||||
"""Test ActionQueue initializes correctly with RTC disabled."""
|
||||
queue = ActionQueue(rtc_config_disabled)
|
||||
assert queue.queue is None
|
||||
assert queue.original_queue is None
|
||||
assert queue.last_index == 0
|
||||
assert queue.cfg.enabled is False
|
||||
|
||||
|
||||
# ====================== get() Tests ======================
|
||||
|
||||
|
||||
def test_get_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get() returns None when queue is empty."""
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_returns_actions_sequentially(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns actions in sequence."""
|
||||
# Initialize queue with actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Get first action
|
||||
action1 = action_queue_rtc_enabled.get()
|
||||
assert action1 is not None
|
||||
assert action1.shape == (6,)
|
||||
assert torch.equal(action1, sample_actions["processed"][0])
|
||||
|
||||
# Get second action
|
||||
action2 = action_queue_rtc_enabled.get()
|
||||
assert action2 is not None
|
||||
assert torch.equal(action2, sample_actions["processed"][1])
|
||||
|
||||
|
||||
def test_get_returns_none_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() returns None after all actions are consumed."""
|
||||
# Use short action sequence
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
# Next get should return None
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is None
|
||||
|
||||
|
||||
def test_get_increments_last_index(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() increments last_index correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
|
||||
# ====================== qsize() Tests ======================
|
||||
|
||||
|
||||
def test_qsize_returns_zero_when_empty(action_queue_rtc_enabled):
|
||||
"""Test qsize() returns 0 when queue is empty."""
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_qsize_returns_correct_size(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns correct number of remaining actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 9
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
|
||||
def test_qsize_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test qsize() returns 0 after queue is exhausted."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all actions
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== empty() Tests ======================
|
||||
|
||||
|
||||
def test_empty_returns_true_when_empty(action_queue_rtc_enabled):
|
||||
"""Test empty() returns True when queue is empty."""
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_empty_returns_false_when_not_empty(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False when queue has actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_partial_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns False after partial consumption."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is False
|
||||
|
||||
|
||||
def test_empty_after_full_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test empty() returns True after all actions consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
# ====================== get_action_index() Tests ======================
|
||||
|
||||
|
||||
def test_get_action_index_initial_value(action_queue_rtc_enabled):
|
||||
"""Test get_action_index() returns 0 initially."""
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_get_action_index_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_action_index() tracks consumption correctly."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 1
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.get_action_index() == 3
|
||||
|
||||
|
||||
# ====================== get_left_over() Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_returns_none_when_empty(action_queue_rtc_enabled):
|
||||
"""Test get_left_over() returns None when queue is empty."""
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is None
|
||||
|
||||
|
||||
def test_get_left_over_returns_all_when_unconsumed(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns all original actions when none consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (10, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"])
|
||||
|
||||
|
||||
def test_get_left_over_returns_remaining_after_consumption(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns only remaining original actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (7, 6)
|
||||
assert torch.equal(leftover, sample_actions["short"][3:])
|
||||
|
||||
|
||||
def test_get_left_over_returns_empty_after_exhaustion(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() returns empty tensor after all consumed."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume all
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert leftover is not None
|
||||
assert leftover.shape == (0, 6)
|
||||
|
||||
|
||||
# ====================== merge() with RTC Enabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_replaces_queue_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() replaces queue when RTC is enabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_enabled.qsize() == 10
|
||||
|
||||
# Consume some actions
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.qsize() == 8
|
||||
|
||||
# Merge new actions - should replace, not append
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
# Queue should be replaced with new actions minus delay
|
||||
# Original has 50 actions, delay is 5, so remaining is 45
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_merge_respects_real_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() correctly applies real_delay when RTC is enabled."""
|
||||
delay = 10
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=delay)
|
||||
|
||||
# Queue should have original length minus delay
|
||||
expected_size = len(sample_actions["original"]) - delay
|
||||
assert action_queue_rtc_enabled.qsize() == expected_size
|
||||
|
||||
# First action should be the one at index [delay]
|
||||
first_action = action_queue_rtc_enabled.get()
|
||||
assert torch.equal(first_action, sample_actions["processed"][delay])
|
||||
|
||||
|
||||
def test_merge_resets_last_index_when_rtc_enabled(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() resets last_index to 0 when RTC is enabled."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_enabled.get()
|
||||
action_queue_rtc_enabled.get()
|
||||
assert action_queue_rtc_enabled.last_index == 2
|
||||
|
||||
# Merge new actions
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=5)
|
||||
|
||||
assert action_queue_rtc_enabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_with_zero_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with zero delay keeps all actions."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_with_large_delay(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test merge() with delay larger than action sequence."""
|
||||
# Delay is larger than sequence length
|
||||
delay = 100
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=delay)
|
||||
|
||||
# Queue should be empty (delay >= length)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== merge() with RTC Disabled Tests ======================
|
||||
|
||||
|
||||
def test_merge_appends_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() appends actions when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
initial_size = action_queue_rtc_disabled.qsize()
|
||||
assert initial_size == 10
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have appended
|
||||
assert action_queue_rtc_disabled.qsize() == initial_size + 10
|
||||
|
||||
|
||||
def test_merge_removes_consumed_actions_when_appending(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() removes consumed actions before appending when RTC is disabled."""
|
||||
# Add initial actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Consume 3 actions
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.qsize() == 7
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Should have 7 remaining + 10 new = 17
|
||||
assert action_queue_rtc_disabled.qsize() == 17
|
||||
|
||||
|
||||
def test_merge_resets_last_index_after_append(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() resets last_index after appending when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
action_queue_rtc_disabled.get()
|
||||
action_queue_rtc_disabled.get()
|
||||
assert action_queue_rtc_disabled.last_index == 2
|
||||
|
||||
# Merge more actions
|
||||
action_queue_rtc_disabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# last_index should be reset to 0
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
def test_merge_ignores_delay_when_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() ignores real_delay parameter when RTC is disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=10)
|
||||
|
||||
# All actions should be in queue (delay ignored)
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
|
||||
|
||||
def test_merge_first_call_with_rtc_disabled(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() on first call with RTC disabled."""
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == len(sample_actions["original"])
|
||||
assert action_queue_rtc_disabled.last_index == 0
|
||||
|
||||
|
||||
# ====================== merge() with Different Action Shapes Tests ======================
|
||||
|
||||
|
||||
def test_merge_with_different_action_dims():
|
||||
"""Test merge() handles actions with different dimensions."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Actions with 4 dimensions instead of 6
|
||||
actions_4d = torch.randn(20, 4)
|
||||
queue.merge(actions_4d, actions_4d, real_delay=5)
|
||||
|
||||
action = queue.get()
|
||||
assert action.shape == (4,)
|
||||
|
||||
|
||||
def test_merge_with_different_lengths():
|
||||
"""Test merge() handles action sequences of varying lengths."""
|
||||
cfg = RTCConfig(enabled=False, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Add sequences of different lengths
|
||||
queue.merge(torch.randn(10, 6), torch.randn(10, 6), real_delay=0)
|
||||
assert queue.qsize() == 10
|
||||
|
||||
queue.merge(torch.randn(25, 6), torch.randn(25, 6), real_delay=0)
|
||||
assert queue.qsize() == 35
|
||||
|
||||
|
||||
# ====================== merge() Delay Validation Tests ======================
|
||||
|
||||
|
||||
def test_merge_validates_delay_consistency(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() validates that real_delay matches action index difference."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with mismatched delay (should log warning)
|
||||
# We consumed 5 actions, so index is 5. If we pass action_index_before_inference=0,
|
||||
# then indexes_diff=5, but if real_delay=3, it will warn
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=3,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Check warning was logged
|
||||
assert "Indexes diff is not equal to real delay" in caplog.text
|
||||
|
||||
|
||||
def test_merge_no_warning_when_delays_match(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() doesn't warn when delays are consistent."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
# Initialize queue
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume 5 actions
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge with matching delay
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=0,
|
||||
)
|
||||
|
||||
# Should not have warning
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
def test_merge_skips_validation_when_action_index_none(action_queue_rtc_enabled, sample_actions, caplog):
|
||||
"""Test merge() skips delay validation when action_index_before_inference is None."""
|
||||
import logging
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Pass None for action_index_before_inference
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=999, # Doesn't matter
|
||||
action_index_before_inference=None,
|
||||
)
|
||||
|
||||
# Should not warn (validation skipped)
|
||||
assert "Indexes diff is not equal to real delay" not in caplog.text
|
||||
|
||||
|
||||
# ====================== Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get() is thread-safe with multiple consumers."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
results = []
|
||||
errors = []
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(25):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
if action is not None:
|
||||
results.append(action)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=consumer) for _ in range(4)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed all actions (100 total, 4 threads * 25 each)
|
||||
assert len(results) == 100
|
||||
|
||||
# All results should be unique (no duplicate consumption)
|
||||
# We can verify by checking that indices are not duplicated
|
||||
# Since we don't track indices in results, we check total count is correct
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_merge_is_thread_safe(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test merge() is thread-safe with multiple producers."""
|
||||
errors = []
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(5):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=producer) for _ in range(3)]
|
||||
|
||||
for t in threads:
|
||||
t.start()
|
||||
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have accumulated all actions (3 threads * 5 merges * 10 actions = 150)
|
||||
assert action_queue_rtc_disabled.qsize() == 150
|
||||
|
||||
|
||||
def test_concurrent_get_and_merge(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test concurrent get() and merge() operations."""
|
||||
errors = []
|
||||
consumed_count = [0]
|
||||
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(50):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
if action is not None:
|
||||
consumed_count[0] += 1
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
def producer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_disabled.merge(
|
||||
sample_actions["short"], sample_actions["short"], real_delay=0
|
||||
)
|
||||
time.sleep(0.005)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_threads = [threading.Thread(target=consumer) for _ in range(2)]
|
||||
producer_threads = [threading.Thread(target=producer) for _ in range(2)]
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.start()
|
||||
|
||||
for t in consumer_threads + producer_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Should have consumed some or all actions (non-deterministic due to timing)
|
||||
# Total produced: 2 producers * 10 merges * 10 actions = 200
|
||||
# Total consumed attempts: 2 consumers * 50 = 100
|
||||
assert consumed_count[0] <= 200
|
||||
|
||||
|
||||
# ====================== get_left_over() Thread Safety Tests ======================
|
||||
|
||||
|
||||
def test_get_left_over_is_thread_safe(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test get_left_over() is thread-safe with concurrent access."""
|
||||
action_queue_rtc_enabled.merge(sample_actions["longer"], sample_actions["longer"], real_delay=0)
|
||||
|
||||
errors = []
|
||||
leftovers = []
|
||||
|
||||
def reader():
|
||||
try:
|
||||
for _ in range(20):
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
if leftover is not None:
|
||||
leftovers.append(leftover.shape[0])
|
||||
time.sleep(0.001)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
threads = [threading.Thread(target=reader) for _ in range(3)]
|
||||
|
||||
# Also consume some actions concurrently
|
||||
def consumer():
|
||||
try:
|
||||
for _ in range(10):
|
||||
action_queue_rtc_enabled.get()
|
||||
time.sleep(0.002)
|
||||
except Exception as e:
|
||||
errors.append(e)
|
||||
|
||||
consumer_thread = threading.Thread(target=consumer)
|
||||
|
||||
all_threads = threads + [consumer_thread]
|
||||
|
||||
for t in all_threads:
|
||||
t.start()
|
||||
|
||||
for t in all_threads:
|
||||
t.join()
|
||||
|
||||
# Should not have errors
|
||||
assert len(errors) == 0
|
||||
|
||||
# Leftovers should be monotonically decreasing or stable
|
||||
# (as actions are consumed, leftover size decreases)
|
||||
assert len(leftovers) > 0
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_queue_with_single_action(action_queue_rtc_enabled):
|
||||
"""Test queue behavior with a single action."""
|
||||
single_action_original = torch.randn(1, 6)
|
||||
single_action_processed = torch.randn(1, 6)
|
||||
|
||||
action_queue_rtc_enabled.merge(single_action_original, single_action_processed, real_delay=0)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 1
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
assert action.shape == (6,)
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
def test_queue_behavior_after_multiple_merge_cycles(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test queue maintains correct state through multiple merge cycles."""
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.merge(sample_actions["short"], sample_actions["short"], real_delay=0)
|
||||
|
||||
# Consume half
|
||||
for _ in range(5):
|
||||
action_queue_rtc_enabled.get()
|
||||
|
||||
# Merge again
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=3)
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() > 0
|
||||
|
||||
|
||||
def test_queue_with_all_zeros_actions(action_queue_rtc_enabled):
|
||||
"""Test queue handles all-zero action tensors."""
|
||||
zeros_actions = torch.zeros(20, 6)
|
||||
action_queue_rtc_enabled.merge(zeros_actions, zeros_actions, real_delay=0)
|
||||
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert torch.all(action == 0)
|
||||
|
||||
|
||||
def test_queue_clones_input_tensors(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test that merge() clones input tensors, not storing references."""
|
||||
original_copy = sample_actions["original"].clone()
|
||||
processed_copy = sample_actions["processed"].clone()
|
||||
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Modify original tensors
|
||||
sample_actions["original"].fill_(999.0)
|
||||
sample_actions["processed"].fill_(-999.0)
|
||||
|
||||
# Queue should have cloned values
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert not torch.equal(action, sample_actions["processed"][0])
|
||||
assert torch.equal(action, processed_copy[0])
|
||||
|
||||
leftover = action_queue_rtc_enabled.get_left_over()
|
||||
assert not torch.equal(leftover, sample_actions["original"][1:])
|
||||
assert torch.equal(leftover, original_copy[1:])
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_queue_handles_gpu_tensors():
|
||||
"""Test queue correctly handles GPU tensors."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
actions_gpu = torch.randn(20, 6, device="cuda")
|
||||
queue.merge(actions_gpu, actions_gpu, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.device.type == "cuda"
|
||||
|
||||
leftover = queue.get_left_over()
|
||||
assert leftover.device.type == "cuda"
|
||||
|
||||
|
||||
def test_queue_handles_different_dtypes():
|
||||
"""Test queue handles actions with different dtypes."""
|
||||
cfg = RTCConfig(enabled=True, execution_horizon=10)
|
||||
queue = ActionQueue(cfg)
|
||||
|
||||
# Use float64 instead of default float32
|
||||
actions_f64 = torch.randn(20, 6, dtype=torch.float64)
|
||||
queue.merge(actions_f64, actions_f64, real_delay=0)
|
||||
|
||||
action = queue.get()
|
||||
assert action.dtype == torch.float64
|
||||
|
||||
|
||||
def test_empty_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test empty() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.empty() is True
|
||||
|
||||
|
||||
def test_qsize_with_none_queue(action_queue_rtc_enabled):
|
||||
"""Test qsize() correctly handles None queue."""
|
||||
assert action_queue_rtc_enabled.queue is None
|
||||
assert action_queue_rtc_enabled.qsize() == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_rtc_workflow(action_queue_rtc_enabled, sample_actions):
|
||||
"""Test a typical RTC workflow: merge, consume, merge with delay."""
|
||||
# First inference
|
||||
action_queue_rtc_enabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
initial_size = action_queue_rtc_enabled.qsize()
|
||||
assert initial_size == 50
|
||||
|
||||
# Consume 10 actions (execution_horizon)
|
||||
for _ in range(10):
|
||||
action = action_queue_rtc_enabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_enabled.qsize() == 40
|
||||
|
||||
# Second inference with delay
|
||||
action_index_before = action_queue_rtc_enabled.get_action_index()
|
||||
|
||||
action_queue_rtc_enabled.merge(
|
||||
sample_actions["original"],
|
||||
sample_actions["processed"],
|
||||
real_delay=5,
|
||||
action_index_before_inference=action_index_before,
|
||||
)
|
||||
|
||||
# Queue should be replaced, minus delay
|
||||
assert action_queue_rtc_enabled.qsize() == 45
|
||||
assert action_queue_rtc_enabled.get_action_index() == 0
|
||||
|
||||
|
||||
def test_typical_non_rtc_workflow(action_queue_rtc_disabled, sample_actions):
|
||||
"""Test a typical non-RTC workflow: merge, consume, merge again."""
|
||||
# First inference
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
assert action_queue_rtc_disabled.qsize() == 50
|
||||
|
||||
# Consume 40 actions
|
||||
for _ in range(40):
|
||||
action = action_queue_rtc_disabled.get()
|
||||
assert action is not None
|
||||
|
||||
assert action_queue_rtc_disabled.qsize() == 10
|
||||
|
||||
# Second inference (should append)
|
||||
action_queue_rtc_disabled.merge(sample_actions["original"], sample_actions["processed"], real_delay=0)
|
||||
|
||||
# Should have 10 remaining + 50 new = 60
|
||||
assert action_queue_rtc_disabled.qsize() == 60
|
||||
@@ -0,0 +1,65 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC configuration module."""
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_config_default_initialization():
|
||||
"""Test RTCConfig initializes with default values."""
|
||||
config = RTCConfig()
|
||||
|
||||
assert config.enabled is False
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.max_guidance_weight == 10.0
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
assert config.debug_maxlen == 100
|
||||
|
||||
|
||||
def test_rtc_config_custom_initialization():
|
||||
"""Test RTCConfig initializes with custom values."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=20,
|
||||
debug=True,
|
||||
debug_maxlen=200,
|
||||
)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.EXP
|
||||
assert config.max_guidance_weight == 5.0
|
||||
assert config.execution_horizon == 20
|
||||
assert config.debug is True
|
||||
assert config.debug_maxlen == 200
|
||||
|
||||
|
||||
def test_rtc_config_partial_initialization():
|
||||
"""Test RTCConfig with partial custom values."""
|
||||
config = RTCConfig(enabled=True, max_guidance_weight=15.0)
|
||||
|
||||
assert config.enabled is True
|
||||
assert config.max_guidance_weight == 15.0
|
||||
# Other values should be defaults
|
||||
assert config.prefix_attention_schedule == RTCAttentionSchedule.LINEAR
|
||||
assert config.execution_horizon == 10
|
||||
assert config.debug is False
|
||||
@@ -0,0 +1,488 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC debug tracker module."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.policies.rtc.debug_tracker import DebugStep, Tracker
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_tensors():
|
||||
"""Create sample tensors for testing."""
|
||||
return {
|
||||
"x_t": torch.randn(1, 50, 6),
|
||||
"v_t": torch.randn(1, 50, 6),
|
||||
"x1_t": torch.randn(1, 50, 6),
|
||||
"correction": torch.randn(1, 50, 6),
|
||||
"err": torch.randn(1, 50, 6),
|
||||
"weights": torch.randn(1, 50, 1),
|
||||
}
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enabled_tracker():
|
||||
"""Create an enabled tracker with default settings."""
|
||||
return Tracker(enabled=True, maxlen=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disabled_tracker():
|
||||
"""Create a disabled tracker."""
|
||||
return Tracker(enabled=False)
|
||||
|
||||
|
||||
# ====================== DebugStep Tests ======================
|
||||
|
||||
|
||||
def test_debug_step_initialization():
|
||||
"""Test that DebugStep can be initialized with default values."""
|
||||
step = DebugStep()
|
||||
assert step.step_idx == 0
|
||||
assert step.x_t is None
|
||||
assert step.v_t is None
|
||||
assert step.x1_t is None
|
||||
assert step.correction is None
|
||||
assert step.err is None
|
||||
assert step.weights is None
|
||||
assert step.guidance_weight is None
|
||||
assert step.time is None
|
||||
assert step.inference_delay is None
|
||||
assert step.execution_horizon is None
|
||||
assert step.metadata == {}
|
||||
|
||||
|
||||
def test_debug_step_with_values(sample_tensors):
|
||||
"""Test DebugStep initialization with actual values."""
|
||||
step = DebugStep(
|
||||
step_idx=5,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
x1_t=sample_tensors["x1_t"],
|
||||
correction=sample_tensors["correction"],
|
||||
err=sample_tensors["err"],
|
||||
weights=sample_tensors["weights"],
|
||||
guidance_weight=2.5,
|
||||
time=0.8,
|
||||
inference_delay=4,
|
||||
execution_horizon=8,
|
||||
metadata={"custom_key": "custom_value"},
|
||||
)
|
||||
|
||||
assert step.step_idx == 5
|
||||
assert torch.equal(step.x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(step.v_t, sample_tensors["v_t"])
|
||||
assert torch.equal(step.x1_t, sample_tensors["x1_t"])
|
||||
assert torch.equal(step.correction, sample_tensors["correction"])
|
||||
assert torch.equal(step.err, sample_tensors["err"])
|
||||
assert torch.equal(step.weights, sample_tensors["weights"])
|
||||
assert step.guidance_weight == 2.5
|
||||
assert step.time == 0.8
|
||||
assert step.inference_delay == 4
|
||||
assert step.execution_horizon == 8
|
||||
assert step.metadata == {"custom_key": "custom_value"}
|
||||
|
||||
|
||||
def test_debug_step_to_dict_without_tensors(sample_tensors):
|
||||
"""Test converting DebugStep to dictionary without tensor values."""
|
||||
step = DebugStep(
|
||||
step_idx=3,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=torch.tensor(3.0),
|
||||
time=torch.tensor(0.5),
|
||||
inference_delay=2,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
result = step.to_dict(include_tensors=False)
|
||||
|
||||
assert result["step_idx"] == 3
|
||||
assert result["guidance_weight"] == 3.0
|
||||
assert result["time"] == 0.5
|
||||
assert result["inference_delay"] == 2
|
||||
assert result["execution_horizon"] == 10
|
||||
|
||||
# Check tensor statistics are included
|
||||
assert "x_t_stats" in result
|
||||
assert "v_t_stats" in result
|
||||
assert "x1_t_stats" not in result # x1_t was None
|
||||
|
||||
# Verify statistics structure
|
||||
assert "shape" in result["x_t_stats"]
|
||||
assert "mean" in result["x_t_stats"]
|
||||
assert "std" in result["x_t_stats"]
|
||||
assert "min" in result["x_t_stats"]
|
||||
assert "max" in result["x_t_stats"]
|
||||
|
||||
# Verify shape matches original tensor
|
||||
assert result["x_t_stats"]["shape"] == tuple(sample_tensors["x_t"].shape)
|
||||
|
||||
|
||||
def test_debug_step_to_dict_with_tensors(sample_tensors):
|
||||
"""Test converting DebugStep to dictionary with tensor values."""
|
||||
step = DebugStep(
|
||||
step_idx=1,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=1.5,
|
||||
time=0.9,
|
||||
)
|
||||
|
||||
result = step.to_dict(include_tensors=True)
|
||||
|
||||
assert result["step_idx"] == 1
|
||||
assert result["guidance_weight"] == 1.5
|
||||
assert result["time"] == 0.9
|
||||
|
||||
# Check tensors are included (as CPU tensors)
|
||||
assert "x_t" in result
|
||||
assert "v_t" in result
|
||||
assert isinstance(result["x_t"], torch.Tensor)
|
||||
assert isinstance(result["v_t"], torch.Tensor)
|
||||
assert result["x_t"].device.type == "cpu"
|
||||
assert result["v_t"].device.type == "cpu"
|
||||
|
||||
|
||||
def test_debug_step_to_dict_with_none_guidance_weight():
|
||||
"""Test to_dict handles None guidance_weight correctly."""
|
||||
step = DebugStep(step_idx=0, time=1.0, guidance_weight=None)
|
||||
result = step.to_dict(include_tensors=False)
|
||||
assert result["guidance_weight"] is None
|
||||
|
||||
|
||||
def test_tracker_initialization_enabled():
|
||||
"""Test tracker initialization when enabled."""
|
||||
tracker = Tracker(enabled=True, maxlen=50)
|
||||
assert tracker.enabled is True
|
||||
assert tracker._steps == {}
|
||||
assert tracker._maxlen == 50
|
||||
assert tracker._step_counter == 0
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
def test_tracker_reset_when_enabled(enabled_tracker, sample_tensors):
|
||||
"""Test reset clears all steps when tracker is enabled."""
|
||||
# Add some steps
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
# Reset
|
||||
enabled_tracker.reset()
|
||||
assert len(enabled_tracker) == 0
|
||||
assert enabled_tracker._step_counter == 0
|
||||
assert enabled_tracker._steps == {}
|
||||
|
||||
|
||||
def test_tracker_reset_when_disabled(disabled_tracker):
|
||||
"""Test reset on disabled tracker doesn't cause errors."""
|
||||
disabled_tracker.reset()
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
# ====================== Tracker.track() Tests ======================
|
||||
|
||||
|
||||
def test_track_creates_new_step(enabled_tracker, sample_tensors):
|
||||
"""Test that track creates a new step when time doesn't exist."""
|
||||
enabled_tracker.track(
|
||||
time=1.0,
|
||||
x_t=sample_tensors["x_t"],
|
||||
v_t=sample_tensors["v_t"],
|
||||
guidance_weight=5.0,
|
||||
inference_delay=4,
|
||||
execution_horizon=8,
|
||||
)
|
||||
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].step_idx == 0
|
||||
assert steps[0].time == 1.0
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
|
||||
assert steps[0].guidance_weight == 5.0
|
||||
assert steps[0].inference_delay == 4
|
||||
assert steps[0].execution_horizon == 8
|
||||
|
||||
|
||||
def test_track_updates_existing_step(enabled_tracker, sample_tensors):
|
||||
"""Test that track updates an existing step at the same time."""
|
||||
# Create initial step
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].v_t is None
|
||||
|
||||
# Update the same timestep with v_t
|
||||
enabled_tracker.track(time=0.9, v_t=sample_tensors["v_t"])
|
||||
assert len(enabled_tracker) == 1 # Still only one step
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Original x_t preserved
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # New v_t added
|
||||
|
||||
|
||||
def test_track_with_tensor_time(enabled_tracker, sample_tensors):
|
||||
"""Test track handles tensor time values correctly."""
|
||||
time_tensor = torch.tensor(0.8)
|
||||
enabled_tracker.track(time=time_tensor, x_t=sample_tensors["x_t"])
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert abs(steps[0].time - 0.8) < 1e-6 # Use approximate comparison for floating point
|
||||
|
||||
|
||||
def test_track_time_rounding(enabled_tracker, sample_tensors):
|
||||
"""Test that track rounds time to avoid floating point precision issues."""
|
||||
# These times should be treated as the same after rounding to 6 decimals
|
||||
enabled_tracker.track(time=0.9000001, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9000002, v_t=sample_tensors["v_t"])
|
||||
|
||||
# Should still be one step (times rounded to same value)
|
||||
assert len(enabled_tracker) == 1
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"])
|
||||
|
||||
|
||||
def test_track_does_nothing_when_disabled(disabled_tracker, sample_tensors):
|
||||
"""Test that track does nothing when tracker is disabled."""
|
||||
disabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
def test_track_with_metadata(enabled_tracker, sample_tensors):
|
||||
"""Test track stores custom metadata."""
|
||||
enabled_tracker.track(time=0.7, x_t=sample_tensors["x_t"], custom_field="custom_value", count=42)
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].metadata["custom_field"] == "custom_value"
|
||||
assert steps[0].metadata["count"] == 42
|
||||
|
||||
|
||||
def test_track_updates_metadata(enabled_tracker):
|
||||
"""Test that track updates metadata for existing steps."""
|
||||
enabled_tracker.track(time=0.6, meta1="value1")
|
||||
enabled_tracker.track(time=0.6, meta2="value2")
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps[0].metadata["meta1"] == "value1"
|
||||
assert steps[0].metadata["meta2"] == "value2"
|
||||
|
||||
|
||||
def test_track_clones_tensors(enabled_tracker, sample_tensors):
|
||||
"""Test that track clones tensors instead of storing references."""
|
||||
x_t_original = sample_tensors["x_t"].clone()
|
||||
enabled_tracker.track(time=0.5, x_t=sample_tensors["x_t"])
|
||||
|
||||
# Modify original tensor
|
||||
sample_tensors["x_t"].fill_(999.0)
|
||||
|
||||
# Tracked tensor should not be affected
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert not torch.equal(steps[0].x_t, sample_tensors["x_t"])
|
||||
assert torch.equal(steps[0].x_t, x_t_original)
|
||||
|
||||
|
||||
def test_track_with_none_values(enabled_tracker):
|
||||
"""Test track handles None values correctly."""
|
||||
enabled_tracker.track(
|
||||
time=0.4,
|
||||
x_t=None,
|
||||
v_t=None,
|
||||
guidance_weight=None,
|
||||
inference_delay=None,
|
||||
)
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].x_t is None
|
||||
assert steps[0].v_t is None
|
||||
assert steps[0].guidance_weight is None
|
||||
assert steps[0].inference_delay is None
|
||||
|
||||
|
||||
def test_track_updates_only_non_none_fields(enabled_tracker, sample_tensors):
|
||||
"""Test that update preserves existing values when None is passed."""
|
||||
# Create step with x_t
|
||||
enabled_tracker.track(time=0.3, x_t=sample_tensors["x_t"], guidance_weight=2.0)
|
||||
|
||||
# Update with v_t only (pass None for other fields)
|
||||
enabled_tracker.track(time=0.3, v_t=sample_tensors["v_t"], x_t=None, guidance_weight=None)
|
||||
|
||||
# Original values should be preserved
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert torch.equal(steps[0].x_t, sample_tensors["x_t"]) # Still has x_t
|
||||
assert torch.equal(steps[0].v_t, sample_tensors["v_t"]) # Now has v_t
|
||||
assert steps[0].guidance_weight == 2.0 # Still has guidance_weight
|
||||
|
||||
|
||||
# ====================== Tracker.maxlen Tests ======================
|
||||
|
||||
|
||||
def test_tracker_enforces_maxlen():
|
||||
"""Test that tracker enforces maxlen limit."""
|
||||
tracker = Tracker(enabled=True, maxlen=3)
|
||||
|
||||
# Add 5 steps
|
||||
for i in range(5):
|
||||
time = 1.0 - i * 0.1 # 1.0, 0.9, 0.8, 0.7, 0.6
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
# Should only keep the last 3
|
||||
assert len(tracker) == 3
|
||||
|
||||
# Verify oldest steps were removed (should have 0.6, 0.7, 0.8)
|
||||
steps = tracker.get_all_steps()
|
||||
times = sorted([step.time for step in steps])
|
||||
assert times == [0.6, 0.7, 0.8]
|
||||
|
||||
|
||||
def test_tracker_step_idx_increments_despite_maxlen():
|
||||
"""Test that step_idx continues incrementing even when maxlen is enforced."""
|
||||
tracker = Tracker(enabled=True, maxlen=2)
|
||||
|
||||
# Add 4 steps
|
||||
for i in range(4):
|
||||
time = 1.0 - i * 0.1
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
# Should have 2 steps with step_idx 2 and 3 (oldest removed)
|
||||
steps = sorted(tracker.get_all_steps(), key=lambda s: s.step_idx)
|
||||
assert len(steps) == 2
|
||||
assert steps[0].step_idx == 2
|
||||
assert steps[1].step_idx == 3
|
||||
|
||||
|
||||
def test_tracker_without_maxlen_keeps_all():
|
||||
"""Test that tracker without maxlen keeps all steps."""
|
||||
tracker = Tracker(enabled=True, maxlen=None)
|
||||
|
||||
# Add 100 steps
|
||||
for i in range(100):
|
||||
time = 1.0 - i * 0.01
|
||||
tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
assert len(tracker) == 100
|
||||
|
||||
|
||||
def test_get_all_steps_returns_empty_when_disabled(disabled_tracker):
|
||||
"""Test get_all_steps returns empty list when disabled."""
|
||||
steps = disabled_tracker.get_all_steps()
|
||||
assert steps == []
|
||||
assert isinstance(steps, list)
|
||||
|
||||
|
||||
def test_get_all_steps_returns_empty_when_no_steps(enabled_tracker):
|
||||
"""Test get_all_steps returns empty list when no steps tracked."""
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert steps == []
|
||||
|
||||
|
||||
def test_get_all_steps_returns_all_tracked_steps(enabled_tracker, sample_tensors):
|
||||
"""Test get_all_steps returns all tracked steps."""
|
||||
# Track 5 steps
|
||||
for i in range(5):
|
||||
time = 1.0 - i * 0.1
|
||||
enabled_tracker.track(time=time, x_t=sample_tensors["x_t"])
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 5
|
||||
|
||||
# Verify all are DebugStep instances
|
||||
for step in steps:
|
||||
assert isinstance(step, DebugStep)
|
||||
|
||||
|
||||
def test_get_all_steps_preserves_insertion_order(enabled_tracker):
|
||||
"""Test that get_all_steps preserves insertion order (Python 3.7+)."""
|
||||
times = [0.9, 0.8, 0.7, 0.6, 0.5]
|
||||
for time in times:
|
||||
enabled_tracker.track(time=time, x_t=torch.randn(1, 10, 6))
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
retrieved_times = [step.time for step in steps]
|
||||
|
||||
# Should be in insertion order
|
||||
assert retrieved_times == times
|
||||
|
||||
|
||||
# ====================== Tracker.__len__() Tests ======================
|
||||
|
||||
|
||||
def test_len_returns_zero_when_disabled(disabled_tracker):
|
||||
"""Test __len__ returns 0 when tracker is disabled."""
|
||||
assert len(disabled_tracker) == 0
|
||||
|
||||
|
||||
def test_len_returns_zero_when_empty(enabled_tracker):
|
||||
"""Test __len__ returns 0 when no steps are tracked."""
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
|
||||
def test_len_returns_correct_count(enabled_tracker, sample_tensors):
|
||||
"""Test __len__ returns correct number of tracked steps."""
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 1
|
||||
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
enabled_tracker.track(time=0.8, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 3
|
||||
|
||||
|
||||
def test_len_after_reset(enabled_tracker, sample_tensors):
|
||||
"""Test __len__ returns 0 after reset."""
|
||||
enabled_tracker.track(time=1.0, x_t=sample_tensors["x_t"])
|
||||
enabled_tracker.track(time=0.9, x_t=sample_tensors["x_t"])
|
||||
assert len(enabled_tracker) == 2
|
||||
|
||||
enabled_tracker.reset()
|
||||
assert len(enabled_tracker) == 0
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_tracker_handles_gpu_tensors():
|
||||
"""Test tracker correctly handles GPU tensors."""
|
||||
tracker = Tracker(enabled=True, maxlen=10)
|
||||
x_t_gpu = torch.randn(1, 50, 6, device="cuda")
|
||||
|
||||
tracker.track(time=1.0, x_t=x_t_gpu)
|
||||
|
||||
steps = tracker.get_all_steps()
|
||||
# Tracker should clone and detach tensors
|
||||
assert steps[0].x_t.device.type == "cuda"
|
||||
|
||||
|
||||
def test_tracker_with_varying_tensor_shapes(enabled_tracker):
|
||||
"""Test tracker handles varying tensor shapes across steps."""
|
||||
enabled_tracker.track(time=1.0, x_t=torch.randn(1, 50, 6))
|
||||
enabled_tracker.track(time=0.9, x_t=torch.randn(1, 25, 6))
|
||||
enabled_tracker.track(time=0.8, x_t=torch.randn(2, 50, 8))
|
||||
|
||||
steps = enabled_tracker.get_all_steps()
|
||||
assert len(steps) == 3
|
||||
assert steps[0].x_t.shape == (1, 50, 6)
|
||||
assert steps[1].x_t.shape == (1, 25, 6)
|
||||
assert steps[2].x_t.shape == (2, 50, 8)
|
||||
@@ -0,0 +1,322 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC LatencyTracker module."""
|
||||
|
||||
import pytest
|
||||
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def tracker():
|
||||
"""Create a LatencyTracker with default maxlen."""
|
||||
return LatencyTracker(maxlen=100)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def small_tracker():
|
||||
"""Create a LatencyTracker with small maxlen for overflow testing."""
|
||||
return LatencyTracker(maxlen=5)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_latency_tracker_initialization():
|
||||
"""Test LatencyTracker initializes correctly."""
|
||||
tracker = LatencyTracker(maxlen=50)
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_latency_tracker_default_maxlen():
|
||||
"""Test LatencyTracker uses default maxlen."""
|
||||
tracker = LatencyTracker()
|
||||
# Should accept default maxlen=100
|
||||
assert len(tracker) == 0
|
||||
|
||||
|
||||
# ====================== add() Tests ======================
|
||||
|
||||
|
||||
def test_add_single_latency(tracker):
|
||||
"""Test adding a single latency value."""
|
||||
tracker.add(0.5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_multiple_latencies(tracker):
|
||||
"""Test adding multiple latency values."""
|
||||
latencies = [0.1, 0.5, 0.3, 0.8, 0.2]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert len(tracker) == 5
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_add_negative_latency_ignored(tracker):
|
||||
"""Test that negative latencies are ignored."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(-0.1)
|
||||
tracker.add(0.3)
|
||||
|
||||
# Should only have 2 valid latencies
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.5
|
||||
|
||||
|
||||
def test_add_zero_latency(tracker):
|
||||
"""Test adding zero latency."""
|
||||
tracker.add(0.0)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_add_converts_to_float(tracker):
|
||||
"""Test add() converts input to float."""
|
||||
tracker.add(5) # Integer
|
||||
tracker.add("3.5") # String
|
||||
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_updates_max_latency(tracker):
|
||||
"""Test that max_latency is updated correctly."""
|
||||
tracker.add(0.5)
|
||||
assert tracker.max_latency == 0.5
|
||||
|
||||
tracker.add(0.3)
|
||||
assert tracker.max_latency == 0.5 # Should not decrease
|
||||
|
||||
tracker.add(0.9)
|
||||
assert tracker.max_latency == 0.9 # Should increase
|
||||
|
||||
|
||||
# ====================== reset() Tests ======================
|
||||
|
||||
|
||||
def test_reset_clears_values(tracker):
|
||||
"""Test reset() clears all values."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 3
|
||||
|
||||
tracker.reset()
|
||||
assert len(tracker) == 0
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_clears_max_latency(tracker):
|
||||
"""Test reset() resets max_latency."""
|
||||
tracker.add(1.5)
|
||||
assert tracker.max_latency == 1.5
|
||||
|
||||
tracker.reset()
|
||||
assert tracker.max_latency == 0.0
|
||||
|
||||
|
||||
def test_reset_allows_new_values(tracker):
|
||||
"""Test that tracker works correctly after reset."""
|
||||
tracker.add(0.5)
|
||||
tracker.reset()
|
||||
|
||||
tracker.add(0.3)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.3
|
||||
|
||||
|
||||
# ====================== max() Tests ======================
|
||||
|
||||
|
||||
def test_max_returns_zero_when_empty(tracker):
|
||||
"""Test max() returns 0.0 when tracker is empty."""
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
def test_max_returns_maximum_value(tracker):
|
||||
"""Test max() returns the maximum latency."""
|
||||
latencies = [0.2, 0.8, 0.3, 0.5, 0.1]
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
assert tracker.max() == 0.8
|
||||
|
||||
|
||||
def test_max_persists_after_sliding_window(small_tracker):
|
||||
"""Test max() persists even after values slide out of window."""
|
||||
# Add values that will exceed maxlen=5
|
||||
small_tracker.add(0.1)
|
||||
small_tracker.add(0.9) # This is max
|
||||
small_tracker.add(0.2)
|
||||
small_tracker.add(0.3)
|
||||
small_tracker.add(0.4)
|
||||
small_tracker.add(0.5) # This pushes out 0.1
|
||||
|
||||
# Max should still be 0.9 even though only last 5 values kept
|
||||
assert small_tracker.max() == 0.9
|
||||
|
||||
|
||||
def test_max_after_reset(tracker):
|
||||
"""Test max() returns 0.0 after reset."""
|
||||
tracker.add(1.5)
|
||||
tracker.reset()
|
||||
assert tracker.max() == 0.0
|
||||
|
||||
|
||||
# ====================== p95() Tests ======================
|
||||
|
||||
|
||||
def test_p95_returns_zero_when_empty(tracker):
|
||||
"""Test p95() returns 0.0 when tracker is empty."""
|
||||
assert tracker.p95() == 0.0
|
||||
|
||||
|
||||
def test_p95_returns_95th_percentile(tracker):
|
||||
"""Test p95() returns the 95th percentile."""
|
||||
# Add 100 values
|
||||
for i in range(100):
|
||||
tracker.add(i / 100.0)
|
||||
|
||||
p95 = tracker.p95()
|
||||
assert 0.93 <= p95 <= 0.96
|
||||
|
||||
|
||||
def test_p95_equals_percentile_95(tracker):
|
||||
"""Test p95() equals percentile(0.95)."""
|
||||
for i in range(50):
|
||||
tracker.add(i / 50.0)
|
||||
|
||||
assert tracker.p95() == tracker.percentile(0.95)
|
||||
|
||||
|
||||
# ====================== Edge Cases Tests ======================
|
||||
|
||||
|
||||
def test_single_value(tracker):
|
||||
"""Test tracker behavior with single value."""
|
||||
tracker.add(0.75)
|
||||
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 0.75
|
||||
assert tracker.percentile(0.0) == 0.75
|
||||
assert tracker.percentile(0.5) == 0.75
|
||||
assert tracker.percentile(1.0) == 0.75
|
||||
|
||||
|
||||
def test_all_same_values(tracker):
|
||||
"""Test tracker with all identical values."""
|
||||
for _ in range(10):
|
||||
tracker.add(0.5)
|
||||
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.5
|
||||
assert tracker.percentile(0.0) == 0.5
|
||||
assert tracker.percentile(0.5) == 0.5
|
||||
assert tracker.percentile(1.0) == 0.5
|
||||
|
||||
|
||||
def test_very_small_values(tracker):
|
||||
"""Test tracker with very small float values."""
|
||||
tracker.add(1e-10)
|
||||
tracker.add(2e-10)
|
||||
tracker.add(3e-10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e-10)
|
||||
|
||||
|
||||
def test_very_large_values(tracker):
|
||||
"""Test tracker with very large float values."""
|
||||
tracker.add(1e10)
|
||||
tracker.add(2e10)
|
||||
tracker.add(3e10)
|
||||
|
||||
assert len(tracker) == 3
|
||||
assert tracker.max() == pytest.approx(3e10)
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_typical_usage_pattern(tracker):
|
||||
"""Test a typical usage pattern of the tracker."""
|
||||
# Simulate adding latencies over time
|
||||
latencies = [0.05, 0.08, 0.12, 0.07, 0.15, 0.09, 0.11, 0.06, 0.14, 0.10]
|
||||
|
||||
for lat in latencies:
|
||||
tracker.add(lat)
|
||||
|
||||
# Check statistics
|
||||
assert len(tracker) == 10
|
||||
assert tracker.max() == 0.15
|
||||
|
||||
# p95 should be close to max since we have only 10 values
|
||||
p95 = tracker.p95()
|
||||
assert p95 >= tracker.percentile(0.5) # p95 should be >= median
|
||||
assert p95 <= tracker.max() # p95 should be <= max
|
||||
|
||||
|
||||
def test_reset_and_reuse(tracker):
|
||||
"""Test resetting and reusing tracker."""
|
||||
# First batch
|
||||
tracker.add(1.0)
|
||||
tracker.add(2.0)
|
||||
assert tracker.max() == 2.0
|
||||
|
||||
# Reset
|
||||
tracker.reset()
|
||||
|
||||
# Second batch
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
assert len(tracker) == 2
|
||||
assert tracker.max() == 0.8
|
||||
assert tracker.percentile(0.5) <= 0.8
|
||||
|
||||
|
||||
# ====================== Type Conversion Tests ======================
|
||||
|
||||
|
||||
def test_add_with_integer(tracker):
|
||||
"""Test adding integer values."""
|
||||
tracker.add(5)
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == 5.0
|
||||
|
||||
|
||||
def test_add_with_string_number(tracker):
|
||||
"""Test adding string representation of number."""
|
||||
tracker.add("3.14")
|
||||
assert len(tracker) == 1
|
||||
assert tracker.max() == pytest.approx(3.14)
|
||||
|
||||
|
||||
def test_percentile_converts_q_to_float(tracker):
|
||||
"""Test percentile converts q parameter to float."""
|
||||
tracker.add(0.5)
|
||||
tracker.add(0.8)
|
||||
|
||||
# Pass integer q
|
||||
result = tracker.percentile(1)
|
||||
assert result == 0.8
|
||||
@@ -0,0 +1,773 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for RTC modeling module (RTCProcessor)."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||
|
||||
# ====================== Fixtures ======================
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_enabled():
|
||||
"""Create RTC config with debug enabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
debug_maxlen=100,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_config_debug_disabled():
|
||||
"""Create RTC config with debug disabled."""
|
||||
return RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=10.0,
|
||||
execution_horizon=10,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_enabled(rtc_config_debug_enabled):
|
||||
"""Create RTCProcessor with debug enabled."""
|
||||
return RTCProcessor(rtc_config_debug_enabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def rtc_processor_debug_disabled(rtc_config_debug_disabled):
|
||||
"""Create RTCProcessor with debug disabled."""
|
||||
return RTCProcessor(rtc_config_debug_disabled)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_x_t():
|
||||
"""Create sample x_t tensor (batch, time, action_dim)."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_prev_chunk():
|
||||
"""Create sample previous chunk tensor."""
|
||||
return torch.randn(1, 50, 6)
|
||||
|
||||
|
||||
# ====================== Initialization Tests ======================
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_with_debug(rtc_config_debug_enabled):
|
||||
"""Test RTCProcessor initializes with debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_enabled)
|
||||
assert processor.rtc_config == rtc_config_debug_enabled
|
||||
assert processor.tracker is not None
|
||||
assert processor.tracker.enabled is True
|
||||
|
||||
|
||||
def test_rtc_processor_initialization_without_debug(rtc_config_debug_disabled):
|
||||
"""Test RTCProcessor initializes without debug tracker."""
|
||||
processor = RTCProcessor(rtc_config_debug_disabled)
|
||||
assert processor.rtc_config == rtc_config_debug_disabled
|
||||
assert processor.tracker is None
|
||||
|
||||
|
||||
# ====================== Tracker Proxy Methods Tests ======================
|
||||
|
||||
|
||||
def test_track_when_tracker_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test track() forwards to tracker when enabled."""
|
||||
rtc_processor_debug_enabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
guidance_weight=2.0,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
assert steps[0].time == 0.5
|
||||
|
||||
|
||||
def test_track_when_tracker_disabled(rtc_processor_debug_disabled, sample_x_t):
|
||||
"""Test track() does nothing when tracker disabled."""
|
||||
# Should not raise error
|
||||
rtc_processor_debug_disabled.track(
|
||||
time=torch.tensor(0.5),
|
||||
x_t=sample_x_t,
|
||||
v_t=sample_x_t,
|
||||
)
|
||||
|
||||
# Should return empty list
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test get_all_debug_steps() returns tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 2
|
||||
|
||||
|
||||
def test_get_all_debug_steps_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test get_all_debug_steps() returns empty list when disabled."""
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert steps == []
|
||||
assert isinstance(steps, list)
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_exists(rtc_processor_debug_enabled):
|
||||
"""Test is_debug_enabled() returns True when tracker enabled."""
|
||||
assert rtc_processor_debug_enabled.is_debug_enabled() is True
|
||||
|
||||
|
||||
def test_is_debug_enabled_when_tracker_disabled(rtc_processor_debug_disabled):
|
||||
"""Test is_debug_enabled() returns False when tracker disabled."""
|
||||
assert rtc_processor_debug_disabled.is_debug_enabled() is False
|
||||
|
||||
|
||||
def test_reset_tracker_when_enabled(rtc_processor_debug_enabled, sample_x_t):
|
||||
"""Test reset_tracker() clears tracked steps."""
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.5), x_t=sample_x_t)
|
||||
rtc_processor_debug_enabled.track(time=torch.tensor(0.4), x_t=sample_x_t)
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 2
|
||||
|
||||
rtc_processor_debug_enabled.reset_tracker()
|
||||
assert len(rtc_processor_debug_enabled.get_all_debug_steps()) == 0
|
||||
|
||||
|
||||
def test_reset_tracker_when_disabled(rtc_processor_debug_disabled):
|
||||
"""Test reset_tracker() doesn't error when tracker disabled."""
|
||||
rtc_processor_debug_disabled.reset_tracker() # Should not raise
|
||||
|
||||
|
||||
# ====================== get_prefix_weights Tests ======================
|
||||
|
||||
|
||||
def test_get_prefix_weights_zeros_schedule():
|
||||
"""Test get_prefix_weights with ZEROS schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ZEROS)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=10, total=20)
|
||||
|
||||
# First 5 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
assert torch.all(weights[5:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_ones_schedule():
|
||||
"""Test get_prefix_weights with ONES schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.ONES)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=15, total=20)
|
||||
|
||||
# First 15 should be 1.0, rest should be 0.0
|
||||
assert weights.shape == (20,)
|
||||
assert torch.all(weights[:15] == 1.0)
|
||||
assert torch.all(weights[15:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_linear_schedule():
|
||||
"""Test get_prefix_weights with LINEAR schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section (5:15) should be linearly decreasing from 1 to 0
|
||||
middle_weights = torch.tensor([0.9, 0.8, 0.7, 0.6, 0.5, 0.4, 0.3, 0.2, 0.1])
|
||||
assert torch.allclose(weights[5:14], middle_weights)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_exp_schedule():
|
||||
"""Test get_prefix_weights with EXP schedule."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.EXP)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=5, end=14, total=25)
|
||||
|
||||
# Should have shape (20,)
|
||||
assert weights.shape == (25,)
|
||||
|
||||
# First 5 should be 1.0 (leading ones)
|
||||
assert torch.all(weights[:5] == 1.0)
|
||||
|
||||
# Middle section should be exponentially weighted
|
||||
middle_weights = torch.tensor([0.7645, 0.5706, 0.4130, 0.2871, 0.1888, 0.1145, 0.0611, 0.0258, 0.0061])
|
||||
assert torch.allclose(weights[5:14], middle_weights, atol=1e-4)
|
||||
|
||||
# Last 5 should be 0.0 (trailing zeros)
|
||||
assert torch.all(weights[14:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_end():
|
||||
"""Test get_prefix_weights when start equals end."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_greater_than_end():
|
||||
"""Test get_prefix_weights when start > end (gets clamped)."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# start > end should use min(start, end) = end
|
||||
weights = processor.get_prefix_weights(start=15, end=10, total=20)
|
||||
|
||||
# Should have ones up to end (10), then zeros
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
# ====================== Helper Method Tests ======================
|
||||
|
||||
|
||||
def test_linweights_with_end_equals_start():
|
||||
"""Test _linweights when end equals start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=10, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_linweights_with_end_less_than_start():
|
||||
"""Test _linweights when end < start."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor._linweights(start=15, end=10, total=20)
|
||||
|
||||
# Should return empty tensor
|
||||
assert len(weights) == 0
|
||||
|
||||
|
||||
def test_add_trailing_zeros_normal():
|
||||
"""Test _add_trailing_zeros adds zeros correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6, 0.4, 0.2])
|
||||
result = processor._add_trailing_zeros(weights, total=10, end=5)
|
||||
|
||||
# Should add 5 zeros (total - end = 10 - 5 = 5)
|
||||
assert len(result) == 10
|
||||
assert torch.all(result[:5] == weights)
|
||||
assert torch.all(result[5:] == 0.0)
|
||||
|
||||
|
||||
def test_add_trailing_zeros_no_zeros_needed():
|
||||
"""Test _add_trailing_zeros when no zeros needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([1.0, 0.8, 0.6])
|
||||
result = processor._add_trailing_zeros(weights, total=3, end=5)
|
||||
|
||||
# zeros_len = 3 - 5 = -2 <= 0, so no zeros added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_normal():
|
||||
"""Test _add_leading_ones adds ones correctly."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4, 0.2, 0.0])
|
||||
result = processor._add_leading_ones(weights, start=3, total=10)
|
||||
|
||||
# Should add 3 ones at the start
|
||||
assert len(result) == 8
|
||||
assert torch.all(result[:3] == 1.0)
|
||||
assert torch.all(result[3:] == weights)
|
||||
|
||||
|
||||
def test_add_leading_ones_no_ones_needed():
|
||||
"""Test _add_leading_ones when no ones needed."""
|
||||
config = RTCConfig()
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = torch.tensor([0.8, 0.6, 0.4])
|
||||
result = processor._add_leading_ones(weights, start=0, total=10)
|
||||
|
||||
# ones_len = 0, so no ones added
|
||||
assert torch.equal(result, weights)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_start_equals_total():
|
||||
"""Test get_prefix_weights when start equals total."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=20)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 20
|
||||
assert torch.all(weights[:10] == 1.0)
|
||||
assert torch.all(weights[10:] == 0.0)
|
||||
|
||||
|
||||
def test_get_prefix_weights_with_total_less_than_start():
|
||||
"""Test get_prefix_weights when total less than start."""
|
||||
config = RTCConfig(prefix_attention_schedule=RTCAttentionSchedule.LINEAR)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
weights = processor.get_prefix_weights(start=10, end=10, total=5)
|
||||
|
||||
# Should have ones up to start, then zeros
|
||||
assert len(weights) == 5
|
||||
assert torch.all(weights == 1.0)
|
||||
|
||||
|
||||
# ====================== denoise_step Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_without_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step without previous chunk (no guidance)."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
|
||||
# Mock denoiser that returns fixed velocity
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should return v_t unchanged (no guidance)
|
||||
expected = mock_denoiser(x_t)
|
||||
assert torch.allclose(result, expected)
|
||||
|
||||
|
||||
def test_denoise_step_with_prev_chunk(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step with previous chunk applies guidance."""
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 20, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.5833],
|
||||
[1.3667],
|
||||
[1.1500],
|
||||
[0.9333],
|
||||
[0.7167],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_adds_batch_dimension():
|
||||
"""Test denoise_step handles 2D input by adding batch dimension."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# 2D input (no batch dimension)
|
||||
x_t = torch.randn(10, 6)
|
||||
prev_chunk = torch.randn(5, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Output should be 2D (batch dimension removed)
|
||||
assert result.ndim == 2
|
||||
assert result.shape == (10, 6)
|
||||
|
||||
|
||||
def test_denoise_step_uses_custom_execution_horizon():
|
||||
"""Test denoise_step uses custom execution_horizon parameter."""
|
||||
config = RTCConfig(execution_horizon=10)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 15, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
execution_horizon=15,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.8000],
|
||||
[1.6818],
|
||||
[1.5636],
|
||||
[1.4455],
|
||||
[1.3273],
|
||||
[1.2091],
|
||||
[1.0909],
|
||||
[0.9727],
|
||||
[0.8545],
|
||||
[0.7364],
|
||||
[0.6182],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_zero():
|
||||
"""Test denoise_step handles time=0 (tau=1) without NaN/Inf."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.ones(1, 20, 1)
|
||||
prev_chunk = torch.full((1, 20, 1), 0.1)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return x * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
expected_result = torch.tensor(
|
||||
[
|
||||
[
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
[0.5000],
|
||||
]
|
||||
]
|
||||
)
|
||||
|
||||
assert torch.allclose(result, expected_result, atol=1e-4)
|
||||
|
||||
|
||||
def test_denoise_step_with_real_denoise_step_partial():
|
||||
"""Test denoise_step with a real denoiser."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
batch_size = 10
|
||||
action_dim = 6
|
||||
chunk_size = 20
|
||||
|
||||
x_t = torch.ones(batch_size, chunk_size, action_dim)
|
||||
prev_chunk = torch.full((batch_size, chunk_size, action_dim), 0.1)
|
||||
|
||||
velocity_function = torch.nn.Sequential(
|
||||
torch.nn.Linear(action_dim, 1000),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(1000, 256),
|
||||
torch.nn.ReLU(),
|
||||
torch.nn.Linear(256, action_dim),
|
||||
)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return velocity_function(x)
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
assert result.shape == (batch_size, chunk_size, action_dim)
|
||||
|
||||
|
||||
def test_denoise_step_guidance_weight_at_time_one():
|
||||
"""Test denoise_step handles time=1 (tau=0) with max_guidance_weight clamping."""
|
||||
config = RTCConfig(max_guidance_weight=10.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
# Time = 1 => tau = 0, c = (1-tau)/tau = 1/0 = inf (clamped to max_guidance_weight)
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(1.0),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should clamp to max_guidance_weight (no Inf)
|
||||
assert not torch.any(torch.isinf(result))
|
||||
|
||||
|
||||
def test_denoise_step_tracks_debug_info(rtc_processor_debug_enabled):
|
||||
"""Test denoise_step tracks debug information when enabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_enabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should have tracked one step
|
||||
steps = rtc_processor_debug_enabled.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
# Check tracked values
|
||||
step = steps[0]
|
||||
assert step.time == 0.5
|
||||
assert step.x1_t is not None
|
||||
assert step.correction is not None
|
||||
assert step.err is not None
|
||||
assert step.weights is not None
|
||||
assert step.guidance_weight is not None
|
||||
assert step.inference_delay == 5
|
||||
|
||||
|
||||
def test_denoise_step_doesnt_track_without_debug(rtc_processor_debug_disabled):
|
||||
"""Test denoise_step doesn't track when debug disabled."""
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
rtc_processor_debug_disabled.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Should not track
|
||||
steps = rtc_processor_debug_disabled.get_all_debug_steps()
|
||||
assert len(steps) == 0
|
||||
|
||||
|
||||
# ====================== Integration Tests ======================
|
||||
|
||||
|
||||
def test_denoise_step_full_workflow():
|
||||
"""Test complete denoise_step workflow."""
|
||||
config = RTCConfig(
|
||||
enabled=True,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.LINEAR,
|
||||
max_guidance_weight=5.0,
|
||||
execution_horizon=10,
|
||||
debug=True,
|
||||
)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
# Simulate two denoising steps
|
||||
x_t1 = torch.randn(1, 50, 6)
|
||||
x_t2 = torch.randn(1, 50, 6)
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.randn_like(x) * 0.1
|
||||
|
||||
# First step - no guidance
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t1,
|
||||
prev_chunk_left_over=None,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.8),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Second step - with guidance
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t2,
|
||||
prev_chunk_left_over=result1,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.6),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Both should complete successfully
|
||||
assert result1.shape == (1, 50, 6)
|
||||
assert result2.shape == (1, 50, 6)
|
||||
|
||||
# Should have tracked one step (second one, first had no prev_chunk)
|
||||
steps = processor.get_all_debug_steps()
|
||||
assert len(steps) == 1
|
||||
|
||||
|
||||
@pytest.mark.skipif(not torch.cuda.is_available(), reason="CUDA not available")
|
||||
def test_denoise_step_with_cuda_tensors():
|
||||
"""Test denoise_step works with CUDA tensors."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
x_t = torch.randn(1, 50, 6, device="cuda")
|
||||
prev_chunk = torch.randn(1, 50, 6, device="cuda")
|
||||
|
||||
def mock_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result = processor.denoise_step(
|
||||
x_t=x_t,
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=mock_denoiser,
|
||||
)
|
||||
|
||||
# Result should be on CUDA
|
||||
assert result.device.type == "cuda"
|
||||
assert result.shape == x_t.shape
|
||||
|
||||
|
||||
def test_denoise_step_deterministic_with_same_inputs():
|
||||
"""Test denoise_step produces same output with same inputs."""
|
||||
config = RTCConfig(execution_horizon=10, max_guidance_weight=5.0)
|
||||
processor = RTCProcessor(config)
|
||||
|
||||
torch.manual_seed(42)
|
||||
x_t = torch.randn(1, 50, 6)
|
||||
prev_chunk = torch.randn(1, 50, 6)
|
||||
|
||||
def deterministic_denoiser(x):
|
||||
return torch.ones_like(x) * 0.5
|
||||
|
||||
result1 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
result2 = processor.denoise_step(
|
||||
x_t=x_t.clone(),
|
||||
prev_chunk_left_over=prev_chunk.clone(),
|
||||
inference_delay=5,
|
||||
time=torch.tensor(0.5),
|
||||
original_denoise_step_partial=deterministic_denoiser,
|
||||
)
|
||||
|
||||
# Should produce identical results
|
||||
assert torch.allclose(result1, result2)
|
||||
@@ -0,0 +1,323 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Test SmolVLA policy with Real-Time Chunking (RTC) enabled during inference."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature, RTCAttentionSchedule # noqa: E402
|
||||
from lerobot.policies.factory import make_pre_post_processors # noqa: E402
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig # noqa: E402
|
||||
from lerobot.policies.smolvla.configuration_smolvla import SmolVLAConfig # noqa: F401
|
||||
from lerobot.utils.random_utils import set_seed # noqa: E402
|
||||
from tests.utils import require_cuda, require_package # noqa: E402
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
|
||||
"""Test SmolVLA policy can initialize RTC processor."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Instantiate policy
|
||||
policy = SmolVLAPolicy(config)
|
||||
|
||||
# Verify RTC processor is initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is not None
|
||||
assert policy.rtc_processor.rtc_config.enabled is True
|
||||
|
||||
print("✓ SmolVLA RTC initialization: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_cuda
|
||||
def test_smolvla_rtc_initialization_without_rtc_config():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
|
||||
"""Test SmolVLA policy can initialize without RTC config."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Instantiate policy
|
||||
policy = SmolVLAPolicy(config)
|
||||
|
||||
# Verify RTC processor is not initialized
|
||||
assert hasattr(policy, "rtc_processor")
|
||||
assert policy.rtc_processor is None
|
||||
assert policy.model.rtc_processor is None
|
||||
assert policy._rtc_enabled() is False
|
||||
|
||||
print("✓ SmolVLA RTC initialization without RTC config: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_with_prev_chunk():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
|
||||
"""Test SmolVLA policy inference with RTC and previous chunk."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC and previous chunk
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=4,
|
||||
execution_horizon=10,
|
||||
)
|
||||
|
||||
# Test without RTC for comparison
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Verify shapes
|
||||
assert actions_with_rtc.shape == (1, config.chunk_size, 7)
|
||||
assert actions_without_rtc.shape == (1, config.chunk_size, 7)
|
||||
|
||||
# With previous chunk, actions should be different (RTC guidance applied)
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
|
||||
print("✓ SmolVLA RTC inference with prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_inference_without_prev_chunk():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
|
||||
"""Test SmolVLA policy inference with RTC but no previous chunk (RTC should have no effect)."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC enabled but no previous chunk
|
||||
actions_with_rtc_no_prev = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=None,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Without previous chunk, RTC should have no effect
|
||||
assert torch.allclose(actions_with_rtc_no_prev, actions_without_rtc, rtol=1e-5)
|
||||
|
||||
print("✓ SmolVLA RTC inference without prev_chunk: Test passed")
|
||||
|
||||
|
||||
@require_package("transformers")
|
||||
@require_cuda
|
||||
@pytest.mark.skipif(True, reason="Requires pretrained SmolVLA model weights")
|
||||
def test_smolvla_rtc_validation_rules():
|
||||
from lerobot.policies.smolvla.modeling_smolvla import SmolVLAPolicy # noqa: F401
|
||||
|
||||
"""Test SmolVLA policy with RTC follows all three validation rules."""
|
||||
set_seed(42)
|
||||
|
||||
config = SmolVLAConfig(max_action_dim=7, chunk_size=50)
|
||||
|
||||
# Add RTC config
|
||||
config.rtc_config = RTCConfig(
|
||||
enabled=True,
|
||||
execution_horizon=10,
|
||||
max_guidance_weight=5.0,
|
||||
prefix_attention_schedule=RTCAttentionSchedule.EXP,
|
||||
debug=False,
|
||||
)
|
||||
|
||||
config.input_features = {
|
||||
"observation.state": PolicyFeature(type=FeatureType.STATE, shape=(14,)),
|
||||
"observation.images.base_0_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(3, 224, 224)),
|
||||
}
|
||||
config.output_features = {
|
||||
"action": PolicyFeature(type=FeatureType.ACTION, shape=(7,)),
|
||||
}
|
||||
|
||||
# Create dataset stats
|
||||
dataset_stats = {
|
||||
"observation.state": {"mean": torch.zeros(14), "std": torch.ones(14)},
|
||||
"action": {"mean": torch.zeros(7), "std": torch.ones(7)},
|
||||
"observation.images.base_0_rgb": {"mean": torch.zeros(3, 224, 224), "std": torch.ones(3, 224, 224)},
|
||||
}
|
||||
|
||||
# Instantiate policy and create preprocessor
|
||||
policy = SmolVLAPolicy(config)
|
||||
policy.eval()
|
||||
preprocessor, _ = make_pre_post_processors(
|
||||
policy_cfg=config, pretrained_path=None, dataset_stats=dataset_stats
|
||||
)
|
||||
|
||||
device = config.device
|
||||
|
||||
# Create dummy batch
|
||||
batch = {
|
||||
"observation.state": torch.randn(1, 14, dtype=torch.float32, device=device),
|
||||
"observation.images.base_0_rgb": torch.rand(1, 3, 224, 224, dtype=torch.float32, device=device),
|
||||
"task": ["Pick up the object"],
|
||||
}
|
||||
batch = preprocessor(batch)
|
||||
|
||||
# Create previous chunk
|
||||
prev_chunk = torch.randn(1, 25, 7, dtype=torch.float32, device=device)
|
||||
|
||||
inference_delay = 4
|
||||
execution_horizon = 10
|
||||
|
||||
with torch.no_grad():
|
||||
# Use same noise for fair comparison
|
||||
noise = policy.model.sample_noise((1, config.chunk_size, 7), device)
|
||||
|
||||
# Test with RTC
|
||||
actions_with_rtc = policy.predict_action_chunk(
|
||||
batch,
|
||||
noise=noise.clone(),
|
||||
prev_chunk_left_over=prev_chunk,
|
||||
inference_delay=inference_delay,
|
||||
execution_horizon=execution_horizon,
|
||||
)
|
||||
|
||||
# Test without RTC
|
||||
policy.config.rtc_config.enabled = False
|
||||
actions_without_rtc = policy.predict_action_chunk(batch, noise=noise.clone())
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
assert not torch.allclose(actions_with_rtc, actions_without_rtc, rtol=1e-3)
|
||||
@@ -0,0 +1,72 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.envs.utils import preprocess_observation
|
||||
from lerobot.processor.env_processor import LiberoProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
seed = 42
|
||||
np.random.seed(seed)
|
||||
|
||||
B = 5
|
||||
obs1 = {
|
||||
"pixels": {
|
||||
"image": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
"image2": (np.random.rand(B, 256, 256, 3) * 255).astype(np.uint8),
|
||||
},
|
||||
"robot_state": {
|
||||
"eef": {
|
||||
"pos": np.random.randn(B, 3),
|
||||
"quat": np.random.randn(B, 4),
|
||||
"mat": np.random.randn(B, 3, 3),
|
||||
},
|
||||
"gripper": {
|
||||
"qpos": np.random.randn(B, 2),
|
||||
"qvel": np.random.randn(B, 2),
|
||||
},
|
||||
"joints": {
|
||||
"pos": np.random.randn(B, 7),
|
||||
"vel": np.random.randn(B, 7),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
observation = preprocess_observation(obs1)
|
||||
libero_preprocessor = PolicyProcessorPipeline(
|
||||
steps=[
|
||||
LiberoProcessorStep(),
|
||||
]
|
||||
)
|
||||
processed_obs = libero_preprocessor(observation)
|
||||
assert "observation.state" in processed_obs
|
||||
state = processed_obs["observation.state"]
|
||||
assert isinstance(state, torch.Tensor)
|
||||
assert state.dtype == torch.float32
|
||||
|
||||
assert state.shape[0] == B
|
||||
assert state.shape[1] == 8
|
||||
|
||||
assert "observation.images.image" in processed_obs
|
||||
assert "observation.images.image2" in processed_obs
|
||||
|
||||
assert isinstance(processed_obs["observation.images.image"], torch.Tensor)
|
||||
assert isinstance(processed_obs["observation.images.image2"], torch.Tensor)
|
||||
|
||||
assert processed_obs["observation.images.image"].shape == (B, 3, 256, 256)
|
||||
assert processed_obs["observation.images.image2"].shape == (B, 3, 256, 256)
|
||||
Reference in New Issue
Block a user