mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-12 15:19:43 +00:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 519234a5d8 | |||
| d9371b9a34 | |||
| 17f47b9cbc | |||
| 05395c8b10 | |||
| f495054321 | |||
| 2345c779ee | |||
| aaf8576411 | |||
| d3e6f14d4f | |||
| 1f5487eea8 | |||
| 8d50be9faa |
@@ -173,8 +173,6 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -1,3 +1,2 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -29,8 +29,6 @@
|
||||
title: Using the Dataset Tools
|
||||
- local: dataset_subtask
|
||||
title: Using Subtasks in the Dataset
|
||||
- local: streaming_video_encoding
|
||||
title: Streaming Video Encoding
|
||||
title: "Datasets"
|
||||
- sections:
|
||||
- local: act
|
||||
|
||||
@@ -88,8 +88,5 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Your task description" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=${HF_USER}/act_policy
|
||||
```
|
||||
|
||||
@@ -192,9 +192,6 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.fps=10 \
|
||||
--dataset.single_task="Navigate around obstacles" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -120,12 +120,9 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||
--dataset.episode_time_s=30 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
--policy.path=<user>/groot-bimanual # your trained model
|
||||
--dataset.episode_time_s=30
|
||||
--dataset.reset_time_s=10
|
||||
```
|
||||
|
||||
|
||||
@@ -230,9 +230,6 @@ lerobot-record \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -276,8 +273,5 @@ lerobot-record \
|
||||
--dataset.repo_id=<USER>/eval_hopejr \
|
||||
--dataset.single_task="Evaluate hopejr hand policy" \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--policy.path=outputs/train/hopejr_hand/checkpoints/last/pretrained_model
|
||||
```
|
||||
|
||||
@@ -165,7 +165,7 @@ huggingface-cli login --token ${HUGGINGFACE_TOKEN} --add-to-git-credential
|
||||
Then store your Hugging Face repository name in a variable:
|
||||
|
||||
```bash
|
||||
HF_USER=$(hf auth whoami | awk -F': *' 'NR==1 {print $2}')
|
||||
HF_USER=$(hf auth whoami | head -n 1)
|
||||
echo $HF_USER
|
||||
```
|
||||
|
||||
@@ -185,10 +185,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
</hfoption>
|
||||
<hfoption id="API example">
|
||||
@@ -518,9 +515,6 @@ lerobot-record \
|
||||
--display_data=false \
|
||||
--dataset.repo_id=${HF_USER}/eval_so100 \
|
||||
--dataset.single_task="Put lego brick into the transparent box" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -40,13 +40,6 @@ conda install ffmpeg -c conda-forge
|
||||
>
|
||||
> - _[On Linux only]_ If you want to bring your own ffmpeg: Install [ffmpeg build dependencies](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#GettheDependencies) and [compile ffmpeg from source with libsvtav1](https://trac.ffmpeg.org/wiki/CompilationGuide/Ubuntu#libsvtav1), and make sure you use the corresponding ffmpeg binary to your install with `which ffmpeg`.
|
||||
|
||||
> [!NOTE]
|
||||
> When installing LeRobot inside WSL (Windows Subsystem for Linux), make sure to install `evdev` with the following command:
|
||||
>
|
||||
> ```bash
|
||||
> conda install evdev -c conda-forge
|
||||
> ```
|
||||
|
||||
## Step 3: Install LeRobot 🤗
|
||||
|
||||
### From Source
|
||||
|
||||
@@ -41,10 +41,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/record-test \
|
||||
--dataset.num_episodes=5 \
|
||||
--dataset.single_task="Grab the black cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab the black cube"
|
||||
```
|
||||
|
||||
See the [recording guide](./il_robots#record-a-dataset) for more details.
|
||||
|
||||
@@ -66,13 +66,12 @@ Run on of the examples scripts to teleoperate, record a dataset, replay a datase
|
||||
|
||||
All scripts assume you configured your robot (e.g., SO-100 follower) and set the correct serial port.
|
||||
|
||||
Additionally you need to **copy the URDF of the robot into the examples folder**. For the examples in this tutorial (using SO100/SO101), copy the `SO101` folder from the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101) into the `examples/phone_to_so100/` directory, so that the URDF file path becomes `examples/phone_to_so100/SO101/so101_new_calib.urdf`.
|
||||
Additionally you need to **copy the urdf of the robot to the examples folder**. For the examples in this tutorial (Using SO100/SO101) it is highly recommended to use the urdf in the [SO-ARM100 repo](https://github.com/TheRobotStudio/SO-ARM100/blob/main/Simulation/SO101/so101_new_calib.urdf)
|
||||
|
||||
- Run this example to teleoperate:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python teleoperate.py
|
||||
python examples/phone_to_so100/teleoperate.py
|
||||
```
|
||||
|
||||
After running the example:
|
||||
@@ -85,22 +84,19 @@ Additionally you can customize mapping or safety limits by editing the processor
|
||||
- Run this example to record a dataset, which saves absolute end effector observations and actions:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python record.py
|
||||
python examples/phone_to_so100/record.py
|
||||
```
|
||||
|
||||
- Run this example to replay recorded episodes:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python replay.py
|
||||
python examples/phone_to_so100/replay.py
|
||||
```
|
||||
|
||||
- Run this example to evaluate a pretrained policy:
|
||||
|
||||
```bash
|
||||
cd examples/phone_to_so100
|
||||
python evaluate.py
|
||||
python examples/phone_to_so100/evaluate.py
|
||||
```
|
||||
|
||||
### Important pipeline steps and options
|
||||
|
||||
@@ -159,9 +159,6 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
@@ -201,9 +198,6 @@ lerobot-record \
|
||||
--dataset.fps=15 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.private=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
--display_data=true
|
||||
```
|
||||
|
||||
|
||||
@@ -106,9 +106,6 @@ lerobot-record \
|
||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
||||
--dataset.episode_time_s=50 \
|
||||
--dataset.num_episodes=10 \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
# --dataset.vcodec=auto \
|
||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||
# --teleop.type=so100_leader \
|
||||
# --teleop.port=/dev/ttyACM0 \
|
||||
|
||||
@@ -1,155 +0,0 @@
|
||||
# Streaming Video Encoding Guide
|
||||
|
||||
## 1. Overview
|
||||
|
||||
Streaming video encoding eliminates the traditional PNG round-trip during video dataset recording. Instead of:
|
||||
|
||||
1. Capture frame -> write PNG to disk -> (at episode end) read PNG's -> encode to MP4 -> delete PNG's
|
||||
|
||||
Frames can be encoded in real-time during capture:
|
||||
|
||||
1. Capture frame -> queue to encoder thread -> encode to MP4 directly
|
||||
|
||||
This makes `save_episode()` near-instant (the video is already encoded by the time the episode ends) and removes the blocking wait that previously occurred between episodes, especially with multiple cameras in long episodes.
|
||||
|
||||
## 2. Tuning Parameters
|
||||
|
||||
| Parameter | CLI Flag | Type | Default | Description |
|
||||
| ----------------------- | --------------------------------- | ------------- | ------------- | ----------------------------------------------------------------- |
|
||||
| `streaming_encoding` | `--dataset.streaming_encoding` | `bool` | `True` | Enable real-time encoding during capture |
|
||||
| `vcodec` | `--dataset.vcodec` | `str` | `"libsvtav1"` | Video codec. `"auto"` detects best HW encoder |
|
||||
| `encoder_threads` | `--dataset.encoder_threads` | `int \| None` | `None` (auto) | Threads per encoder instance. `None` will leave the vcoded decide |
|
||||
| `encoder_queue_maxsize` | `--dataset.encoder_queue_maxsize` | `int` | `60` | Max buffered frames per camera (~2s at 30fps). Consumes RAM |
|
||||
|
||||
## 3. Performance Considerations
|
||||
|
||||
Streaming encoding means the CPU is encoding video **during** the capture loop, not after. This creates a CPU budget that must be shared between:
|
||||
|
||||
- **Control loop** (reading cameras, control the robot, writing non-video data)
|
||||
- **Encoder threads** (one pool per camera)
|
||||
- **Rerun visualization** (if enabled)
|
||||
- **OS and other processes**
|
||||
|
||||
### Resolution & Number of Cameras Impact
|
||||
|
||||
| Setup | Throughput (px/sec) | CPU Encoding Load | Notes |
|
||||
| ------------------------- | ------------------- | ----------------- | ------------------------------ |
|
||||
| 2camsx 640x480x3 @30fps | 55M | Low | Works on most systems |
|
||||
| 2camsx 1280x720x3 @30fps | 165M | Moderate | Comfortable on modern systems |
|
||||
| 2camsx 1920x1080x3 @30fps | 373M | High | Requires powerful high-end CPU |
|
||||
|
||||
### `encoder_threads` Tuning
|
||||
|
||||
This parameter controls how many threads each encoder instance uses internally:
|
||||
|
||||
- **Higher values** (e.g., 4-5): Faster encoding, but uses more CPU cores per camera. Good for high-end systems with many cores.
|
||||
- **Lower values** (e.g., 1-2): Less CPU per camera, freeing cores for capture and visualization. Good for low-res images and capable CPUs.
|
||||
- **`None` (default)**: Lets the codec decide. Information available in the codec logs.
|
||||
|
||||
### Backpressure and Frame Dropping
|
||||
|
||||
Each camera has a bounded queue (`encoder_queue_maxsize`, default 60 frames). When the encoder can't keep up:
|
||||
|
||||
1. The queue fills up (consuming RAM)
|
||||
2. New frames are **dropped** (not blocked) — the capture loop continues uninterrupted
|
||||
3. A warning is logged: `"Encoder queue full for {camera}, dropped N frame(s)"`
|
||||
4. At episode end, total dropped frames per camera are reported
|
||||
|
||||
### Symptoms of Encoder Falling Behind
|
||||
|
||||
- **System feels laggy and freezes**: all CPUs are at 100%
|
||||
- **Dropped frame warnings** in the log or lower frames/FPS than expected in the recorded dataset
|
||||
- **Choppy robot movement**: If CPU is severely overloaded, even the capture loop may be affected
|
||||
- **Accumulated rerun lag**: Visualization falls behind real-time
|
||||
|
||||
## 4. Hardware-Accelerated Encoding
|
||||
|
||||
### When to Use
|
||||
|
||||
Use HW encoding when:
|
||||
|
||||
- CPU is the bottleneck (dropped frames, choppy robot, rerun lag)
|
||||
- You have compatible hardware (GPU or dedicated encoder)
|
||||
- You're recording at high throughput (high resolution or with many cameras)
|
||||
|
||||
### Choosing a Codec
|
||||
|
||||
| Codec | CPU Usage | File Size | Quality | Notes |
|
||||
| --------------------- | --------- | -------------- | ------- | ---------------------------------------------------------------- |
|
||||
| `libsvtav1` (default) | High | Smallest | Best | Default. Best compression but most CPU-intensive |
|
||||
| `h264` | Medium | ~30-50% larger | Good | Software H.264. Lower CPU |
|
||||
| HW encoders | Very Low | Largest | Good | Offloads to dedicated hardware. Best for CPU-constrained systems |
|
||||
|
||||
### Available HW Encoders
|
||||
|
||||
| Encoder | Platform | Hardware | CLI Value |
|
||||
| ------------------- | ------------- | ------------------------------------------------------------------------------------------------ | ------------------------------------ |
|
||||
| `h264_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=h264_videotoolbox` |
|
||||
| `hevc_videotoolbox` | macOS | Apple Silicon / Intel | `--dataset.vcodec=hevc_videotoolbox` |
|
||||
| `h264_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=h264_nvenc` |
|
||||
| `hevc_nvenc` | Linux/Windows | NVIDIA GPU | `--dataset.vcodec=hevc_nvenc` |
|
||||
| `h264_vaapi` | Linux | Intel/AMD GPU | `--dataset.vcodec=h264_vaapi` |
|
||||
| `h264_qsv` | Linux/Windows | Intel Quick Sync | `--dataset.vcodec=h264_qsv` |
|
||||
| `auto` | Any | Probes the system for available HW encoders. Falls back to `libsvtav1` if no HW encoder is found | `--dataset.vcodec=auto` |
|
||||
|
||||
> [!NOTE]
|
||||
> In order to use the HW accelerated encoders you might need to upgrade your GPU drivers.
|
||||
|
||||
> [!NOTE]
|
||||
> `libsvtav1` is the default because it provides the best training performance; other vcodecs can reduce CPU usage and be faster, but they typically produce larger files and may affect training time.
|
||||
|
||||
## 5. Troubleshooting
|
||||
|
||||
| Symptom | Likely Cause | Fix |
|
||||
| ------------------------------------------------------------------ | -------------------------------------------- | ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ |
|
||||
| System freezes or choppy robot movement or Rerun visualization lag | CPU starved (100% load usage) | Close other apps, reduce encoding throughput, lower `encoder_threads`, use `h264`, use `display_data=False`. If the CPU continues to be at 100% then it might be insufficient for your setup, consider `--dataset.streaming_encoding=false` or HW encoding (`--dataset.vcodec=auto`) |
|
||||
| "Encoder queue full" warnings or dropped frames in dataset | Encoder can't keep up (Queue overflow) | If CPU is not at 100%: Increase `encoder_threads`, increase `encoder_queue_maxsize` or use HW encoding (`--dataset.vcodec=auto`). |
|
||||
| High RAM usage | Queue filling faster than encoding | `encoder_threads` too low or CPU insufficient. Reduce `encoder_queue_maxsize` or use HW encoding |
|
||||
| Large video files | Using HW encoder or H.264 | Expected trade-off. Switch to `libsvtav1` if CPU allows |
|
||||
| `save_episode()` still slow | `streaming_encoding` is `False` | Set `--dataset.streaming_encoding=true` |
|
||||
| Encoder thread crash | Codec not available or invalid settings | Check `vcodec` is installed, try `--dataset.vcodec=auto` |
|
||||
| Recorded dataset is missing frames | CPU/GPU starvation or occasional load spikes | If ~5% of frames are missing, your system is likely overloaded — follow the recommendations above. If fewer frames are missing (~2%), they are probably due to occasional transient load spikes (often at startup) and can be considered expected. |
|
||||
|
||||
## 6. Recommended Configurations
|
||||
|
||||
These estimates are conservative; we recommend testing them on your setup—start with a low load and increase it gradually.
|
||||
|
||||
### High-End Systems: modern 12+ cores (24+ threads)
|
||||
|
||||
A throughput between ~250-500M px/sec should be comfortable in CPU. For even better results try HW encoding if available.
|
||||
|
||||
```bash
|
||||
# 3camsx 1280x720x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
# 2camsx 1920x1080x3 @30fps: Defaults work well. Optionally increase encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=5 ...
|
||||
|
||||
# 3camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Mid-Range Systems: modern 8+ cores (16+ threads) or Apple Silicon
|
||||
|
||||
A throughput between ~80-300M px/sec should be possible in CPU.
|
||||
|
||||
```bash
|
||||
# 3camsx 640x480x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
# 2camsx 1280x720x3 @30fps: Defaults work well. Optionally decrease encoder parallelism.
|
||||
lerobot-record --dataset.encoder_threads=2 ...
|
||||
|
||||
# 2camsx 1920x1080x3 @30fps: Might require some tuning.
|
||||
```
|
||||
|
||||
### Low-Resource Systems: modern 4+ cores (8+ threads) or Raspberry Pi 5
|
||||
|
||||
On very constrained systems, streaming encoding may compete too heavily with the capture loop. Disabling it falls back to the PNG-based approach where encoding happens between episodes (blocking, but doesn't interfere with capture). Alternatively, record at a lower throughput to reduce both capture and encoding load. Consider also changing codec to `h264` and using batch encoding.
|
||||
|
||||
```bash
|
||||
# 2camsx 640x480x3 @30fps: Requires some tuning.
|
||||
|
||||
# Use H.264, disable streaming, consider batching encoding
|
||||
lerobot-record --dataset.vcodec=h264 --dataset.streaming_encoding=false ...
|
||||
```
|
||||
|
||||
## 7. Closing note
|
||||
|
||||
Performance ultimately depends on your exact setup — frames-per-second, resolution, CPU cores and load, available memory, episode length, and the encoder you choose. Always test with your target workload, be mindful about your CPU & system capabilities and tune `encoder_threads`, `encoder_queue_maxsize`, and
|
||||
`vcodec` reasonably. That said, a common practical configuration (for many applications) is three cameras at 640×480x3 @30fps; this usually runs fine with the default streaming video encoding settings in modern systems. Always verify your recorded dataset is healthy by comparing the video duration to the CLI episode duration and confirming the row count equals FPS × CLI duration.
|
||||
@@ -229,10 +229,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
Example simulation dataset: [nepyope/teleop_test_sim](https://huggingface.co/datasets/nepyope/teleop_test_sim)
|
||||
@@ -282,10 +279,7 @@ lerobot-record \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.episode_time_s=5 \
|
||||
--dataset.reset_time_s=5 \
|
||||
--dataset.push_to_hub=true \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.push_to_hub=true
|
||||
```
|
||||
|
||||
**Note**: Update `server_address` to match your robot's camera server IP.
|
||||
|
||||
@@ -4,7 +4,6 @@ from pathlib import Path
|
||||
from queue import Empty, Full
|
||||
|
||||
import torch
|
||||
import torch.optim as optim
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.utils import hw_to_dataset_features
|
||||
@@ -12,6 +11,7 @@ from lerobot.envs.configs import HILSerlProcessorConfig, HILSerlRobotEnvConfig
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.sac.reward_model.modeling_classifier import Classifier
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.gym_manipulator import make_robot_env
|
||||
from lerobot.robots.so_follower import SO100FollowerConfig
|
||||
@@ -40,8 +40,9 @@ def run_learner(
|
||||
policy_learner.train()
|
||||
policy_learner.to(device)
|
||||
|
||||
# Create Adam optimizer from scratch - simple and clean
|
||||
optimizer = optim.Adam(policy_learner.parameters(), lr=lr)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(policy_learner.config)
|
||||
algorithm = SACAlgorithm(policy=policy_learner, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
|
||||
print(f"[LEARNER] Online buffer capacity: {online_buffer.capacity}")
|
||||
print(f"[LEARNER] Offline buffer capacity: {offline_buffer.capacity}")
|
||||
@@ -83,24 +84,26 @@ def run_learner(
|
||||
else:
|
||||
batch[key] = online_batch[key]
|
||||
|
||||
loss, _ = policy_learner.forward(batch)
|
||||
def batch_iter(b=batch):
|
||||
while True:
|
||||
yield b
|
||||
|
||||
optimizer.zero_grad()
|
||||
loss.backward()
|
||||
optimizer.step()
|
||||
stats = algorithm.update(batch_iter())
|
||||
training_step += 1
|
||||
|
||||
if training_step % LOG_EVERY == 0:
|
||||
log_dict = stats.to_log_dict()
|
||||
print(
|
||||
f"[LEARNER] Training step {training_step}, Loss: {loss.item():.4f}, "
|
||||
f"[LEARNER] Training step {training_step}, "
|
||||
f"critic_loss: {log_dict.get('critic', 'N/A'):.4f}, "
|
||||
f"Buffers: Online={len(online_buffer)}, Offline={len(offline_buffer)}"
|
||||
)
|
||||
|
||||
# Send updated parameters to actor every 10 training steps
|
||||
if training_step % SEND_EVERY == 0:
|
||||
try:
|
||||
state_dict = {k: v.cpu() for k, v in policy_learner.state_dict().items()}
|
||||
parameters_queue.put_nowait(state_dict)
|
||||
weights = algorithm.get_weights()
|
||||
parameters_queue.put_nowait(weights)
|
||||
print("[LEARNER] Sent updated parameters to actor")
|
||||
except Full:
|
||||
# Missing write due to queue not being consumed (should happen rarely)
|
||||
@@ -144,15 +147,15 @@ def run_actor(
|
||||
|
||||
while step < MAX_STEPS_PER_EPISODE and not shutdown_event.is_set():
|
||||
try:
|
||||
new_params = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_params)
|
||||
new_weights = parameters_queue.get_nowait()
|
||||
policy_actor.load_state_dict(new_weights)
|
||||
print("[ACTOR] Updated policy parameters from learner")
|
||||
except Empty: # No new updated parameters available from learner, waiting
|
||||
pass
|
||||
|
||||
# Get action from policy
|
||||
# Get action from policy (returns full action: continuous + discrete)
|
||||
policy_obs = make_policy_obs(obs, device=device)
|
||||
action_tensor = policy_actor.select_action(policy_obs) # predicts a single action
|
||||
action_tensor = policy_actor.select_action(policy_obs)
|
||||
action = action_tensor.squeeze(0).cpu().numpy()
|
||||
|
||||
# Step environment
|
||||
|
||||
+2
-7
@@ -59,7 +59,7 @@ keywords = ["lerobot", "huggingface", "robotics", "machine learning", "artifici
|
||||
dependencies = [
|
||||
|
||||
# Hugging Face dependencies
|
||||
"datasets>=4.0.0,<5.0.0",
|
||||
"datasets>=4.0.0,<4.2.0",
|
||||
"diffusers>=0.27.2,<0.36.0",
|
||||
"huggingface-hub[hf-transfer,cli]>=0.34.2,<0.36.0",
|
||||
"accelerate>=1.10.0,<2.0.0",
|
||||
@@ -98,13 +98,11 @@ pygame-dep = ["pygame>=2.5.1,<2.7.0"]
|
||||
placo-dep = ["placo>=0.9.6,<0.10.0"]
|
||||
transformers-dep = ["transformers>=4.57.1,<5.0.0"]
|
||||
grpcio-dep = ["grpcio==1.73.1", "protobuf>=6.31.1,<6.32.0"]
|
||||
can-dep = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Motors
|
||||
feetech = ["feetech-servo-sdk>=1.0.0,<2.0.0"]
|
||||
dynamixel = ["dynamixel-sdk>=3.7.31,<3.9.0"]
|
||||
damiao = ["lerobot[can-dep]"]
|
||||
robstride = ["lerobot[can-dep]"]
|
||||
damiao = ["python-can>=4.2.0,<5.0.0"]
|
||||
|
||||
# Robots
|
||||
openarms = ["lerobot[damiao]"]
|
||||
@@ -214,9 +212,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -211,3 +211,15 @@ class TrainRLServerPipelineConfig(TrainPipelineConfig):
|
||||
# NOTE: In RL, we don't need an offline dataset
|
||||
# TODO: Make `TrainPipelineConfig.dataset` optional
|
||||
dataset: DatasetConfig | None = None # type: ignore[assignment] # because the parent class has made it's type non-optional
|
||||
|
||||
# Algorithm name registered in RLAlgorithmConfig registry
|
||||
algorithm: str = "sac"
|
||||
|
||||
# Data mixer strategy name. Currently supports "online_offline"
|
||||
mixer: str = "online_offline"
|
||||
# Fraction sampled from online replay when using OnlineOfflineMixer
|
||||
online_ratio: float = 0.5
|
||||
|
||||
# RL trainer iterator
|
||||
async_prefetch: bool = True
|
||||
queue_size: int = 2
|
||||
|
||||
@@ -7,13 +7,6 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -47,7 +47,6 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -568,22 +567,20 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -625,10 +622,9 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -638,20 +634,21 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(frame_ranges):
|
||||
if range_idx >= len(time_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -664,7 +661,6 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -753,17 +749,15 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1775,296 +1769,3 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def trim_episodes_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim multiple episodes to keep only specific frames.
|
||||
|
||||
This function creates a new dataset where the specified episodes contain only
|
||||
the frames at the given indices. All other episodes are copied as-is.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
|
||||
Returns:
|
||||
A new LeRobotDataset with the trimmed episodes.
|
||||
"""
|
||||
if not episode_frames_to_keep:
|
||||
raise ValueError("No episodes to trim")
|
||||
|
||||
for ep_idx in episode_frames_to_keep:
|
||||
if ep_idx >= dataset.meta.total_episodes:
|
||||
raise ValueError(f"Episode {ep_idx} does not exist")
|
||||
if not episode_frames_to_keep[ep_idx]:
|
||||
raise ValueError(f"No frames to keep for episode {ep_idx}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
|
||||
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
|
||||
|
||||
# Create new metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Build set of all frames to keep (for episodes being trimmed)
|
||||
# and compute new frame counts per episode
|
||||
all_keep_frames: set[int] = set()
|
||||
trimmed_frame_counts: dict[int, int] = {}
|
||||
for ep_idx, frames in episode_frames_to_keep.items():
|
||||
all_keep_frames.update(frames)
|
||||
trimmed_frame_counts[ep_idx] = len(frames)
|
||||
|
||||
# Copy and filter data
|
||||
_copy_and_reindex_data_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep, all_keep_frames
|
||||
)
|
||||
|
||||
# Handle videos if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep
|
||||
)
|
||||
|
||||
# Copy episode metadata
|
||||
_copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
dataset, new_meta, trimmed_frame_counts
|
||||
)
|
||||
|
||||
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
|
||||
|
||||
# Return the metadata instead of trying to load as LeRobotDataset
|
||||
# This avoids Hub validation issues when the repo doesn't exist yet
|
||||
return new_meta
|
||||
|
||||
|
||||
# Keep old function for backward compatibility
|
||||
def trim_episode_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
keep_frame_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
|
||||
return trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep={episode_index: keep_frame_indices},
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
def _copy_and_reindex_data_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
all_keep_frames: set[int],
|
||||
) -> None:
|
||||
"""Copy data files with frame-level filtering for multiple episodes."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Copy tasks
|
||||
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
|
||||
# Tasks are stored with task string as index
|
||||
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
|
||||
|
||||
# Get all parquet files
|
||||
data_dir = src_dataset.root / "data"
|
||||
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
|
||||
|
||||
trim_episode_set = set(episode_frames_to_keep.keys())
|
||||
global_index = 0
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Filter: keep all frames from non-trimmed episodes,
|
||||
# and only specified frames from trimmed episodes
|
||||
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
|
||||
df = df[mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Reindex
|
||||
df["index"] = range(global_index, global_index + len(df))
|
||||
|
||||
# Recalculate frame_index within each episode
|
||||
for ep_idx in df["episode_index"].unique():
|
||||
ep_mask = df["episode_index"] == ep_idx
|
||||
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
|
||||
|
||||
# Recalculate timestamps based on frame_index and fps
|
||||
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
|
||||
|
||||
# Determine output path (keep same structure)
|
||||
rel_path = parquet_path.relative_to(src_dataset.root)
|
||||
dst_path = dst_meta.root / rel_path
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_parquet(df, dst_path, dst_meta)
|
||||
global_index += len(df)
|
||||
|
||||
|
||||
def _copy_and_reindex_videos_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
) -> None:
|
||||
"""Copy video files for trimmed dataset.
|
||||
|
||||
In v3.0 datasets, multiple episodes are concatenated into single video files.
|
||||
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
|
||||
|
||||
For trimming, we copy the original video files as-is and update the metadata
|
||||
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
|
||||
"""
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
video_dir = src_dataset.root / "videos" / video_key
|
||||
dst_video_dir = dst_meta.root / "videos" / video_key
|
||||
|
||||
if not video_dir.exists():
|
||||
logging.warning(f"Video directory not found: {video_dir}")
|
||||
continue
|
||||
|
||||
# Copy all video files (they contain concatenated episodes)
|
||||
# The metadata timestamps will handle which portions to use
|
||||
copied_files = set()
|
||||
for chunk_dir in video_dir.glob("chunk-*"):
|
||||
dst_chunk_dir = dst_video_dir / chunk_dir.name
|
||||
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for video_file in chunk_dir.glob("*.mp4"):
|
||||
if video_file.name not in copied_files:
|
||||
dst_path = dst_chunk_dir / video_file.name
|
||||
if not dst_path.exists():
|
||||
shutil.copy(video_file, dst_path)
|
||||
copied_files.add(video_file.name)
|
||||
|
||||
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
|
||||
|
||||
|
||||
def _trim_video_frames(
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
keep_frame_indices: list[int],
|
||||
fps: float,
|
||||
episode_start_idx: int,
|
||||
) -> None:
|
||||
"""Trim a video to keep only specific frames using ffmpeg."""
|
||||
import subprocess
|
||||
|
||||
# Convert global indices to local indices within the episode
|
||||
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
|
||||
|
||||
if not local_indices:
|
||||
logging.warning(f"No frames to keep for video {src_path}")
|
||||
return
|
||||
|
||||
# Calculate start and end times
|
||||
start_frame = local_indices[0]
|
||||
end_frame = local_indices[-1]
|
||||
|
||||
start_time = start_frame / fps
|
||||
duration = (end_frame - start_frame + 1) / fps
|
||||
|
||||
# Use ffmpeg to trim
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start_time),
|
||||
"-i", str(src_path),
|
||||
"-t", str(duration),
|
||||
"-c", "copy", # Fast copy without re-encoding
|
||||
str(dst_path)
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Failed to trim video: {e.stderr.decode()}")
|
||||
# Fallback: copy the whole video
|
||||
shutil.copy(src_path, dst_path)
|
||||
|
||||
|
||||
def _copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
trimmed_frame_counts: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy and update episode metadata for trimmed dataset."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Calculate new frame counts and indices
|
||||
episodes_data = []
|
||||
global_idx = 0
|
||||
|
||||
for old_ep_idx in range(src_dataset.meta.total_episodes):
|
||||
src_ep = src_dataset.meta.episodes[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_length = trimmed_frame_counts[old_ep_idx]
|
||||
else:
|
||||
ep_length = src_ep["length"]
|
||||
|
||||
ep_data = {
|
||||
"episode_index": old_ep_idx,
|
||||
"tasks": src_ep["tasks"],
|
||||
"length": ep_length,
|
||||
"data/chunk_index": src_ep["data/chunk_index"],
|
||||
"data/file_index": src_ep["data/file_index"],
|
||||
"dataset_from_index": global_idx,
|
||||
"dataset_to_index": global_idx + ep_length,
|
||||
}
|
||||
|
||||
# Copy video metadata - preserve timestamps for concatenated videos
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
|
||||
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
# Keep original from_timestamp (start position in concatenated video)
|
||||
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
|
||||
|
||||
# For trimmed episodes, update to_timestamp based on new length
|
||||
# For non-trimmed episodes, keep original to_timestamp
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
|
||||
else:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
|
||||
ep_data["meta/episodes/chunk_index"] = 0
|
||||
ep_data["meta/episodes/file_index"] = 0
|
||||
|
||||
episodes_data.append(ep_data)
|
||||
global_idx += ep_length
|
||||
|
||||
# Save episodes metadata
|
||||
df = pd.DataFrame(episodes_data)
|
||||
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(episodes_path)
|
||||
|
||||
# Update info.json
|
||||
info = load_info(src_dataset.root)
|
||||
info["total_episodes"] = len(episodes_data)
|
||||
info["total_frames"] = global_idx
|
||||
write_info(info, dst_meta.root)
|
||||
|
||||
@@ -68,7 +68,6 @@ from lerobot.datasets.utils import (
|
||||
write_tasks,
|
||||
)
|
||||
from lerobot.datasets.video_utils import (
|
||||
StreamingVideoEncoder,
|
||||
VideoFrame,
|
||||
concatenate_video_files,
|
||||
decode_video_frames,
|
||||
@@ -76,11 +75,11 @@ from lerobot.datasets.video_utils import (
|
||||
get_safe_default_codec,
|
||||
get_video_duration_in_s,
|
||||
get_video_info,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import HF_LEROBOT_HOME
|
||||
|
||||
CODEBASE_VERSION = "v3.0"
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1"}
|
||||
|
||||
|
||||
class LeRobotDatasetMetadata:
|
||||
@@ -546,19 +545,12 @@ class LeRobotDatasetMetadata:
|
||||
|
||||
|
||||
def _encode_video_worker(
|
||||
video_key: str,
|
||||
episode_index: int,
|
||||
root: Path,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
encoder_threads: int | None = None,
|
||||
video_key: str, episode_index: int, root: Path, fps: int, vcodec: str = "libsvtav1"
|
||||
) -> Path:
|
||||
temp_path = Path(tempfile.mkdtemp(dir=root)) / f"{video_key}_{episode_index:03d}.mp4"
|
||||
fpath = DEFAULT_IMAGE_PATH.format(image_key=video_key, episode_index=episode_index, frame_index=0)
|
||||
img_dir = (root / fpath).parent
|
||||
encode_video_frames(
|
||||
img_dir, temp_path, fps, vcodec=vcodec, overwrite=True, encoder_threads=encoder_threads
|
||||
)
|
||||
encode_video_frames(img_dir, temp_path, fps, vcodec=vcodec, overwrite=True)
|
||||
shutil.rmtree(img_dir)
|
||||
return temp_path
|
||||
|
||||
@@ -578,9 +570,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
"""
|
||||
2 modes are available for instantiating this class, depending on 2 different use cases:
|
||||
@@ -694,17 +683,12 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
batch_encoding_size (int, optional): Number of episodes to accumulate before batch encoding videos.
|
||||
Set to 1 for immediate encoding (default), or higher for batched encoding. Defaults to 1.
|
||||
vcodec (str, optional): Video codec for encoding videos during recording. Options: 'h264', 'hevc',
|
||||
'libsvtav1', 'auto', or hardware-specific codecs like 'h264_videotoolbox', 'h264_nvenc'.
|
||||
Defaults to 'libsvtav1'. Use 'auto' to auto-detect the best available hardware encoder.
|
||||
streaming_encoding (bool, optional): If True, encode video frames in real-time during capture
|
||||
instead of writing PNG images first. This makes save_episode() near-instant. Defaults to False.
|
||||
encoder_queue_maxsize (int, optional): Maximum number of frames to buffer per camera when using
|
||||
streaming encoding. Defaults to 30 (~1s at 30fps).
|
||||
encoder_threads (int | None, optional): Number of threads per encoder instance. None lets the
|
||||
codec auto-detect (default). Lower values reduce CPU usage per encoder. Maps to 'lp' (via svtav1-params) for
|
||||
libsvtav1 and 'threads' for h264/hevc.
|
||||
'libsvtav1'. Defaults to 'libsvtav1'. Use 'h264' for faster encoding on systems where AV1
|
||||
encoding is CPU-heavy.
|
||||
"""
|
||||
super().__init__()
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
self.repo_id = repo_id
|
||||
self.root = Path(root) if root else HF_LEROBOT_HOME / repo_id
|
||||
self.image_transforms = image_transforms
|
||||
@@ -716,8 +700,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.delta_indices = None
|
||||
self.batch_encoding_size = batch_encoding_size
|
||||
self.episodes_since_last_encoding = 0
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self._encoder_threads = encoder_threads
|
||||
self.vcodec = vcodec
|
||||
|
||||
# Unused attributes
|
||||
self.image_writer = None
|
||||
@@ -725,7 +708,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.writer = None
|
||||
self.latest_episode = None
|
||||
self._current_file_start_frame = None # Track the starting frame index of the current parquet file
|
||||
self._streaming_encoder = None
|
||||
|
||||
self.root.mkdir(exist_ok=True, parents=True)
|
||||
|
||||
@@ -747,7 +729,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -767,19 +749,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
check_delta_timestamps(self.delta_timestamps, self.fps, self.tolerance_s)
|
||||
self.delta_indices = get_delta_indices(self.delta_timestamps, self.fps)
|
||||
|
||||
# Initialize streaming encoder for resumed recording
|
||||
if streaming_encoding and len(self.meta.video_keys) > 0:
|
||||
self._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=self.meta.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
|
||||
def _close_writer(self) -> None:
|
||||
"""Close and cleanup the parquet writer if it exists."""
|
||||
writer = getattr(self, "writer", None)
|
||||
@@ -839,7 +808,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
@@ -1135,8 +1104,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
"""
|
||||
self._close_writer()
|
||||
self.meta._close_writer()
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.close()
|
||||
|
||||
def create_episode_buffer(self, episode_index: int | None = None) -> dict:
|
||||
current_ep_idx = self.meta.total_episodes if episode_index is None else episode_index
|
||||
@@ -1191,13 +1158,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.episode_buffer["timestamp"].append(timestamp)
|
||||
self.episode_buffer["task"].append(frame.pop("task")) # Remove task from frame after processing
|
||||
|
||||
# Start streaming encoder on first frame of episode (once, before iterating keys)
|
||||
if frame_index == 0 and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.start_episode(
|
||||
video_keys=list(self.meta.video_keys),
|
||||
temp_dir=self.root,
|
||||
)
|
||||
|
||||
# Add frame features to episode_buffer
|
||||
for key in frame:
|
||||
if key not in self.features:
|
||||
@@ -1205,10 +1165,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
f"An element of the frame is not in the features. '{key}' not in '{self.features.keys()}'."
|
||||
)
|
||||
|
||||
if self.features[key]["dtype"] == "video" and self._streaming_encoder is not None:
|
||||
self._streaming_encoder.feed_frame(key, frame[key])
|
||||
self.episode_buffer[key].append(None) # Placeholder (video keys are skipped in parquet)
|
||||
elif self.features[key]["dtype"] in ["image", "video"]:
|
||||
if self.features[key]["dtype"] in ["image", "video"]:
|
||||
img_path = self._get_image_file_path(
|
||||
episode_index=self.episode_buffer["episode_index"], image_key=key, frame_index=frame_index
|
||||
)
|
||||
@@ -1269,38 +1226,13 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
|
||||
# Wait for image writer to end, so that episode stats over images can be computed
|
||||
self._wait_image_writer()
|
||||
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_streaming = self._streaming_encoder is not None and has_video_keys
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if use_streaming:
|
||||
# Compute stats for non-video features only (video stats come from encoder)
|
||||
non_video_buffer = {
|
||||
k: v
|
||||
for k, v in episode_buffer.items()
|
||||
if self.features.get(k, {}).get("dtype") not in ("video",)
|
||||
}
|
||||
non_video_features = {k: v for k, v in self.features.items() if v["dtype"] != "video"}
|
||||
ep_stats = compute_episode_stats(non_video_buffer, non_video_features)
|
||||
else:
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
ep_stats = compute_episode_stats(episode_buffer, self.features)
|
||||
|
||||
ep_metadata = self._save_episode_data(episode_buffer)
|
||||
has_video_keys = len(self.meta.video_keys) > 0
|
||||
use_batched_encoding = self.batch_encoding_size > 1
|
||||
|
||||
if use_streaming:
|
||||
# Finish streaming encoding and collect results
|
||||
streaming_results = self._streaming_encoder.finish_episode()
|
||||
for video_key in self.meta.video_keys:
|
||||
temp_path, video_stats = streaming_results[video_key]
|
||||
if video_stats is not None:
|
||||
# Format stats same as compute_episode_stats: normalize to [0,1], reshape to (C,1,1)
|
||||
ep_stats[video_key] = {
|
||||
k: v if k == "count" else np.squeeze(v.reshape(1, -1, 1, 1) / 255.0, axis=0)
|
||||
for k, v in video_stats.items()
|
||||
}
|
||||
ep_metadata.update(self._save_episode_video(video_key, episode_index, temp_path=temp_path))
|
||||
elif has_video_keys and not use_batched_encoding:
|
||||
if has_video_keys and not use_batched_encoding:
|
||||
num_cameras = len(self.meta.video_keys)
|
||||
if parallel_encoding and num_cameras > 1:
|
||||
# TODO(Steven): Ideally we would like to control the number of threads per encoding such that:
|
||||
@@ -1314,7 +1246,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
self.root,
|
||||
self.fps,
|
||||
self.vcodec,
|
||||
self._encoder_threads,
|
||||
): video_key
|
||||
for video_key in self.meta.video_keys
|
||||
}
|
||||
@@ -1583,10 +1514,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
return metadata
|
||||
|
||||
def clear_episode_buffer(self, delete_images: bool = True) -> None:
|
||||
# Cancel streaming encoder if active
|
||||
if self._streaming_encoder is not None:
|
||||
self._streaming_encoder.cancel_episode()
|
||||
|
||||
# Clean up image files for the current episode buffer
|
||||
if delete_images:
|
||||
# Wait for the async image writer to finish
|
||||
@@ -1634,9 +1561,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
|
||||
since video encoding with ffmpeg is already using multithreading.
|
||||
"""
|
||||
return _encode_video_worker(
|
||||
video_key, episode_index, self.root, self.fps, self.vcodec, self._encoder_threads
|
||||
)
|
||||
return _encode_video_worker(video_key, episode_index, self.root, self.fps, self.vcodec)
|
||||
|
||||
@classmethod
|
||||
def create(
|
||||
@@ -1653,13 +1578,10 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
video_backend: str | None = None,
|
||||
batch_encoding_size: int = 1,
|
||||
vcodec: str = "libsvtav1",
|
||||
metadata_buffer_size: int = 10,
|
||||
streaming_encoding: bool = False,
|
||||
encoder_queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
) -> "LeRobotDataset":
|
||||
"""Create a LeRobot Dataset from scratch in order to record data."""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
obj = cls.__new__(cls)
|
||||
obj.meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
@@ -1668,7 +1590,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
features=features,
|
||||
root=root,
|
||||
use_videos=use_videos,
|
||||
metadata_buffer_size=metadata_buffer_size,
|
||||
)
|
||||
obj.repo_id = obj.meta.repo_id
|
||||
obj.root = obj.meta.root
|
||||
@@ -1678,7 +1599,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj.batch_encoding_size = batch_encoding_size
|
||||
obj.episodes_since_last_encoding = 0
|
||||
obj.vcodec = vcodec
|
||||
obj._encoder_threads = encoder_threads
|
||||
|
||||
if image_writer_processes or image_writer_threads:
|
||||
obj.start_image_writer(image_writer_processes, image_writer_threads)
|
||||
@@ -1700,22 +1620,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
obj._lazy_loading = False
|
||||
obj._recorded_frames = 0
|
||||
obj._writer_closed_for_reading = False
|
||||
|
||||
# Initialize streaming encoder
|
||||
if streaming_encoding and len(obj.meta.video_keys) > 0:
|
||||
obj._streaming_encoder = StreamingVideoEncoder(
|
||||
fps=fps,
|
||||
vcodec=vcodec,
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=None,
|
||||
queue_maxsize=encoder_queue_maxsize,
|
||||
encoder_threads=encoder_threads,
|
||||
)
|
||||
else:
|
||||
obj._streaming_encoder = None
|
||||
|
||||
return obj
|
||||
|
||||
|
||||
|
||||
@@ -13,106 +13,25 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
import contextlib
|
||||
import glob
|
||||
import importlib
|
||||
import logging
|
||||
import queue
|
||||
import shutil
|
||||
import tempfile
|
||||
import threading
|
||||
import warnings
|
||||
from dataclasses import dataclass, field
|
||||
from fractions import Fraction
|
||||
from pathlib import Path
|
||||
from threading import Lock
|
||||
from typing import Any, ClassVar
|
||||
|
||||
import av
|
||||
import fsspec
|
||||
import numpy as np
|
||||
import pyarrow as pa
|
||||
import torch
|
||||
import torchvision
|
||||
from datasets.features.features import register_feature
|
||||
from PIL import Image
|
||||
|
||||
# List of hardware encoders to probe for auto-selection. Availability depends on the platform and FFmpeg build.
|
||||
# Determines the order of preference for auto-selection when vcodec="auto" is used.
|
||||
HW_ENCODERS = [
|
||||
"h264_videotoolbox", # macOS
|
||||
"hevc_videotoolbox", # macOS
|
||||
"h264_nvenc", # NVIDIA GPU
|
||||
"hevc_nvenc", # NVIDIA GPU
|
||||
"h264_vaapi", # Linux Intel/AMD
|
||||
"h264_qsv", # Intel Quick Sync
|
||||
]
|
||||
|
||||
VALID_VIDEO_CODECS = {"h264", "hevc", "libsvtav1", "auto"} | set(HW_ENCODERS)
|
||||
|
||||
|
||||
def _get_codec_options(
|
||||
vcodec: str,
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
) -> dict:
|
||||
"""Build codec-specific options dict for video encoding."""
|
||||
options = {}
|
||||
|
||||
# GOP size (keyframe interval) - supported by VideoToolbox and software encoders
|
||||
if g is not None and (vcodec in ("h264_videotoolbox", "hevc_videotoolbox") or vcodec not in HW_ENCODERS):
|
||||
options["g"] = str(g)
|
||||
|
||||
# Quality control (codec-specific parameter names)
|
||||
if crf is not None:
|
||||
if vcodec in ("h264", "hevc", "libsvtav1"):
|
||||
options["crf"] = str(crf)
|
||||
elif vcodec in ("h264_videotoolbox", "hevc_videotoolbox"):
|
||||
quality = max(1, min(100, int(100 - crf * 2)))
|
||||
options["q:v"] = str(quality)
|
||||
elif vcodec in ("h264_nvenc", "hevc_nvenc"):
|
||||
options["rc"] = "constqp"
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_vaapi",):
|
||||
options["qp"] = str(crf)
|
||||
elif vcodec in ("h264_qsv",):
|
||||
options["global_quality"] = str(crf)
|
||||
|
||||
# Preset (only for libsvtav1)
|
||||
if vcodec == "libsvtav1":
|
||||
options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
return options
|
||||
|
||||
|
||||
def detect_available_hw_encoders() -> list[str]:
|
||||
"""Probe PyAV/FFmpeg for available hardware video encoders."""
|
||||
available = []
|
||||
for codec_name in HW_ENCODERS:
|
||||
try:
|
||||
av.codec.Codec(codec_name, "w")
|
||||
available.append(codec_name)
|
||||
except Exception: # nosec B110
|
||||
pass # nosec B110
|
||||
return available
|
||||
|
||||
|
||||
def resolve_vcodec(vcodec: str) -> str:
|
||||
"""Validate vcodec and resolve 'auto' to best available HW encoder, fallback to libsvtav1."""
|
||||
if vcodec not in VALID_VIDEO_CODECS:
|
||||
raise ValueError(f"Invalid vcodec '{vcodec}'. Must be one of: {sorted(VALID_VIDEO_CODECS)}")
|
||||
if vcodec != "auto":
|
||||
logging.info(f"Using video codec: {vcodec}")
|
||||
return vcodec
|
||||
available = detect_available_hw_encoders()
|
||||
for encoder in HW_ENCODERS:
|
||||
if encoder in available:
|
||||
logging.info(f"Auto-selected video codec: {encoder}")
|
||||
return encoder
|
||||
logging.info("No hardware encoder available, falling back to software encoder 'libsvtav1'")
|
||||
return "libsvtav1"
|
||||
|
||||
|
||||
def get_safe_default_codec():
|
||||
if importlib.util.find_spec("torchcodec"):
|
||||
@@ -227,17 +146,16 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -249,11 +167,7 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -358,16 +272,15 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -396,13 +309,14 @@ def encode_video_frames(
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
fast_decode: int = 0,
|
||||
log_level: int | None = av.logging.WARNING,
|
||||
log_level: int | None = av.logging.ERROR,
|
||||
overwrite: bool = False,
|
||||
preset: int | None = None,
|
||||
encoder_threads: int | None = None,
|
||||
) -> None:
|
||||
"""More info on ffmpeg arguments tuning on `benchmark/video/README.md`"""
|
||||
vcodec = resolve_vcodec(vcodec)
|
||||
# Check encoder availability
|
||||
if vcodec not in ["h264", "hevc", "libsvtav1"]:
|
||||
raise ValueError(f"Unsupported video codec: {vcodec}. Supported codecs are: h264, hevc, libsvtav1.")
|
||||
|
||||
video_path = Path(video_path)
|
||||
imgs_dir = Path(imgs_dir)
|
||||
@@ -433,22 +347,21 @@ def encode_video_frames(
|
||||
width, height = dummy_image.size
|
||||
|
||||
# Define video codec options
|
||||
video_options = _get_codec_options(vcodec, g, crf, preset)
|
||||
video_options = {}
|
||||
|
||||
if g is not None:
|
||||
video_options["g"] = str(g)
|
||||
|
||||
if crf is not None:
|
||||
video_options["crf"] = str(crf)
|
||||
|
||||
if fast_decode:
|
||||
key = "svtav1-params" if vcodec == "libsvtav1" else "tune"
|
||||
value = f"fast-decode={fast_decode}" if vcodec == "libsvtav1" else "fastdecode"
|
||||
video_options[key] = value
|
||||
|
||||
if encoder_threads is not None:
|
||||
if vcodec == "libsvtav1":
|
||||
lp_param = f"lp={encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(encoder_threads)
|
||||
if vcodec == "libsvtav1":
|
||||
video_options["preset"] = str(preset) if preset is not None else "12"
|
||||
|
||||
# Set logging level
|
||||
if log_level is not None:
|
||||
@@ -567,348 +480,6 @@ def concatenate_video_files(
|
||||
Path(tmp_concatenate_path).unlink()
|
||||
|
||||
|
||||
class _CameraEncoderThread(threading.Thread):
|
||||
"""A thread that encodes video frames streamed via a queue into an MP4 file.
|
||||
|
||||
One instance is created per camera per episode. Frames are received as numpy arrays
|
||||
from the main thread, encoded in real-time using PyAV (which releases the GIL during
|
||||
encoding), and written to disk. Stats are computed incrementally using
|
||||
RunningQuantileStats and returned via result_queue.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
video_path: Path,
|
||||
fps: int,
|
||||
vcodec: str,
|
||||
pix_fmt: str,
|
||||
g: int | None,
|
||||
crf: int | None,
|
||||
preset: int | None,
|
||||
frame_queue: queue.Queue,
|
||||
result_queue: queue.Queue,
|
||||
stop_event: threading.Event,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
super().__init__(daemon=True)
|
||||
self.video_path = video_path
|
||||
self.fps = fps
|
||||
self.vcodec = vcodec
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.frame_queue = frame_queue
|
||||
self.result_queue = result_queue
|
||||
self.stop_event = stop_event
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
def run(self) -> None:
|
||||
from lerobot.datasets.compute_stats import RunningQuantileStats, auto_downsample_height_width
|
||||
|
||||
container = None
|
||||
output_stream = None
|
||||
stats_tracker = RunningQuantileStats()
|
||||
frame_count = 0
|
||||
|
||||
try:
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
|
||||
while True:
|
||||
try:
|
||||
frame_data = self.frame_queue.get(timeout=1)
|
||||
except queue.Empty:
|
||||
if self.stop_event.is_set():
|
||||
break
|
||||
continue
|
||||
|
||||
if frame_data is None:
|
||||
# Sentinel: flush and close
|
||||
break
|
||||
|
||||
# Ensure HWC uint8 numpy array
|
||||
if isinstance(frame_data, np.ndarray):
|
||||
if frame_data.ndim == 3 and frame_data.shape[0] == 3:
|
||||
# CHW -> HWC
|
||||
frame_data = frame_data.transpose(1, 2, 0)
|
||||
if frame_data.dtype != np.uint8:
|
||||
frame_data = (frame_data * 255).astype(np.uint8)
|
||||
|
||||
# Open container on first frame (to get width/height)
|
||||
if container is None:
|
||||
height, width = frame_data.shape[:2]
|
||||
video_options = _get_codec_options(self.vcodec, self.g, self.crf, self.preset)
|
||||
if self.encoder_threads is not None:
|
||||
if self.vcodec == "libsvtav1":
|
||||
lp_param = f"lp={self.encoder_threads}"
|
||||
if "svtav1-params" in video_options:
|
||||
video_options["svtav1-params"] += f":{lp_param}"
|
||||
else:
|
||||
video_options["svtav1-params"] = lp_param
|
||||
else:
|
||||
video_options["threads"] = str(self.encoder_threads)
|
||||
Path(self.video_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
container = av.open(str(self.video_path), "w")
|
||||
output_stream = container.add_stream(self.vcodec, self.fps, options=video_options)
|
||||
output_stream.pix_fmt = self.pix_fmt
|
||||
output_stream.width = width
|
||||
output_stream.height = height
|
||||
output_stream.time_base = Fraction(1, self.fps)
|
||||
|
||||
# Encode frame with explicit timestamps
|
||||
pil_img = Image.fromarray(frame_data)
|
||||
video_frame = av.VideoFrame.from_image(pil_img)
|
||||
video_frame.pts = frame_count
|
||||
video_frame.time_base = Fraction(1, self.fps)
|
||||
packet = output_stream.encode(video_frame)
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
# Update stats with downsampled frame (per-channel stats like compute_episode_stats)
|
||||
img_chw = frame_data.transpose(2, 0, 1) # HWC -> CHW
|
||||
img_downsampled = auto_downsample_height_width(img_chw)
|
||||
# Reshape CHW to (H*W, C) for per-channel stats
|
||||
channels = img_downsampled.shape[0]
|
||||
img_for_stats = img_downsampled.transpose(1, 2, 0).reshape(-1, channels)
|
||||
stats_tracker.update(img_for_stats)
|
||||
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder
|
||||
if output_stream is not None:
|
||||
packet = output_stream.encode()
|
||||
if packet:
|
||||
container.mux(packet)
|
||||
|
||||
if container is not None:
|
||||
container.close()
|
||||
|
||||
av.logging.restore_default_callback()
|
||||
|
||||
# Get stats and put on result queue
|
||||
if frame_count >= 2:
|
||||
stats = stats_tracker.get_statistics()
|
||||
self.result_queue.put(("ok", stats))
|
||||
else:
|
||||
self.result_queue.put(("ok", None))
|
||||
|
||||
except Exception as e:
|
||||
logging.error(f"Encoder thread error: {e}")
|
||||
if container is not None:
|
||||
with contextlib.suppress(Exception):
|
||||
container.close()
|
||||
self.result_queue.put(("error", str(e)))
|
||||
|
||||
|
||||
class StreamingVideoEncoder:
|
||||
"""Manages per-camera encoder threads for real-time video encoding during recording.
|
||||
|
||||
Instead of writing frames as PNG images and then encoding to MP4 at episode end,
|
||||
this class streams frames directly to encoder threads, eliminating the
|
||||
PNG round-trip and making save_episode() near-instant.
|
||||
|
||||
Uses threading instead of multiprocessing to avoid the overhead of pickling large
|
||||
numpy arrays through multiprocessing.Queue. PyAV's encode() releases the GIL,
|
||||
so encoding runs in parallel with the main recording loop.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
fps: int,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
g: int | None = 2,
|
||||
crf: int | None = 30,
|
||||
preset: int | None = None,
|
||||
queue_maxsize: int = 30,
|
||||
encoder_threads: int | None = None,
|
||||
):
|
||||
self.fps = fps
|
||||
self.vcodec = resolve_vcodec(vcodec)
|
||||
self.pix_fmt = pix_fmt
|
||||
self.g = g
|
||||
self.crf = crf
|
||||
self.preset = preset
|
||||
self.queue_maxsize = queue_maxsize
|
||||
self.encoder_threads = encoder_threads
|
||||
|
||||
self._frame_queues: dict[str, queue.Queue] = {}
|
||||
self._result_queues: dict[str, queue.Queue] = {}
|
||||
self._threads: dict[str, _CameraEncoderThread] = {}
|
||||
self._stop_events: dict[str, threading.Event] = {}
|
||||
self._video_paths: dict[str, Path] = {}
|
||||
self._dropped_frames: dict[str, int] = {}
|
||||
self._episode_active = False
|
||||
|
||||
def start_episode(self, video_keys: list[str], temp_dir: Path) -> None:
|
||||
"""Start encoder threads for a new episode.
|
||||
|
||||
Args:
|
||||
video_keys: List of video feature keys (e.g. ["observation.images.laptop"])
|
||||
temp_dir: Base directory for temporary MP4 files
|
||||
"""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
self._dropped_frames.clear()
|
||||
|
||||
for video_key in video_keys:
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=self.queue_maxsize)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
temp_video_dir = Path(tempfile.mkdtemp(dir=temp_dir))
|
||||
video_path = temp_video_dir / f"{video_key.replace('/', '_')}_streaming.mp4"
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=self.fps,
|
||||
vcodec=self.vcodec,
|
||||
pix_fmt=self.pix_fmt,
|
||||
g=self.g,
|
||||
crf=self.crf,
|
||||
preset=self.preset,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
encoder_threads=self.encoder_threads,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
self._frame_queues[video_key] = frame_queue
|
||||
self._result_queues[video_key] = result_queue
|
||||
self._threads[video_key] = encoder_thread
|
||||
self._stop_events[video_key] = stop_event
|
||||
self._video_paths[video_key] = video_path
|
||||
|
||||
self._episode_active = True
|
||||
|
||||
def feed_frame(self, video_key: str, image: np.ndarray) -> None:
|
||||
"""Feed a frame to the encoder for a specific camera.
|
||||
|
||||
A copy of the image is made before enqueueing to prevent race conditions
|
||||
with camera drivers that may reuse buffers. If the encoder queue is full
|
||||
(encoder can't keep up), the frame is dropped with a warning instead of
|
||||
crashing the recording session.
|
||||
|
||||
Args:
|
||||
video_key: The video feature key
|
||||
image: numpy array in (H,W,C) or (C,H,W) format, uint8 or float
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the encoder thread has crashed
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode. Call start_episode() first.")
|
||||
|
||||
thread = self._threads[video_key]
|
||||
if not thread.is_alive():
|
||||
# Check for error
|
||||
try:
|
||||
status, msg = self._result_queues[video_key].get_nowait()
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} crashed: {msg}")
|
||||
except queue.Empty:
|
||||
pass
|
||||
raise RuntimeError(f"Encoder thread for {video_key} is not alive")
|
||||
|
||||
try:
|
||||
self._frame_queues[video_key].put(image.copy(), timeout=0.1)
|
||||
except queue.Full:
|
||||
self._dropped_frames[video_key] = self._dropped_frames.get(video_key, 0) + 1
|
||||
count = self._dropped_frames[video_key]
|
||||
# Log periodically to avoid spam (1st, then every 10th)
|
||||
if count == 1 or count % 10 == 0:
|
||||
logging.warning(
|
||||
f"Encoder queue full for {video_key}, dropped {count} frame(s). "
|
||||
f"Consider using vcodec='auto' for hardware encoding or increasing encoder_queue_maxsize."
|
||||
)
|
||||
|
||||
def finish_episode(self) -> dict[str, tuple[Path, dict | None]]:
|
||||
"""Finish encoding the current episode.
|
||||
|
||||
Sends sentinel values, waits for encoder threads to complete,
|
||||
and collects results.
|
||||
|
||||
Returns:
|
||||
Dict mapping video_key to (mp4_path, stats_dict_or_None)
|
||||
"""
|
||||
if not self._episode_active:
|
||||
raise RuntimeError("No active episode to finish.")
|
||||
|
||||
results = {}
|
||||
|
||||
# Report dropped frames
|
||||
for video_key, count in self._dropped_frames.items():
|
||||
if count > 0:
|
||||
logging.warning(f"Episode finished with {count} dropped frame(s) for {video_key}.")
|
||||
|
||||
# Send sentinel to all queues
|
||||
for video_key in self._frame_queues:
|
||||
self._frame_queues[video_key].put(None)
|
||||
|
||||
# Wait for all threads and collect results
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=120)
|
||||
if self._threads[video_key].is_alive():
|
||||
logging.error(f"Encoder thread for {video_key} did not finish in time")
|
||||
self._stop_events[video_key].set()
|
||||
self._threads[video_key].join(timeout=5)
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
continue
|
||||
|
||||
try:
|
||||
status, data = self._result_queues[video_key].get(timeout=5)
|
||||
if status == "error":
|
||||
raise RuntimeError(f"Encoder thread for {video_key} failed: {data}")
|
||||
results[video_key] = (self._video_paths[video_key], data)
|
||||
except queue.Empty:
|
||||
logging.error(f"No result from encoder thread for {video_key}")
|
||||
results[video_key] = (self._video_paths[video_key], None)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
return results
|
||||
|
||||
def cancel_episode(self) -> None:
|
||||
"""Cancel the current episode, stopping encoder threads and cleaning up."""
|
||||
if not self._episode_active:
|
||||
return
|
||||
|
||||
# Signal all threads to stop
|
||||
for video_key in self._stop_events:
|
||||
self._stop_events[video_key].set()
|
||||
|
||||
# Wait for threads to finish
|
||||
for video_key in self._threads:
|
||||
self._threads[video_key].join(timeout=5)
|
||||
|
||||
# Clean up temp MP4 files
|
||||
video_path = self._video_paths.get(video_key)
|
||||
if video_path is not None and video_path.exists():
|
||||
shutil.rmtree(str(video_path.parent), ignore_errors=True)
|
||||
|
||||
self._cleanup()
|
||||
self._episode_active = False
|
||||
|
||||
def close(self) -> None:
|
||||
"""Close the encoder, canceling any in-progress episode."""
|
||||
if self._episode_active:
|
||||
self.cancel_episode()
|
||||
|
||||
def _cleanup(self) -> None:
|
||||
"""Clean up queues and thread tracking dicts."""
|
||||
for q in self._frame_queues.values():
|
||||
with contextlib.suppress(Exception):
|
||||
while not q.empty():
|
||||
q.get_nowait()
|
||||
self._frame_queues.clear()
|
||||
self._result_queues.clear()
|
||||
self._threads.clear()
|
||||
self._stop_events.clear()
|
||||
self._video_paths.clear()
|
||||
|
||||
|
||||
@dataclass
|
||||
class VideoFrame:
|
||||
# TODO(rcadene, lhoestq): move to Hugging Face `datasets` repo
|
||||
@@ -943,7 +514,7 @@ with warnings.catch_warnings():
|
||||
|
||||
def get_audio_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting audio stream information
|
||||
audio_info = {}
|
||||
@@ -975,7 +546,7 @@ def get_audio_info(video_path: Path | str) -> dict:
|
||||
|
||||
def get_video_info(video_path: Path | str) -> dict:
|
||||
# Set logging level
|
||||
logging.getLogger("libav").setLevel(av.logging.WARNING)
|
||||
logging.getLogger("libav").setLevel(av.logging.ERROR)
|
||||
|
||||
# Getting video stream information
|
||||
video_info = {}
|
||||
@@ -1061,15 +632,8 @@ class VideoEncodingManager:
|
||||
return self
|
||||
|
||||
def __exit__(self, exc_type, exc_val, exc_tb):
|
||||
streaming_encoder = getattr(self.dataset, "_streaming_encoder", None)
|
||||
|
||||
if streaming_encoder is not None:
|
||||
# Handle streaming encoder cleanup
|
||||
if exc_type is not None:
|
||||
streaming_encoder.cancel_episode()
|
||||
streaming_encoder.close()
|
||||
elif self.dataset.episodes_since_last_encoding > 0:
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
# Handle any remaining episodes that haven't been batch encoded
|
||||
if self.dataset.episodes_since_last_encoding > 0:
|
||||
if exc_type is not None:
|
||||
logging.info("Exception occurred. Encoding remaining episodes before exit...")
|
||||
else:
|
||||
@@ -1086,8 +650,8 @@ class VideoEncodingManager:
|
||||
# Finalize the dataset to properly close all writers
|
||||
self.dataset.finalize()
|
||||
|
||||
# Clean up episode images if recording was interrupted (only for non-streaming mode)
|
||||
if exc_type is not None and streaming_encoder is None:
|
||||
# Clean up episode images if recording was interrupted
|
||||
if exc_type is not None:
|
||||
interrupted_episode_index = self.dataset.num_episodes
|
||||
for key in self.dataset.meta.video_keys:
|
||||
img_dir = self.dataset._get_image_file_path(
|
||||
@@ -1101,12 +665,14 @@ class VideoEncodingManager:
|
||||
|
||||
# Clean up any remaining images directory if it's empty
|
||||
img_dir = self.dataset.root / "images"
|
||||
if img_dir.exists():
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Check for any remaining PNG files
|
||||
png_files = list(img_dir.rglob("*.png"))
|
||||
if len(png_files) == 0:
|
||||
# Only remove the images directory if no PNG files remain
|
||||
if img_dir.exists():
|
||||
shutil.rmtree(img_dir)
|
||||
logging.debug("Cleaned up empty images directory")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
else:
|
||||
logging.debug(f"Images directory is not empty, containing {len(png_files)} PNG files")
|
||||
|
||||
return False # Don't suppress the original exception
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
@@ -1,120 +0,0 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Configuration tables for Damiao motors."""
|
||||
|
||||
from enum import IntEnum
|
||||
|
||||
|
||||
# Motor type definitions
|
||||
class MotorType(IntEnum):
|
||||
O0 = 0
|
||||
O1 = 1
|
||||
O2 = 2
|
||||
O3 = 3
|
||||
O4 = 4
|
||||
O5 = 5
|
||||
ELO5 = 6
|
||||
O6 = 7
|
||||
|
||||
|
||||
class CommMode(IntEnum):
|
||||
PrivateProtocole = 0
|
||||
CANopen = 1
|
||||
MIT = 2
|
||||
|
||||
|
||||
# Control modes
|
||||
class ControlMode(IntEnum):
|
||||
MIT = 0
|
||||
POS_VEL = 1
|
||||
VEL = 2
|
||||
|
||||
|
||||
# Motor limit parameters [PMAX, VMAX, TMAX]
|
||||
# PMAX: Maximum position (rad)
|
||||
# VMAX: Maximum velocity (rad/s)
|
||||
# TMAX: Maximum torque (N·m)
|
||||
MOTOR_LIMIT_PARAMS: dict[MotorType, tuple[float, float, float]] = {
|
||||
MotorType.O0: (12.57, 33, 14),
|
||||
MotorType.O1: (12.57, 44, 17),
|
||||
MotorType.O2: (12.57, 33, 20),
|
||||
MotorType.O3: (12.57, 33, 60),
|
||||
MotorType.O4: (12.57, 33, 120),
|
||||
MotorType.O5: (12.57, 50, 5.5),
|
||||
MotorType.ELO5: (12.57, 50, 6),
|
||||
MotorType.O6: (112.5, 50, 36),
|
||||
}
|
||||
|
||||
# Motor model names
|
||||
MODEL_NAMES = {
|
||||
MotorType.O0: "O0",
|
||||
MotorType.O1: "O1",
|
||||
MotorType.O2: "O2",
|
||||
MotorType.O3: "O3",
|
||||
MotorType.O4: "O4",
|
||||
MotorType.O5: "O5",
|
||||
MotorType.ELO5: "ELO5",
|
||||
MotorType.O6: "O6",
|
||||
}
|
||||
|
||||
# Motor resolution table (encoder counts per revolution)
|
||||
MODEL_RESOLUTION = {
|
||||
"O0": 65536,
|
||||
"O1": 65536,
|
||||
"O2": 65536,
|
||||
"O3": 65536,
|
||||
"O4": 65536,
|
||||
"O5": 65536,
|
||||
"ELO5": 65536,
|
||||
"O6": 65536,
|
||||
}
|
||||
|
||||
# CAN baudrates supported by Robstride motors
|
||||
AVAILABLE_BAUDRATES = [
|
||||
1000000, # 4: 1 mbps (default)
|
||||
]
|
||||
DEFAULT_BAUDRATE = 1000000
|
||||
|
||||
# Default timeout in milliseconds
|
||||
DEFAULT_TIMEOUT_MS = 0 # disabled by default, otherwise 20000 is 1s
|
||||
|
||||
|
||||
# Data that should be normalized
|
||||
NORMALIZED_DATA = ["Present_Position", "Goal_Position"]
|
||||
|
||||
|
||||
# MIT control parameter ranges
|
||||
MIT_KP_RANGE = (0.0, 500.0)
|
||||
MIT_KD_RANGE = (0.0, 5.0)
|
||||
|
||||
# CAN frame command IDs
|
||||
CAN_CMD_ENABLE = 0xFC
|
||||
CAN_CMD_DISABLE = 0xFD
|
||||
CAN_CMD_SET_ZERO = 0xFE
|
||||
CAN_CMD_CLEAR_FAULT = 0xFB
|
||||
|
||||
|
||||
CAN_CMD_QUERY_PARAM = 0x33
|
||||
CAN_CMD_WRITE_PARAM = 0x55
|
||||
CAN_CMD_SAVE_PARAM = 0xAA
|
||||
|
||||
# CAN ID for parameter operations
|
||||
CAN_PARAM_ID = 0x7FF
|
||||
|
||||
|
||||
RUNNING_TIMEOUT = 0.001
|
||||
PARAM_TIMEOUT = 0.01
|
||||
|
||||
STATE_CACHE_TTL_S = 0.02
|
||||
@@ -139,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
|
||||
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
|
||||
+4
-6
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
from lerobot.policies.rlt.configuration_rlt import RLTConfig
|
||||
from lerobot.policies.rlt.modeling_rlt import RLTPolicy
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
__all__ = ["RLTConfig", "RLTPolicy"]
|
||||
@@ -0,0 +1,156 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy configuration.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import NormalizationMode
|
||||
from lerobot.policies.sac.configuration_sac import ActorLearnerConfig, ConcurrencyConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTokenConfig:
|
||||
"""Configuration for the RL-token encoder/decoder transformer."""
|
||||
|
||||
input_dim: int = 2048
|
||||
rl_token_dim: int = 2048
|
||||
num_encoder_layers: int = 2
|
||||
num_decoder_layers: int = 2
|
||||
num_heads: int = 8
|
||||
ff_dim: int = 2048
|
||||
dropout: float = 0.0
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTActorConfig:
|
||||
"""Configuration for the lightweight RL actor MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
std: float = 0.1
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLTCriticConfig:
|
||||
"""Configuration for the RLT critic MLP."""
|
||||
|
||||
hidden_dims: list[int] = field(default_factory=lambda: [256, 256])
|
||||
|
||||
|
||||
@PreTrainedConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTConfig(PreTrainedConfig):
|
||||
"""Configuration for the RLT (RL Token) policy.
|
||||
|
||||
RLT adds an RL-token encoder/decoder to a frozen VLA backbone, then trains
|
||||
a lightweight actor-critic head using the RL token as state representation.
|
||||
The frozen VLA also provides reference action chunks that the actor refines.
|
||||
"""
|
||||
|
||||
normalization_mapping: dict[str, NormalizationMode] = field(
|
||||
default_factory=lambda: {
|
||||
"VISUAL": NormalizationMode.MEAN_STD,
|
||||
"STATE": NormalizationMode.MIN_MAX,
|
||||
"ACTION": NormalizationMode.MIN_MAX,
|
||||
}
|
||||
)
|
||||
|
||||
dataset_stats: dict[str, dict[str, list[float]]] | None = field(
|
||||
default_factory=lambda: {
|
||||
OBS_IMAGE: {
|
||||
"mean": [0.485, 0.456, 0.406],
|
||||
"std": [0.229, 0.224, 0.225],
|
||||
},
|
||||
OBS_STATE: {"min": [0.0], "max": [1.0]},
|
||||
ACTION: {"min": [0.0], "max": [1.0]},
|
||||
}
|
||||
)
|
||||
|
||||
# ── Device ──
|
||||
device: str = "cuda"
|
||||
storage_device: str = "cpu"
|
||||
|
||||
# ── VLA backbone ──
|
||||
vla_checkpoint: str | None = None
|
||||
|
||||
# ── RL-token ──
|
||||
rl_token: RLTokenConfig = field(default_factory=RLTokenConfig)
|
||||
|
||||
# ── Actor / Critic heads ──
|
||||
actor: RLTActorConfig = field(default_factory=RLTActorConfig)
|
||||
critic: RLTCriticConfig = field(default_factory=RLTCriticConfig)
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
vla_chunk_size: int = 50
|
||||
|
||||
# ── Training parameters ──
|
||||
online_steps: int = 50000
|
||||
offline_steps: int = 5000
|
||||
online_buffer_capacity: int = 100000
|
||||
offline_buffer_capacity: int = 100000
|
||||
online_step_before_learning: int = 500
|
||||
warmup_steps: int = 500
|
||||
async_prefetch: bool = False
|
||||
|
||||
# ── Algorithm hyperparameters ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
discount: float = 0.99
|
||||
critic_lr: float = 3e-4
|
||||
actor_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
tau: float = 0.005
|
||||
clip_grad_norm: float = 10.0
|
||||
num_critics: int = 2
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
chunk_stride: int = 2
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
# ── Distributed ──
|
||||
actor_learner_config: ActorLearnerConfig = field(default_factory=ActorLearnerConfig)
|
||||
concurrency: ConcurrencyConfig = field(default_factory=ConcurrencyConfig)
|
||||
|
||||
def __post_init__(self):
|
||||
super().__post_init__()
|
||||
|
||||
def get_optimizer_preset(self):
|
||||
return None
|
||||
|
||||
def get_scheduler_preset(self):
|
||||
return None
|
||||
|
||||
def validate_features(self) -> None:
|
||||
if ACTION not in self.output_features:
|
||||
raise ValueError("You must provide 'action' in the output features")
|
||||
|
||||
@property
|
||||
def observation_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def action_delta_indices(self) -> list | None:
|
||||
return None
|
||||
|
||||
@property
|
||||
def reward_delta_indices(self) -> None:
|
||||
return None
|
||||
@@ -0,0 +1,318 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) policy networks.
|
||||
|
||||
Reference: "RL Token: Bootstrapping Online RL with Vision-Language-Action Models"
|
||||
(Xu et al., Physical Intelligence, 2026)
|
||||
|
||||
Architecture:
|
||||
- RLTokenEncoder: compresses VLA token embeddings into a single compact RL token
|
||||
- RLTokenDecoder: reconstructs VLA embeddings from the RL token (Stage 1 training only)
|
||||
- RLTActor: refines VLA reference action chunks conditioned on (z_rl, proprioception, ref_action)
|
||||
- RLTCritic: Q(x, action_chunk) where x = (z_rl, proprioception)
|
||||
- RLTPolicy: bundles RL-token modules + actor into a PreTrainedPolicy for inference
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.policies.rlt.configuration_rlt import RLTConfig
|
||||
|
||||
# ── Building blocks ──────────────────────────────────────────────────
|
||||
|
||||
|
||||
class MLP(nn.Module):
|
||||
"""Simple feedforward network with ReLU activations."""
|
||||
|
||||
def __init__(self, input_dim: int, hidden_dims: list[int], output_dim: int):
|
||||
super().__init__()
|
||||
layers: list[nn.Module] = []
|
||||
prev = input_dim
|
||||
for h in hidden_dims:
|
||||
layers.append(nn.Linear(prev, h))
|
||||
layers.append(nn.ReLU())
|
||||
prev = h
|
||||
layers.append(nn.Linear(prev, output_dim))
|
||||
self.net = nn.Sequential(*layers)
|
||||
|
||||
def forward(self, x: Tensor) -> Tensor:
|
||||
return self.net(x)
|
||||
|
||||
|
||||
# ── RL Token Encoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenEncoder(nn.Module):
|
||||
"""Compress VLA token embeddings into a single RL token via a small transformer.
|
||||
|
||||
Appends a learnable ``e_rl`` embedding to the VLA token sequence, processes
|
||||
through transformer encoder layers, and returns the output at the ``e_rl``
|
||||
position as the RL token ``z_rl``.
|
||||
|
||||
Paper Eq. 1: z_rl = g_phi([z_{1:M}, e_rl])_{M+1}
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
input_dim: int,
|
||||
rl_token_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.rl_token_dim = rl_token_dim
|
||||
|
||||
self.e_rl = nn.Parameter(torch.randn(1, 1, input_dim) * 0.02)
|
||||
|
||||
if input_dim != rl_token_dim:
|
||||
self.input_proj = nn.Linear(input_dim, rl_token_dim)
|
||||
else:
|
||||
self.input_proj = nn.Identity()
|
||||
|
||||
encoder_layer = nn.TransformerEncoderLayer(
|
||||
d_model=rl_token_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
|
||||
|
||||
def forward(self, z_vla: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_vla: VLA token embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
RL token ``z_rl``, shape ``(B, rl_token_dim)``.
|
||||
"""
|
||||
batch_size = z_vla.shape[0]
|
||||
e_rl = self.e_rl.expand(batch_size, -1, -1)
|
||||
seq = torch.cat([z_vla, e_rl], dim=1) # (B, M+1, D)
|
||||
seq = self.input_proj(seq)
|
||||
out = self.transformer(seq)
|
||||
z_rl = out[:, -1, :] # output at e_rl position
|
||||
return z_rl
|
||||
|
||||
|
||||
# ── RL Token Decoder ─────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTokenDecoder(nn.Module):
|
||||
"""Autoregressively reconstruct VLA embeddings from z_rl.
|
||||
|
||||
Used only during Stage 1 (offline RL-token training).
|
||||
|
||||
Paper Eq. 2: L_ro = E[sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
rl_token_dim: int,
|
||||
output_dim: int,
|
||||
num_layers: int,
|
||||
num_heads: int,
|
||||
ff_dim: int,
|
||||
dropout: float = 0.0,
|
||||
):
|
||||
super().__init__()
|
||||
self.output_dim = output_dim
|
||||
|
||||
if rl_token_dim != output_dim:
|
||||
self.rl_proj = nn.Linear(rl_token_dim, output_dim)
|
||||
else:
|
||||
self.rl_proj = nn.Identity()
|
||||
|
||||
decoder_layer = nn.TransformerDecoderLayer(
|
||||
d_model=output_dim,
|
||||
nhead=num_heads,
|
||||
dim_feedforward=ff_dim,
|
||||
dropout=dropout,
|
||||
batch_first=True,
|
||||
)
|
||||
self.transformer = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
|
||||
self.output_head = nn.Linear(output_dim, output_dim)
|
||||
|
||||
def forward(self, z_rl: Tensor, z_vla_stopped: Tensor) -> Tensor:
|
||||
"""
|
||||
Args:
|
||||
z_rl: RL token, shape ``(B, D_rl)``.
|
||||
z_vla_stopped: Stop-gradient VLA embeddings, shape ``(B, M, D)``.
|
||||
|
||||
Returns:
|
||||
Reconstructed embeddings, shape ``(B, M, D)``.
|
||||
"""
|
||||
seq_len = z_vla_stopped.shape[1]
|
||||
z_rl_proj = self.rl_proj(z_rl).unsqueeze(1)
|
||||
|
||||
target = torch.cat([z_rl_proj, z_vla_stopped[:, :-1, :]], dim=1)
|
||||
|
||||
causal_mask = nn.Transformer.generate_square_subsequent_mask(seq_len, device=z_rl.device)
|
||||
|
||||
decoded = self.transformer(
|
||||
tgt=target,
|
||||
memory=z_rl_proj,
|
||||
tgt_mask=causal_mask,
|
||||
)
|
||||
return self.output_head(decoded) # (B, M, D)
|
||||
|
||||
|
||||
# ── Actor ────────────────────────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTActor(nn.Module):
|
||||
"""Lightweight actor that refines VLA reference action chunks.
|
||||
|
||||
Paper Eq. 4: pi_theta(a_{1:C} | x, a_tilde_{1:C}) = N(mu_theta(x, a_tilde), sigma^2 I)
|
||||
|
||||
The actor is conditioned on both the RL state and the VLA's proposed action
|
||||
chunk, acting as a "VLA-guided action editor".
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int], std: float = 0.1):
|
||||
super().__init__()
|
||||
input_dim = state_dim + action_chunk_dim
|
||||
self.net = MLP(input_dim, hidden_dims, action_chunk_dim)
|
||||
self.log_std = math.log(std)
|
||||
|
||||
def forward(self, state: Tensor, ref_action_chunk: Tensor) -> Tensor:
|
||||
"""Return the mean action chunk.
|
||||
|
||||
Args:
|
||||
state: RL state ``x = (z_rl, proprioception)``, shape ``(B, state_dim)``.
|
||||
ref_action_chunk: Flattened VLA reference chunk, shape ``(B, C*d)``.
|
||||
|
||||
Returns:
|
||||
Refined action chunk (mean), shape ``(B, C*d)``.
|
||||
"""
|
||||
x = torch.cat([state, ref_action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
def sample(self, state: Tensor, ref_action_chunk: Tensor) -> tuple[Tensor, Tensor]:
|
||||
"""Sample an action and return (action, log_prob)."""
|
||||
mean = self.forward(state, ref_action_chunk)
|
||||
std = math.exp(self.log_std)
|
||||
noise = torch.randn_like(mean) * std
|
||||
action = mean + noise
|
||||
log_prob = -0.5 * (noise / std).pow(2).sum(dim=-1) - mean.shape[-1] * math.log(
|
||||
std * math.sqrt(2 * math.pi)
|
||||
)
|
||||
return action, log_prob
|
||||
|
||||
|
||||
# ── Policy (inference bundle) ────────────────────────────────────────
|
||||
|
||||
|
||||
class RLTPolicy(PreTrainedPolicy):
|
||||
"""RLT policy — bundles the RL-token encoder and actor for inference.
|
||||
|
||||
The frozen VLA backbone is **not** part of this module; it is loaded
|
||||
separately and its embeddings / reference actions are passed in via the
|
||||
observation dict (populated by the actor process or a preprocessor).
|
||||
|
||||
During training, the :class:`RLTAlgorithm` holds the critic, target networks,
|
||||
and optimizers. This class only contains what is needed for ``select_action``.
|
||||
"""
|
||||
|
||||
name = "rlt"
|
||||
config_class = RLTConfig
|
||||
|
||||
def __init__(self, config: RLTConfig, dataset_stats=None):
|
||||
super().__init__(config, dataset_stats)
|
||||
action_dim = config.output_features["action"].shape[0]
|
||||
action_chunk_dim = config.chunk_size * action_dim
|
||||
prop_feature = config.input_features.get("observation.state", None)
|
||||
proprioception_dim = prop_feature.shape[0] if prop_feature is not None else 0
|
||||
|
||||
state_dim = config.rl_token.rl_token_dim + proprioception_dim
|
||||
|
||||
# RL-token encoder (frozen after Stage 1)
|
||||
self.rl_token_encoder = RLTokenEncoder(
|
||||
input_dim=config.rl_token.input_dim,
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
num_layers=config.rl_token.num_encoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# RL-token decoder (used only during Stage 1 training)
|
||||
self.rl_token_decoder = RLTokenDecoder(
|
||||
rl_token_dim=config.rl_token.rl_token_dim,
|
||||
output_dim=config.rl_token.input_dim,
|
||||
num_layers=config.rl_token.num_decoder_layers,
|
||||
num_heads=config.rl_token.num_heads,
|
||||
ff_dim=config.rl_token.ff_dim,
|
||||
dropout=config.rl_token.dropout,
|
||||
)
|
||||
|
||||
# Actor MLP
|
||||
self.actor = RLTActor(
|
||||
state_dim=state_dim,
|
||||
action_chunk_dim=action_chunk_dim,
|
||||
hidden_dims=config.actor.hidden_dims,
|
||||
std=config.actor.std,
|
||||
)
|
||||
|
||||
self._action_dim = action_dim
|
||||
self._action_chunk_dim = action_chunk_dim
|
||||
self._state_dim = state_dim
|
||||
self._proprioception_dim = proprioception_dim
|
||||
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select a refined action chunk given an observation.
|
||||
|
||||
Expects the observation dict to contain:
|
||||
- ``"observation.vla_embeddings"``: VLA internal token embeddings ``(M, D)``
|
||||
- ``"observation.reference_action"``: VLA reference chunk ``(C*d,)``
|
||||
- ``"observation.state"`` (optional): proprioceptive state ``(P,)``
|
||||
|
||||
Returns:
|
||||
Action chunk tensor of shape ``(C*d,)``.
|
||||
"""
|
||||
self.eval()
|
||||
|
||||
vla_emb = batch["observation.vla_embeddings"]
|
||||
if vla_emb.dim() == 2:
|
||||
vla_emb = vla_emb.unsqueeze(0)
|
||||
|
||||
z_rl = self.rl_token_encoder(vla_emb) # (1, D_rl)
|
||||
|
||||
parts = [z_rl]
|
||||
if "observation.state" in batch and self._proprioception_dim > 0:
|
||||
prop = batch["observation.state"]
|
||||
if prop.dim() == 1:
|
||||
prop = prop.unsqueeze(0)
|
||||
parts.append(prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
|
||||
ref = batch["observation.reference_action"]
|
||||
if ref.dim() == 1:
|
||||
ref = ref.unsqueeze(0)
|
||||
|
||||
action = self.actor(state, ref)
|
||||
return action.squeeze(0)
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
@@ -15,16 +15,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
from collections.abc import Callable
|
||||
from dataclasses import asdict
|
||||
from typing import Literal
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.distributions import MultivariateNormal, TanhTransform, Transform, TransformedDistribution
|
||||
|
||||
@@ -52,20 +47,13 @@ class SACPolicy(
|
||||
|
||||
# Determine action dimension and initialize all components
|
||||
continuous_action_dim = config.output_features[ACTION].shape[0]
|
||||
self._init_encoders()
|
||||
self._init_critics(continuous_action_dim)
|
||||
self.encoder = SACObservationEncoder(config)
|
||||
self._init_actor(continuous_action_dim)
|
||||
self._init_temperature()
|
||||
self._init_discrete_critic()
|
||||
|
||||
def get_optim_params(self) -> dict:
|
||||
optim_params = {
|
||||
"actor": [
|
||||
p
|
||||
for n, p in self.actor.named_parameters()
|
||||
if not n.startswith("encoder") or not self.shared_encoder
|
||||
],
|
||||
"critic": self.critic_ensemble.parameters(),
|
||||
"temperature": self.log_alpha,
|
||||
"actor": [self.actor.parameters()],
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
optim_params["discrete_critic"] = self.discrete_critic.parameters()
|
||||
@@ -83,10 +71,9 @@ class SACPolicy(
|
||||
@torch.no_grad()
|
||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||
"""Select action for inference/evaluation"""
|
||||
|
||||
observations_features = None
|
||||
if self.shared_encoder and self.actor.encoder.has_images:
|
||||
observations_features = self.actor.encoder.get_cached_image_features(batch)
|
||||
if self.encoder.has_images:
|
||||
observations_features = self.encoder.get_cached_image_features(batch)
|
||||
|
||||
actions, _, _ = self.actor(batch, observations_features)
|
||||
|
||||
@@ -97,372 +84,35 @@ class SACPolicy(
|
||||
|
||||
return actions
|
||||
|
||||
def critic_forward(
|
||||
self,
|
||||
observations: dict[str, Tensor],
|
||||
actions: Tensor,
|
||||
use_target: bool = False,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
"""Forward pass through a critic network ensemble
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
actions: Action tensor
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from all critics
|
||||
"""
|
||||
|
||||
critics = self.critic_target if use_target else self.critic_ensemble
|
||||
q_values = critics(observations, actions, observation_features)
|
||||
return q_values
|
||||
|
||||
def discrete_critic_forward(
|
||||
self, observations, use_target=False, observation_features=None
|
||||
) -> torch.Tensor:
|
||||
"""Forward pass through a discrete critic network
|
||||
|
||||
Args:
|
||||
observations: Dictionary of observations
|
||||
use_target: If True, use target critics, otherwise use ensemble critics
|
||||
observation_features: Optional pre-computed observation features to avoid recomputing encoder output
|
||||
|
||||
Returns:
|
||||
Tensor of Q-values from the discrete critic network
|
||||
"""
|
||||
discrete_critic = self.discrete_critic_target if use_target else self.discrete_critic
|
||||
q_values = discrete_critic(observations, observation_features)
|
||||
return q_values
|
||||
|
||||
def forward(
|
||||
self,
|
||||
batch: dict[str, Tensor | dict[str, Tensor]],
|
||||
model: Literal["actor", "critic", "temperature", "discrete_critic"] = "critic",
|
||||
) -> dict[str, Tensor]:
|
||||
"""Compute the loss for the given model
|
||||
"""Actor forward pass."""
|
||||
observations = batch.get("state", batch)
|
||||
observation_features = batch.get("observation_feature") if isinstance(batch, dict) else None
|
||||
actions, log_probs, means = self.actor(observations, observation_features)
|
||||
return {"action": actions, "log_prob": log_probs, "action_mean": means}
|
||||
|
||||
Args:
|
||||
batch: Dictionary containing:
|
||||
- action: Action tensor
|
||||
- reward: Reward tensor
|
||||
- state: Observations tensor dict
|
||||
- next_state: Next observations tensor dict
|
||||
- done: Done mask tensor
|
||||
- observation_feature: Optional pre-computed observation features
|
||||
- next_observation_feature: Optional pre-computed next observation features
|
||||
model: Which model to compute the loss for ("actor", "critic", "discrete_critic", or "temperature")
|
||||
|
||||
Returns:
|
||||
The computed loss tensor
|
||||
"""
|
||||
# Extract common components from batch
|
||||
actions: Tensor = batch[ACTION]
|
||||
observations: dict[str, Tensor] = batch["state"]
|
||||
observation_features: Tensor = batch.get("observation_feature")
|
||||
|
||||
if model == "critic":
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
|
||||
loss_critic = self.compute_loss_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
return {"loss_critic": loss_critic}
|
||||
|
||||
if model == "discrete_critic" and self.config.num_discrete_actions is not None:
|
||||
# Extract critic-specific components
|
||||
rewards: Tensor = batch["reward"]
|
||||
next_observations: dict[str, Tensor] = batch["next_state"]
|
||||
done: Tensor = batch["done"]
|
||||
next_observation_features: Tensor = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
loss_discrete_critic = self.compute_loss_discrete_critic(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
rewards=rewards,
|
||||
next_observations=next_observations,
|
||||
done=done,
|
||||
observation_features=observation_features,
|
||||
next_observation_features=next_observation_features,
|
||||
complementary_info=complementary_info,
|
||||
)
|
||||
return {"loss_discrete_critic": loss_discrete_critic}
|
||||
if model == "actor":
|
||||
return {
|
||||
"loss_actor": self.compute_loss_actor(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
if model == "temperature":
|
||||
return {
|
||||
"loss_temperature": self.compute_loss_temperature(
|
||||
observations=observations,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
}
|
||||
|
||||
raise ValueError(f"Unknown model type: {model}")
|
||||
|
||||
def update_target_networks(self):
|
||||
"""Update target networks with exponential moving average"""
|
||||
for target_param, param in zip(
|
||||
self.critic_target.parameters(),
|
||||
self.critic_ensemble.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_param, param in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_param.data.copy_(
|
||||
param.data * self.config.critic_target_update_weight
|
||||
+ target_param.data * (1.0 - self.config.critic_target_update_weight)
|
||||
)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
"""Return the current temperature value, always in sync with log_alpha."""
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def compute_loss_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features: Tensor | None = None,
|
||||
next_observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
with torch.no_grad():
|
||||
next_action_preds, next_log_probs, _ = self.actor(next_observations, next_observation_features)
|
||||
|
||||
# 2- compute q targets
|
||||
q_targets = self.critic_forward(
|
||||
observations=next_observations,
|
||||
actions=next_action_preds,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# subsample critics to prevent overfitting if use high UTD (update to date)
|
||||
# TODO: Get indices before forward pass to avoid unnecessary computation
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
# critics subsample size
|
||||
min_q, _ = q_targets.min(dim=0) # Get values from min operation
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
# 3- compute predicted qs
|
||||
if self.config.num_discrete_actions is not None:
|
||||
# NOTE: We only want to keep the continuous action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions: Tensor = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
|
||||
# 4- Calculate loss
|
||||
# Compute state-action value loss (TD loss) for all of the Q functions in the ensemble.
|
||||
td_target_duplicate = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
# You compute the mean loss of the batch for each critic and then to compute the final loss you sum them up
|
||||
critics_loss = (
|
||||
F.mse_loss(
|
||||
input=q_preds,
|
||||
target=td_target_duplicate,
|
||||
reduction="none",
|
||||
).mean(dim=1)
|
||||
).sum()
|
||||
return critics_loss
|
||||
|
||||
def compute_loss_discrete_critic(
|
||||
self,
|
||||
observations,
|
||||
actions,
|
||||
rewards,
|
||||
next_observations,
|
||||
done,
|
||||
observation_features=None,
|
||||
next_observation_features=None,
|
||||
complementary_info=None,
|
||||
):
|
||||
# NOTE: We only want to keep the discrete action part
|
||||
# In the buffer we have the full action space (continuous + discrete)
|
||||
# We need to split them before concatenating them in the critic forward
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete)
|
||||
actions_discrete = actions_discrete.long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties: Tensor | None = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
# For DQN, select actions using online network, evaluate with target network
|
||||
next_discrete_qs = self.discrete_critic_forward(
|
||||
next_observations, use_target=False, observation_features=next_observation_features
|
||||
)
|
||||
best_next_discrete_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
# Get target Q-values from target network
|
||||
target_next_discrete_qs = self.discrete_critic_forward(
|
||||
observations=next_observations,
|
||||
use_target=True,
|
||||
observation_features=next_observation_features,
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for best actions
|
||||
target_next_discrete_q = torch.gather(
|
||||
target_next_discrete_qs, dim=1, index=best_next_discrete_action
|
||||
).squeeze(-1)
|
||||
|
||||
# Compute target Q-value with Bellman equation
|
||||
rewards_discrete = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_discrete = rewards + discrete_penalties
|
||||
target_discrete_q = rewards_discrete + (1 - done) * self.config.discount * target_next_discrete_q
|
||||
|
||||
# Get predicted Q-values for current observations
|
||||
predicted_discrete_qs = self.discrete_critic_forward(
|
||||
observations=observations, use_target=False, observation_features=observation_features
|
||||
)
|
||||
|
||||
# Use gather to select Q-values for taken actions
|
||||
predicted_discrete_q = torch.gather(predicted_discrete_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
# Compute MSE loss between predicted and target Q-values
|
||||
discrete_critic_loss = F.mse_loss(input=predicted_discrete_q, target=target_discrete_q)
|
||||
return discrete_critic_loss
|
||||
|
||||
def compute_loss_temperature(self, observations, observation_features: Tensor | None = None) -> Tensor:
|
||||
"""Compute the temperature loss"""
|
||||
# calculate temperature loss
|
||||
with torch.no_grad():
|
||||
_, log_probs, _ = self.actor(observations, observation_features)
|
||||
temperature_loss = (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
return temperature_loss
|
||||
|
||||
def compute_loss_actor(
|
||||
self,
|
||||
observations,
|
||||
observation_features: Tensor | None = None,
|
||||
) -> Tensor:
|
||||
actions_pi, log_probs, _ = self.actor(observations, observation_features)
|
||||
|
||||
q_preds = self.critic_forward(
|
||||
observations=observations,
|
||||
actions=actions_pi,
|
||||
use_target=False,
|
||||
observation_features=observation_features,
|
||||
)
|
||||
min_q_preds = q_preds.min(dim=0)[0]
|
||||
|
||||
actor_loss = ((self.temperature * log_probs) - min_q_preds).mean()
|
||||
return actor_loss
|
||||
|
||||
def _init_encoders(self):
|
||||
"""Initialize shared or separate encoders for actor and critic."""
|
||||
self.shared_encoder = self.config.shared_encoder
|
||||
self.encoder_critic = SACObservationEncoder(self.config)
|
||||
self.encoder_actor = (
|
||||
self.encoder_critic if self.shared_encoder else SACObservationEncoder(self.config)
|
||||
)
|
||||
|
||||
def _init_critics(self, continuous_action_dim):
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.encoder_critic, ensemble=heads)
|
||||
target_heads = [
|
||||
CriticHead(
|
||||
input_dim=self.encoder_critic.output_dim + continuous_action_dim,
|
||||
**asdict(self.config.critic_network_kwargs),
|
||||
)
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.encoder_critic, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critics()
|
||||
|
||||
def _init_discrete_critics(self):
|
||||
"""Build discrete discrete critic ensemble and target networks."""
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.encoder_critic,
|
||||
input_dim=self.encoder_critic.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
# TODO: (maractingi, azouitine) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.discrete_critic.state_dict())
|
||||
|
||||
def _init_actor(self, continuous_action_dim):
|
||||
"""Initialize policy actor network and default target entropy."""
|
||||
# NOTE: The actor select only the continuous action part
|
||||
def _init_actor(self, continuous_action_dim: int) -> None:
|
||||
self.actor = Policy(
|
||||
encoder=self.encoder_actor,
|
||||
network=MLP(input_dim=self.encoder_actor.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
encoder=self.encoder,
|
||||
network=MLP(input_dim=self.encoder.output_dim, **asdict(self.config.actor_network_kwargs)),
|
||||
action_dim=continuous_action_dim,
|
||||
encoder_is_shared=self.shared_encoder,
|
||||
encoder_is_shared=False,
|
||||
**asdict(self.config.policy_kwargs),
|
||||
)
|
||||
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = continuous_action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha)."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
def _init_discrete_critic(self) -> None:
|
||||
if self.config.num_discrete_actions is None:
|
||||
self.discrete_critic = None
|
||||
return
|
||||
self.discrete_critic = DiscreteCritic(
|
||||
encoder=self.encoder,
|
||||
input_dim=self.encoder.output_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
|
||||
|
||||
class SACObservationEncoder(nn.Module):
|
||||
|
||||
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -85,7 +85,7 @@ class SmolVLAConfig(PreTrainedConfig):
|
||||
scheduler_decay_lr: float = 2.5e-6
|
||||
|
||||
vlm_model_name: str = "HuggingFaceTB/SmolVLM2-500M-Video-Instruct" # Select the VLM backbone.
|
||||
load_vlm_weights: bool = False # Set to False in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
load_vlm_weights: bool = False # Set to True in case of training the expert from scratch. True when init from pretrained SmolVLA weights
|
||||
|
||||
add_image_special_tokens: bool = False # Whether to use special image tokens around image features.
|
||||
|
||||
|
||||
@@ -131,6 +131,15 @@ class _NormalizationMixin:
|
||||
if self.dtype is None:
|
||||
self.dtype = torch.float32
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
def _reshape_visual_stats(self) -> None:
|
||||
"""Reshape visual stats from ``[C]`` to ``[C, 1, 1]`` for image broadcasting."""
|
||||
for key, feature in self.features.items():
|
||||
if feature.type == FeatureType.VISUAL and key in self._tensor_stats:
|
||||
for stat_name, stat_tensor in self._tensor_stats[key].items():
|
||||
if isinstance(stat_tensor, Tensor) and stat_tensor.ndim == 1:
|
||||
self._tensor_stats[key][stat_name] = stat_tensor.reshape(-1, 1, 1)
|
||||
|
||||
def to(
|
||||
self, device: torch.device | str | None = None, dtype: torch.dtype | None = None
|
||||
@@ -149,6 +158,7 @@ class _NormalizationMixin:
|
||||
if dtype is not None:
|
||||
self.dtype = dtype
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype)
|
||||
self._reshape_visual_stats()
|
||||
return self
|
||||
|
||||
def state_dict(self) -> dict[str, Tensor]:
|
||||
@@ -198,6 +208,7 @@ class _NormalizationMixin:
|
||||
# Don't load from state_dict, keep the explicitly provided stats
|
||||
# But ensure _tensor_stats is properly initialized
|
||||
self._tensor_stats = to_tensor(self.stats, device=self.device, dtype=self.dtype) # type: ignore[assignment]
|
||||
self._reshape_visual_stats()
|
||||
return
|
||||
|
||||
# Normal behavior: load stats from state_dict
|
||||
@@ -208,6 +219,7 @@ class _NormalizationMixin:
|
||||
self._tensor_stats.setdefault(key, {})[stat_name] = tensor.to(
|
||||
dtype=torch.float32, device=self.device
|
||||
)
|
||||
self._reshape_visual_stats()
|
||||
|
||||
# Reconstruct the original stats dict from tensor stats for compatibility with to() method
|
||||
# and other functions that rely on self.stats
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
@@ -13,6 +11,3 @@
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .robstride import RobstrideMotorsBus
|
||||
from .tables import *
|
||||
+9
-19
@@ -61,7 +61,7 @@ from lerobot.cameras import opencv # noqa: F401
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import TransitionKey
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.queue import get_last_item_from_queue
|
||||
@@ -248,16 +248,16 @@ def act_with_policy(
|
||||
|
||||
logging.info("make_policy")
|
||||
|
||||
### Instantiate the policy in both the actor and learner processes
|
||||
### To avoid sending a SACPolicy object through the port, we create a policy instance
|
||||
### on both sides, the learner sends the updated parameters every n steps to update the actor's parameters
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
policy = policy.eval()
|
||||
assert isinstance(policy, nn.Module)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
# preprocessor, postprocessor = None, None
|
||||
|
||||
obs, info = online_env.reset()
|
||||
env_processor.reset()
|
||||
action_processor.reset()
|
||||
@@ -288,7 +288,6 @@ def act_with_policy(
|
||||
|
||||
# Time policy inference and check if it meets FPS requirement
|
||||
with policy_timer:
|
||||
# Extract observation from transition for policy
|
||||
action = policy.select_action(batch=observation)
|
||||
policy_fps = policy_timer.fps_last
|
||||
|
||||
@@ -649,12 +648,12 @@ def interactions_stream(
|
||||
# Policy functions
|
||||
|
||||
|
||||
def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device):
|
||||
def update_policy_parameters(policy: PreTrainedPolicy, parameters_queue: Queue, device):
|
||||
"""Load the latest policy weights from the learner."""
|
||||
bytes_state_dict = get_last_item_from_queue(parameters_queue, block=False)
|
||||
if bytes_state_dict is not None:
|
||||
logging.info("[ACTOR] Load new parameters from Learner.")
|
||||
state_dicts = bytes_to_state_dict(bytes_state_dict)
|
||||
|
||||
# TODO: check encoder parameter synchronization possible issues:
|
||||
# 1. When shared_encoder=True, we're loading stale encoder params from actor's state_dict
|
||||
# instead of the updated encoder params from critic (which is optimized separately)
|
||||
@@ -664,18 +663,9 @@ def update_policy_parameters(policy: SACPolicy, parameters_queue: Queue, device)
|
||||
# - Send critic's encoder state when shared_encoder=True
|
||||
# - Skip encoder params entirely when freeze_vision_encoder=True
|
||||
# - Ensure discrete_critic gets correct encoder state (currently uses encoder_critic)
|
||||
|
||||
# Load actor state dict
|
||||
actor_state_dict = move_state_dict_to_device(state_dicts["policy"], device=device)
|
||||
policy.actor.load_state_dict(actor_state_dict)
|
||||
|
||||
# Load discrete critic if present
|
||||
if hasattr(policy, "discrete_critic") and "discrete_critic" in state_dicts:
|
||||
discrete_critic_state_dict = move_state_dict_to_device(
|
||||
state_dicts["discrete_critic"], device=device
|
||||
)
|
||||
policy.discrete_critic.load_state_dict(discrete_critic_state_dict)
|
||||
logging.info("[ACTOR] Loaded discrete critic parameters from Learner.")
|
||||
state_dicts = move_state_dict_to_device(state_dicts, device=device)
|
||||
policy.load_state_dict(state_dicts)
|
||||
|
||||
|
||||
# Utilities functions
|
||||
|
||||
@@ -0,0 +1,70 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
RLAlgorithm,
|
||||
RLAlgorithmConfig,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt import RLTAlgorithm, RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
|
||||
|
||||
def make_algorithm(
|
||||
policy: torch.nn.Module,
|
||||
policy_cfg,
|
||||
*,
|
||||
algorithm_name: str,
|
||||
) -> RLAlgorithm:
|
||||
"""Construct an :class:`RLAlgorithm` from a policy and its config.
|
||||
|
||||
Algorithm selection is explicit via ``algorithm_name`` (from
|
||||
``cfg.algorithm``).
|
||||
|
||||
This is fully registry-driven — adding a new algorithm only requires
|
||||
registering an ``RLAlgorithmConfig`` subclass; no changes here.
|
||||
|
||||
The returned algorithm has **no optimizers** yet. On the learner side,
|
||||
call ``algorithm.make_optimizers()`` afterwards to create them. On the
|
||||
actor side (inference-only), leave them empty.
|
||||
|
||||
Args:
|
||||
policy: Instantiated policy (e.g. ``SACPolicy``).
|
||||
policy_cfg: The policy's ``PreTrainedConfig`` with the hyper-parameters
|
||||
expected by the algorithm config's ``from_policy_config`` class-method.
|
||||
algorithm_name: Algorithm registry key to instantiate.
|
||||
"""
|
||||
known = RLAlgorithmConfig.get_known_choices()
|
||||
if algorithm_name not in known:
|
||||
raise ValueError(f"No RLAlgorithmConfig registered for '{algorithm_name}'. Known: {list(known)}")
|
||||
|
||||
config_cls = RLAlgorithmConfig.get_choice_class(algorithm_name)
|
||||
algo_config = config_cls.from_policy_config(policy_cfg)
|
||||
return algo_config.build_algorithm(policy)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"RLAlgorithm",
|
||||
"RLAlgorithmConfig",
|
||||
"TrainingStats",
|
||||
"SACAlgorithm",
|
||||
"SACAlgorithmConfig",
|
||||
"RLTAlgorithm",
|
||||
"RLTAlgorithmConfig",
|
||||
"make_algorithm",
|
||||
]
|
||||
@@ -0,0 +1,183 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Base classes for RL algorithms.
|
||||
|
||||
Defines the abstract interface that every algorithm must implement, a registry
|
||||
for algorithm configs, and a dataclass for training statistics.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import draccus
|
||||
import torch
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrainingStats:
|
||||
"""Returned by ``algorithm.update()`` for logging and checkpointing."""
|
||||
|
||||
# Generic containers for all algorithms
|
||||
losses: dict[str, float] = field(default_factory=dict)
|
||||
grad_norms: dict[str, float] = field(default_factory=dict)
|
||||
extra: dict[str, float] = field(default_factory=dict)
|
||||
|
||||
def to_log_dict(self) -> dict[str, float]:
|
||||
"""Flatten all stats into a single dict for logging."""
|
||||
|
||||
d: dict[str, float] = {}
|
||||
for name, val in self.losses.items():
|
||||
d[name] = val
|
||||
for name, val in self.grad_norms.items():
|
||||
d[f"{name}_grad_norm"] = val
|
||||
for name, val in self.extra.items():
|
||||
d[name] = val
|
||||
return d
|
||||
|
||||
|
||||
@dataclass
|
||||
class RLAlgorithmConfig(draccus.ChoiceRegistry):
|
||||
"""Registry for algorithm configs."""
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLAlgorithm:
|
||||
"""Construct the :class:`RLAlgorithm` for this config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{type(self).__name__} must implement build_algorithm()")
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg: Any) -> RLAlgorithmConfig:
|
||||
"""Build an algorithm config from a policy config.
|
||||
|
||||
Must be overridden by every registered config subclass.
|
||||
"""
|
||||
raise NotImplementedError(f"{cls.__name__} must implement from_policy_config()")
|
||||
|
||||
|
||||
class RLAlgorithm(abc.ABC):
|
||||
"""Base for all RL algorithms."""
|
||||
|
||||
@abc.abstractmethod
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One complete training step.
|
||||
|
||||
The algorithm calls ``next(batch_iterator)`` as many times as it
|
||||
needs (e.g. ``utd_ratio`` times for SAC) to obtain fresh batches.
|
||||
The iterator is owned by the trainer; the algorithm just consumes
|
||||
from it.
|
||||
"""
|
||||
...
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
"""Whether this algorithm has an offline pretraining phase.
|
||||
|
||||
Algorithms like RLT (RL-token training) or ConRFT (Cal-QL pretraining)
|
||||
return ``True`` here. The learner checks this before the main online
|
||||
loop and routes to :meth:`offline_update` accordingly.
|
||||
"""
|
||||
return False
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One offline training step (called before any online collection).
|
||||
|
||||
Only called when :meth:`supports_offline_phase` returns ``True``.
|
||||
Uses the same iterator protocol as :meth:`update`.
|
||||
"""
|
||||
raise NotImplementedError(
|
||||
f"{type(self).__name__} does not implement offline_update(). "
|
||||
"Either override this method or return False from supports_offline_phase()."
|
||||
)
|
||||
|
||||
def transition_to_online(self) -> None: # noqa: B027
|
||||
"""Called once when switching from offline to online phase.
|
||||
|
||||
Use this to freeze modules trained offline, rebuild optimizers for the
|
||||
online phase, reset step counters, etc.
|
||||
|
||||
Default is a no-op; subclasses override when they have an offline phase.
|
||||
"""
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
) -> Iterator[BatchType]:
|
||||
"""Create the data iterator this algorithm needs.
|
||||
|
||||
The default implementation uses the standard ``data_mixer.get_iterator()``.
|
||||
Algorithms that need specialised sampling should override this method.
|
||||
"""
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create, store, and return the optimizers needed for training.
|
||||
|
||||
Called on the **learner** side after construction. Subclasses must
|
||||
override this with algorithm-specific optimizer setup.
|
||||
"""
|
||||
return {}
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Return optimizers for checkpointing / external scheduling."""
|
||||
return {}
|
||||
|
||||
@property
|
||||
def optimization_step(self) -> int:
|
||||
"""Current learner optimization step.
|
||||
|
||||
Part of the stable contract for checkpoint/resume. Algorithms can
|
||||
either use this default storage or override for custom behavior.
|
||||
"""
|
||||
return getattr(self, "_optimization_step", 0)
|
||||
|
||||
@optimization_step.setter
|
||||
def optimization_step(self, value: int) -> None:
|
||||
self._optimization_step = int(value)
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors."""
|
||||
return {}
|
||||
|
||||
@abc.abstractmethod
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner (inverse of ``get_weights``)."""
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
"""Pre-compute observation features (e.g. frozen encoder cache).
|
||||
|
||||
Returns ``(None, None)`` when caching is not applicable.
|
||||
"""
|
||||
return None, None
|
||||
+4
-16
@@ -1,6 +1,4 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@@ -14,17 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
__all__ = ["RLTAlgorithm", "RLTAlgorithmConfig"]
|
||||
@@ -0,0 +1,83 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("rlt")
|
||||
@dataclass
|
||||
class RLTAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""RLT-specific hyper-parameters that control the update loop."""
|
||||
|
||||
# ── Action chunks ──
|
||||
chunk_size: int = 10
|
||||
chunk_stride: int = 2
|
||||
|
||||
# ── Update cadence ──
|
||||
utd_ratio: int = 5
|
||||
policy_update_freq: int = 2
|
||||
clip_grad_norm: float = 10.0
|
||||
|
||||
# ── Learning rates ──
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
rl_token_lr: float = 1e-4
|
||||
|
||||
# ── TD learning ──
|
||||
discount: float = 0.99
|
||||
tau: float = 0.005
|
||||
num_critics: int = 2
|
||||
|
||||
# ── Policy constraint (paper Eq. 5) ──
|
||||
bc_reg_coeff: float = 0.1
|
||||
ref_dropout: float = 0.5
|
||||
|
||||
# ── Offline RL-token training ──
|
||||
vla_finetune_weight: float = 0.0
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> RLTAlgorithmConfig:
|
||||
"""Build from an existing ``RLTConfig`` (cfg.policy)."""
|
||||
return cls(
|
||||
chunk_size=policy_cfg.chunk_size,
|
||||
chunk_stride=policy_cfg.chunk_stride,
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.clip_grad_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
rl_token_lr=policy_cfg.rl_token_lr,
|
||||
discount=policy_cfg.discount,
|
||||
tau=policy_cfg.tau,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
bc_reg_coeff=policy_cfg.bc_reg_coeff,
|
||||
ref_dropout=policy_cfg.ref_dropout,
|
||||
vla_finetune_weight=policy_cfg.vla_finetune_weight,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> RLTAlgorithm:
|
||||
from lerobot.rl.algorithms.rlt.rlt_algorithm import RLTAlgorithm
|
||||
|
||||
return RLTAlgorithm(policy=policy, config=self)
|
||||
@@ -0,0 +1,319 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""RLT (RL Token) algorithm.
|
||||
|
||||
Implements the two-stage training from "RL Token: Bootstrapping Online RL
|
||||
with Vision-Language-Action Models" (Xu et al., Physical Intelligence, 2026).
|
||||
|
||||
Stage 1 (offline): Train RL-token encoder/decoder via reconstruction loss.
|
||||
Stage 2 (online): Train actor-critic with chunked TD, BC regularization,
|
||||
reference-action pass-through, and reference-action dropout.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import copy
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.rlt.modeling_rlt import MLP, RLTPolicy
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.rlt.configuration_rlt import RLTAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
class RLTCritic(nn.Module):
|
||||
"""Q-function over (state, action_chunk) pairs.
|
||||
|
||||
Paper Eq. 3: Q_psi(x, a_{1:C})
|
||||
|
||||
Training-only component — lives on the algorithm side, not in the policy.
|
||||
"""
|
||||
|
||||
def __init__(self, state_dim: int, action_chunk_dim: int, hidden_dims: list[int]):
|
||||
super().__init__()
|
||||
self.net = MLP(state_dim + action_chunk_dim, hidden_dims, output_dim=1)
|
||||
|
||||
def forward(self, state: Tensor, action_chunk: Tensor) -> Tensor:
|
||||
x = torch.cat([state, action_chunk], dim=-1)
|
||||
return self.net(x)
|
||||
|
||||
|
||||
class RLTAlgorithm(RLAlgorithm):
|
||||
"""RL Token: lightweight actor-critic on frozen VLA features.
|
||||
|
||||
Owns the ``RLTPolicy`` (RL-token encoder/decoder + actor), a critic
|
||||
ensemble, and target networks. All VLA-specific logic (embedding
|
||||
extraction, reference actions) lives in ``_prepare_forward_batch``.
|
||||
"""
|
||||
|
||||
def __init__(self, policy: RLTPolicy, config: RLTAlgorithmConfig):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._is_online = False
|
||||
|
||||
self._init_critics()
|
||||
self._move_to_device()
|
||||
|
||||
# ── Initialization ───────────────────────────────────────────────
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
state_dim = self.policy._state_dim
|
||||
action_chunk_dim = self.policy._action_chunk_dim
|
||||
hidden_dims = self.policy.config.critic.hidden_dims
|
||||
|
||||
self.critics = torch.nn.ModuleList(
|
||||
[RLTCritic(state_dim, action_chunk_dim, hidden_dims) for _ in range(self.config.num_critics)]
|
||||
)
|
||||
self.critic_targets = torch.nn.ModuleList([copy.deepcopy(c) for c in self.critics])
|
||||
for ct in self.critic_targets:
|
||||
ct.requires_grad_(False)
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
self.critics.to(self._device)
|
||||
self.critic_targets.to(self._device)
|
||||
|
||||
# ── Offline phase (Stage 1): RL-token training ───────────────────
|
||||
|
||||
def supports_offline_phase(self) -> bool:
|
||||
return True
|
||||
|
||||
def offline_update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Train RL-token encoder/decoder on demonstration data.
|
||||
|
||||
Paper Eq. 2: L_ro = E[ sum_i || h(d([z_rl, z_bar_{1:i-1}]))_i - z_bar_i ||^2 ]
|
||||
"""
|
||||
batch = next(batch_iterator)
|
||||
|
||||
vla_embeddings = batch["state"]["observation.vla_embeddings"].to(self._device)
|
||||
z_vla = vla_embeddings.detach() # stop-gradient on VLA embeddings
|
||||
|
||||
z_rl = self.policy.rl_token_encoder(z_vla)
|
||||
z_reconstructed = self.policy.rl_token_decoder(z_rl, z_vla)
|
||||
|
||||
loss_ro = F.mse_loss(z_reconstructed, z_vla)
|
||||
|
||||
self.optimizers["rl_token"].zero_grad()
|
||||
loss_ro.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
list(self.policy.rl_token_encoder.parameters()) + list(self.policy.rl_token_decoder.parameters()),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
)
|
||||
self.optimizers["rl_token"].step()
|
||||
|
||||
self._optimization_step += 1
|
||||
return TrainingStats(losses={"loss_rl_token": loss_ro.item()})
|
||||
|
||||
def transition_to_online(self) -> None:
|
||||
"""Freeze RL-token modules; rebuild optimizers for actor-critic only."""
|
||||
self.policy.rl_token_encoder.requires_grad_(False)
|
||||
self.policy.rl_token_decoder.requires_grad_(False)
|
||||
self._is_online = True
|
||||
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
self._optimization_step = 0
|
||||
|
||||
# ── Online phase (Stage 2): Actor-Critic ─────────────────────────
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""One full RLT update step with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches. First ``utd_ratio - 1`` are critic-only;
|
||||
the last batch also updates the actor (every ``policy_update_freq`` steps).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
self._critic_step(fb)
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
fb = self._prepare_forward_batch(batch)
|
||||
critic_loss = self._critic_step(fb)
|
||||
|
||||
stats = TrainingStats(losses={"loss_critic": critic_loss})
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
actor_loss, bc_loss, q_val = self._actor_step(fb)
|
||||
stats.losses["loss_actor"] = actor_loss
|
||||
stats.extra["bc_loss"] = bc_loss
|
||||
stats.extra["q_value_mean"] = q_val
|
||||
|
||||
self._update_target_networks()
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Convert a replay batch into algorithm-ready tensors.
|
||||
|
||||
Extracts RL-token from VLA embeddings, builds RL state, reads
|
||||
reference action from complementary_info.
|
||||
"""
|
||||
obs = batch["state"]
|
||||
next_obs = batch["next_state"]
|
||||
device = self._device
|
||||
|
||||
vla_emb = obs["observation.vla_embeddings"].to(device)
|
||||
next_vla_emb = next_obs["observation.vla_embeddings"].to(device)
|
||||
|
||||
with torch.no_grad():
|
||||
z_rl = self.policy.rl_token_encoder(vla_emb)
|
||||
z_rl_next = self.policy.rl_token_encoder(next_vla_emb)
|
||||
|
||||
parts = [z_rl]
|
||||
next_parts = [z_rl_next]
|
||||
if "observation.state" in obs and self.policy._proprioception_dim > 0:
|
||||
prop = obs["observation.state"].to(device)
|
||||
next_prop = next_obs["observation.state"].to(device)
|
||||
parts.append(prop)
|
||||
next_parts.append(next_prop)
|
||||
|
||||
state = torch.cat(parts, dim=-1)
|
||||
next_state = torch.cat(next_parts, dim=-1)
|
||||
|
||||
action = batch[ACTION].to(device)
|
||||
reward = batch["reward"].to(device)
|
||||
done = batch["done"].to(device)
|
||||
|
||||
ref_action = None
|
||||
comp_info = batch.get("complementary_info")
|
||||
if comp_info is not None and "reference_action" in comp_info:
|
||||
ref_action = comp_info["reference_action"].to(device)
|
||||
|
||||
return {
|
||||
"state": state,
|
||||
"next_state": next_state,
|
||||
"action": action,
|
||||
"reward": reward,
|
||||
"done": done,
|
||||
"reference_action": ref_action,
|
||||
}
|
||||
|
||||
def _critic_step(self, fb: dict[str, Any]) -> float:
|
||||
"""Paper Eq. 3: chunked TD with clipped double-Q target."""
|
||||
state = fb["state"]
|
||||
next_state = fb["next_state"]
|
||||
action = fb["action"]
|
||||
reward = fb["reward"]
|
||||
done = fb["done"]
|
||||
|
||||
with torch.no_grad():
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros_like(action)
|
||||
next_action = self.policy.actor(next_state, ref)
|
||||
|
||||
target_qs = [ct(next_state, next_action) for ct in self.critic_targets]
|
||||
min_target_q = torch.min(torch.cat(target_qs, dim=-1), dim=-1, keepdim=True).values
|
||||
|
||||
discount_chunk = self.config.discount**self.config.chunk_size
|
||||
td_target = reward.unsqueeze(-1) + (1 - done.unsqueeze(-1)) * discount_chunk * min_target_q
|
||||
|
||||
q_preds = [c(state, action) for c in self.critics]
|
||||
loss = sum(F.mse_loss(q, td_target) for q in q_preds)
|
||||
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.critics.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["critic"].step()
|
||||
return loss.item()
|
||||
|
||||
def _actor_step(self, fb: dict[str, Any]) -> tuple[float, float, float]:
|
||||
"""Paper Eq. 5: maximize Q while staying near VLA reference.
|
||||
|
||||
L_pi(theta) = E[ -Q(x, a) + beta * ||a - a_tilde||^2 ]
|
||||
With reference-action dropout applied to the actor's ref input.
|
||||
"""
|
||||
state = fb["state"]
|
||||
ref = fb.get("reference_action")
|
||||
if ref is None:
|
||||
ref = torch.zeros(state.shape[0], self.policy._action_chunk_dim, device=self._device)
|
||||
|
||||
# Reference-action dropout (paper Section IV-B)
|
||||
mask = (torch.rand(ref.shape[0], 1, device=self._device) > self.config.ref_dropout).float()
|
||||
ref_input = ref * mask
|
||||
|
||||
action = self.policy.actor(state, ref_input)
|
||||
|
||||
q_value = self.critics[0](state, action)
|
||||
|
||||
bc_loss = F.mse_loss(action, ref)
|
||||
|
||||
loss = -q_value.mean() + self.config.bc_reg_coeff * bc_loss
|
||||
|
||||
self.optimizers["actor"].zero_grad()
|
||||
loss.backward()
|
||||
torch.nn.utils.clip_grad_norm_(self.policy.actor.parameters(), max_norm=self.config.clip_grad_norm)
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
return loss.item(), bc_loss.item(), q_value.mean().item()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.tau
|
||||
for critic, target in zip(self.critics, self.critic_targets, strict=True):
|
||||
for p, tp in zip(critic.parameters(), target.parameters(), strict=True):
|
||||
tp.data.copy_(tau * p.data + (1 - tau) * tp.data)
|
||||
|
||||
# ── Optimizer management ─────────────────────────────────────────
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create optimizers. Initially for RL-token (Stage 1)."""
|
||||
self.optimizers = {
|
||||
"rl_token": torch.optim.Adam(
|
||||
list(self.policy.rl_token_encoder.parameters())
|
||||
+ list(self.policy.rl_token_decoder.parameters()),
|
||||
lr=self.config.rl_token_lr,
|
||||
),
|
||||
"actor": torch.optim.Adam(self.policy.actor.parameters(), lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critics.parameters(), lr=self.config.critic_lr),
|
||||
}
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
# ── Weight sync ──────────────────────────────────────────────────
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Push actor + RL-token encoder to actors (small footprint)."""
|
||||
weights = {
|
||||
"actor": self.policy.actor.state_dict(),
|
||||
"rl_token_encoder": self.policy.rl_token_encoder.state_dict(),
|
||||
}
|
||||
return {k: {kk: vv.cpu() for kk, vv in v.items()} for k, v in weights.items()}
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
if "actor" in weights:
|
||||
self.policy.actor.load_state_dict({k: v.to(device) for k, v in weights["actor"].items()})
|
||||
if "rl_token_encoder" in weights:
|
||||
self.policy.rl_token_encoder.load_state_dict(
|
||||
{k: v.to(device) for k, v in weights["rl_token_encoder"].items()}
|
||||
)
|
||||
@@ -0,0 +1,18 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
__all__ = ["SACAlgorithm", "SACAlgorithmConfig"]
|
||||
@@ -0,0 +1,81 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC algorithm configuration."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.policies.sac.configuration_sac import CriticNetworkConfig
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
|
||||
@RLAlgorithmConfig.register_subclass("sac")
|
||||
@dataclass
|
||||
class SACAlgorithmConfig(RLAlgorithmConfig):
|
||||
"""SAC-specific hyper-parameters that control the update loop."""
|
||||
|
||||
utd_ratio: int = 1
|
||||
policy_update_freq: int = 1
|
||||
clip_grad_norm: float = 40.0
|
||||
actor_lr: float = 3e-4
|
||||
critic_lr: float = 3e-4
|
||||
temperature_lr: float = 3e-4
|
||||
discount: float = 0.99
|
||||
temperature_init: float = 1.0
|
||||
target_entropy: float | None = None
|
||||
use_backup_entropy: bool = True
|
||||
critic_target_update_weight: float = 0.005
|
||||
num_critics: int = 2
|
||||
num_subsample_critics: int | None = None
|
||||
num_discrete_actions: int | None = None
|
||||
shared_encoder: bool = True
|
||||
critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
discrete_critic_network_kwargs: CriticNetworkConfig = field(default_factory=CriticNetworkConfig)
|
||||
use_torch_compile: bool = True
|
||||
|
||||
@classmethod
|
||||
def from_policy_config(cls, policy_cfg) -> SACAlgorithmConfig:
|
||||
"""Build from an existing ``SACConfig`` (cfg.policy) for backwards compat."""
|
||||
return cls(
|
||||
utd_ratio=policy_cfg.utd_ratio,
|
||||
policy_update_freq=policy_cfg.policy_update_freq,
|
||||
clip_grad_norm=policy_cfg.grad_clip_norm,
|
||||
actor_lr=policy_cfg.actor_lr,
|
||||
critic_lr=policy_cfg.critic_lr,
|
||||
temperature_lr=policy_cfg.temperature_lr,
|
||||
discount=policy_cfg.discount,
|
||||
temperature_init=policy_cfg.temperature_init,
|
||||
target_entropy=policy_cfg.target_entropy,
|
||||
use_backup_entropy=policy_cfg.use_backup_entropy,
|
||||
critic_target_update_weight=policy_cfg.critic_target_update_weight,
|
||||
num_critics=policy_cfg.num_critics,
|
||||
num_subsample_critics=policy_cfg.num_subsample_critics,
|
||||
num_discrete_actions=policy_cfg.num_discrete_actions,
|
||||
shared_encoder=policy_cfg.shared_encoder,
|
||||
critic_network_kwargs=policy_cfg.critic_network_kwargs,
|
||||
discrete_critic_network_kwargs=policy_cfg.discrete_critic_network_kwargs,
|
||||
use_torch_compile=policy_cfg.use_torch_compile,
|
||||
)
|
||||
|
||||
def build_algorithm(self, policy: torch.nn.Module) -> SACAlgorithm:
|
||||
from lerobot.rl.algorithms.sac.sac_algorithm import SACAlgorithm
|
||||
|
||||
return SACAlgorithm(policy=policy, config=self)
|
||||
@@ -0,0 +1,409 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""SAC (Soft Actor-Critic) algorithm.
|
||||
|
||||
This module encapsulates all SAC-specific training logic (critic, actor,
|
||||
temperature, and discrete-critic updates) behind the ``RLAlgorithm`` interface.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from collections.abc import Iterator
|
||||
from dataclasses import asdict
|
||||
from typing import Any
|
||||
|
||||
import einops
|
||||
import numpy as np
|
||||
import torch
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F # noqa: N812
|
||||
from torch import Tensor
|
||||
from torch.optim import Optimizer
|
||||
|
||||
from lerobot.policies.sac.modeling_sac import (
|
||||
DISCRETE_DIMENSION_INDEX,
|
||||
CriticEnsemble,
|
||||
CriticHead,
|
||||
DiscreteCritic,
|
||||
SACObservationEncoder,
|
||||
SACPolicy,
|
||||
)
|
||||
from lerobot.policies.utils import get_device_from_parameters
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.algorithms.sac.configuration_sac import SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION
|
||||
from lerobot.utils.transition import move_state_dict_to_device
|
||||
|
||||
|
||||
class SACAlgorithm(RLAlgorithm):
|
||||
"""Soft Actor-Critic with optional discrete-critic head.
|
||||
|
||||
Owns the ``SACPolicy`` and its optimizers. All loss methods call
|
||||
``self.policy(batch_dict)`` rather than reaching into ``self.policy.actor``
|
||||
directly, so any policy that returns ``{"action", "log_prob"}`` from its
|
||||
``forward()`` is compatible.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
policy: SACPolicy,
|
||||
config: SACAlgorithmConfig,
|
||||
):
|
||||
self.policy = policy
|
||||
self.config = config
|
||||
self.optimizers: dict[str, Optimizer] = {}
|
||||
self._optimization_step: int = 0
|
||||
|
||||
self._device = get_device_from_parameters(self.policy)
|
||||
self._init_critic_encoder()
|
||||
self._init_critics()
|
||||
self._init_temperature()
|
||||
self._move_to_device()
|
||||
|
||||
def _init_critic_encoder(self) -> None:
|
||||
"""Build or share the encoder used by critics."""
|
||||
if self.config.shared_encoder:
|
||||
self.critic_encoder = self.policy.encoder
|
||||
self.policy.actor.encoder_is_shared = True
|
||||
else:
|
||||
self.critic_encoder = SACObservationEncoder(self.policy.config)
|
||||
|
||||
def _init_critics(self) -> None:
|
||||
"""Build critic ensemble, targets, and optional discrete critic."""
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
input_dim = self.critic_encoder.output_dim + action_dim
|
||||
|
||||
heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_ensemble = CriticEnsemble(encoder=self.critic_encoder, ensemble=heads)
|
||||
|
||||
target_heads = [
|
||||
CriticHead(input_dim=input_dim, **asdict(self.config.critic_network_kwargs))
|
||||
for _ in range(self.config.num_critics)
|
||||
]
|
||||
self.critic_target = CriticEnsemble(encoder=self.critic_encoder, ensemble=target_heads)
|
||||
self.critic_target.load_state_dict(self.critic_ensemble.state_dict())
|
||||
|
||||
if self.config.use_torch_compile:
|
||||
self.critic_ensemble = torch.compile(self.critic_ensemble)
|
||||
self.critic_target = torch.compile(self.critic_target)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self._init_discrete_critic_target()
|
||||
|
||||
def _init_discrete_critic_target(self) -> None:
|
||||
"""Build only the target discrete critic."""
|
||||
input_dim = self.critic_encoder.output_dim
|
||||
self.discrete_critic_target = DiscreteCritic(
|
||||
encoder=self.critic_encoder,
|
||||
input_dim=input_dim,
|
||||
output_dim=self.config.num_discrete_actions,
|
||||
**asdict(self.config.discrete_critic_network_kwargs),
|
||||
)
|
||||
# TODO: (kmeftah) Compile the discrete critic
|
||||
self.discrete_critic_target.load_state_dict(self.policy.discrete_critic.state_dict())
|
||||
|
||||
def _init_temperature(self) -> None:
|
||||
"""Set up temperature parameter (log_alpha) and default target entropy."""
|
||||
temp_init = self.config.temperature_init
|
||||
self.log_alpha = nn.Parameter(torch.tensor([math.log(temp_init)]))
|
||||
|
||||
action_dim = self.policy.config.output_features[ACTION].shape[0]
|
||||
self.target_entropy = self.config.target_entropy
|
||||
if self.target_entropy is None:
|
||||
dim = action_dim + (1 if self.config.num_discrete_actions is not None else 0)
|
||||
self.target_entropy = -np.prod(dim) / 2
|
||||
|
||||
def _move_to_device(self) -> None:
|
||||
"""Move algorithm-owned modules to the policy device."""
|
||||
self.critic_ensemble.to(self._device)
|
||||
self.critic_target.to(self._device)
|
||||
self.log_alpha = nn.Parameter(self.log_alpha.data.to(self._device))
|
||||
if hasattr(self, "discrete_critic_target"):
|
||||
self.discrete_critic_target.to(self._device)
|
||||
|
||||
@property
|
||||
def temperature(self) -> float:
|
||||
return self.log_alpha.exp().item()
|
||||
|
||||
def update(self, batch_iterator: Iterator[BatchType]) -> TrainingStats:
|
||||
"""Run one full SAC update with UTD critic warm-up.
|
||||
|
||||
Pulls ``utd_ratio`` batches from ``batch_iterator``. The first
|
||||
``utd_ratio - 1`` batches are used for critic-only warm-up steps;
|
||||
the last batch drives the full update (critic + actor + temperature).
|
||||
"""
|
||||
for _ in range(self.config.utd_ratio - 1):
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
self._update_target_networks()
|
||||
|
||||
batch = next(batch_iterator)
|
||||
forward_batch = self._prepare_forward_batch(batch)
|
||||
|
||||
loss_critic = self._compute_loss_critic(forward_batch)
|
||||
self.optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
self.critic_ensemble.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["critic"].step()
|
||||
|
||||
critic_loss_val = loss_critic.item()
|
||||
stats = TrainingStats(
|
||||
losses={"loss_critic": critic_loss_val},
|
||||
grad_norms={"critic": critic_grad_norm},
|
||||
)
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
loss_discrete = self._compute_loss_discrete_critic(forward_batch)
|
||||
self.optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete.backward()
|
||||
dc_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.discrete_critic.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["discrete_critic"].step()
|
||||
stats.losses["loss_discrete_critic"] = loss_discrete.item()
|
||||
stats.grad_norms["discrete_critic"] = dc_grad
|
||||
|
||||
if self._optimization_step % self.config.policy_update_freq == 0:
|
||||
for _ in range(self.config.policy_update_freq):
|
||||
actor_loss = self._compute_loss_actor(forward_batch)
|
||||
self.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
actor_grad = torch.nn.utils.clip_grad_norm_(
|
||||
self.policy.actor.parameters(),
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["actor"].step()
|
||||
|
||||
temp_loss = self._compute_loss_temperature(forward_batch)
|
||||
self.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
temp_grad = torch.nn.utils.clip_grad_norm_(
|
||||
[self.log_alpha],
|
||||
max_norm=self.config.clip_grad_norm,
|
||||
).item()
|
||||
self.optimizers["temperature"].step()
|
||||
|
||||
stats.losses["loss_actor"] = actor_loss.item()
|
||||
stats.losses["loss_temperature"] = temp_loss.item()
|
||||
stats.grad_norms["actor"] = actor_grad
|
||||
stats.grad_norms["temperature"] = temp_grad
|
||||
stats.extra["temperature"] = self.temperature
|
||||
|
||||
self._update_target_networks()
|
||||
|
||||
self._optimization_step += 1
|
||||
return stats
|
||||
|
||||
def _compute_loss_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
next_output = self.policy({"state": next_observations, "observation_feature": next_obs_features})
|
||||
next_actions = next_output["action"]
|
||||
next_log_probs = next_output["log_prob"]
|
||||
|
||||
q_targets = self.critic_target(next_observations, next_actions, next_obs_features)
|
||||
|
||||
if self.config.num_subsample_critics is not None:
|
||||
indices = torch.randperm(self.config.num_critics)
|
||||
indices = indices[: self.config.num_subsample_critics]
|
||||
q_targets = q_targets[indices]
|
||||
|
||||
min_q, _ = q_targets.min(dim=0)
|
||||
if self.config.use_backup_entropy:
|
||||
min_q = min_q - (self.temperature * next_log_probs)
|
||||
|
||||
td_target = rewards + (1 - done) * self.config.discount * min_q
|
||||
|
||||
if self.config.num_discrete_actions is not None:
|
||||
actions = actions[:, :DISCRETE_DIMENSION_INDEX]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions, obs_features)
|
||||
|
||||
td_target_dup = einops.repeat(td_target, "b -> e b", e=q_preds.shape[0])
|
||||
critics_loss = (F.mse_loss(input=q_preds, target=td_target_dup, reduction="none").mean(dim=1)).sum()
|
||||
return critics_loss
|
||||
|
||||
def _compute_loss_discrete_critic(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
next_obs_features = batch.get("next_observation_feature")
|
||||
complementary_info = batch.get("complementary_info")
|
||||
|
||||
actions_discrete: Tensor = actions[:, DISCRETE_DIMENSION_INDEX:].clone()
|
||||
actions_discrete = torch.round(actions_discrete).long()
|
||||
|
||||
discrete_penalties: Tensor | None = None
|
||||
if complementary_info is not None:
|
||||
discrete_penalties = complementary_info.get("discrete_penalty")
|
||||
|
||||
with torch.no_grad():
|
||||
next_discrete_qs = self.policy.discrete_critic(next_observations, next_obs_features)
|
||||
best_next_action = torch.argmax(next_discrete_qs, dim=-1, keepdim=True)
|
||||
|
||||
target_next_qs = self.discrete_critic_target(next_observations, next_obs_features)
|
||||
target_next_q = torch.gather(target_next_qs, dim=1, index=best_next_action).squeeze(-1)
|
||||
|
||||
rewards_disc = rewards
|
||||
if discrete_penalties is not None:
|
||||
rewards_disc = rewards + discrete_penalties
|
||||
target_q = rewards_disc + (1 - done) * self.config.discount * target_next_q
|
||||
|
||||
predicted_qs = self.policy.discrete_critic(observations, obs_features)
|
||||
predicted_q = torch.gather(predicted_qs, dim=1, index=actions_discrete).squeeze(-1)
|
||||
|
||||
return F.mse_loss(input=predicted_q, target=target_q)
|
||||
|
||||
def _compute_loss_actor(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
actions_pi = output["action"]
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
q_preds = self.critic_ensemble(observations, actions_pi, obs_features)
|
||||
min_q = q_preds.min(dim=0)[0]
|
||||
|
||||
return ((self.temperature * log_probs) - min_q).mean()
|
||||
|
||||
def _compute_loss_temperature(self, batch: dict[str, Any]) -> Tensor:
|
||||
observations = batch["state"]
|
||||
obs_features = batch.get("observation_feature")
|
||||
|
||||
with torch.no_grad():
|
||||
output = self.policy({"state": observations, "observation_feature": obs_features})
|
||||
log_probs = output["log_prob"]
|
||||
|
||||
return (-self.log_alpha.exp() * (log_probs + self.target_entropy)).mean()
|
||||
|
||||
def _update_target_networks(self) -> None:
|
||||
tau = self.config.critic_target_update_weight
|
||||
for target_p, p in zip(
|
||||
self.critic_target.parameters(), self.critic_ensemble.parameters(), strict=True
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
if self.config.num_discrete_actions is not None:
|
||||
for target_p, p in zip(
|
||||
self.discrete_critic_target.parameters(),
|
||||
self.policy.discrete_critic.parameters(),
|
||||
strict=True,
|
||||
):
|
||||
target_p.data.copy_(p.data * tau + target_p.data * (1.0 - tau))
|
||||
|
||||
def _prepare_forward_batch(self, batch: BatchType) -> dict[str, Any]:
|
||||
"""Build the dict expected by loss computation from a sampled batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
|
||||
observation_features, next_observation_features = self.get_observation_features(
|
||||
observations, next_observations
|
||||
)
|
||||
forward_batch: dict[str, Any] = {
|
||||
ACTION: batch[ACTION],
|
||||
"reward": batch["reward"],
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": batch["done"],
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
if "complementary_info" in batch:
|
||||
forward_batch["complementary_info"] = batch["complementary_info"]
|
||||
return forward_batch
|
||||
|
||||
def make_optimizers(self) -> dict[str, Optimizer]:
|
||||
"""Create Adam optimizers for the SAC components and store them."""
|
||||
actor_params = [
|
||||
p
|
||||
for n, p in self.policy.actor.named_parameters()
|
||||
if not self.config.shared_encoder or not n.startswith("encoder")
|
||||
]
|
||||
self.optimizers = {
|
||||
"actor": torch.optim.Adam(actor_params, lr=self.config.actor_lr),
|
||||
"critic": torch.optim.Adam(self.critic_ensemble.parameters(), lr=self.config.critic_lr),
|
||||
"temperature": torch.optim.Adam([self.log_alpha], lr=self.config.temperature_lr),
|
||||
}
|
||||
if self.config.num_discrete_actions is not None:
|
||||
self.optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
self.policy.discrete_critic.parameters(), lr=self.config.critic_lr
|
||||
)
|
||||
return self.optimizers
|
||||
|
||||
def get_optimizers(self) -> dict[str, Optimizer]:
|
||||
return self.optimizers
|
||||
|
||||
def get_weights(self) -> dict[str, Any]:
|
||||
"""Policy state-dict to push to actors (includes actor + discrete critic)."""
|
||||
return move_state_dict_to_device(self.policy.state_dict(), device="cpu")
|
||||
|
||||
def load_weights(self, weights: dict[str, Any], device: str | torch.device = "cpu") -> None:
|
||||
"""Load policy state-dict received from the learner."""
|
||||
state = move_state_dict_to_device(weights, device=device)
|
||||
self.policy.load_state_dict(state)
|
||||
|
||||
@torch.no_grad()
|
||||
def get_observation_features(
|
||||
self, observations: Tensor, next_observations: Tensor
|
||||
) -> tuple[Tensor | None, Tensor | None]:
|
||||
if not self.config.shared_encoder:
|
||||
return None, None
|
||||
if self.policy.config.vision_encoder_name is None or not self.policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
if not self.policy.encoder.has_images:
|
||||
return None, None
|
||||
observation_features = self.policy.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = self.policy.encoder.get_cached_image_features(next_observations)
|
||||
return observation_features, next_observation_features
|
||||
@@ -0,0 +1,17 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.rl.data_sources.data_mixer import BatchType, DataMixer, OnlineOfflineMixer
|
||||
|
||||
__all__ = ["BatchType", "DataMixer", "OnlineOfflineMixer"]
|
||||
@@ -0,0 +1,94 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import abc
|
||||
from typing import Any
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
|
||||
BatchType = dict[str, Any]
|
||||
|
||||
|
||||
class DataMixer(abc.ABC):
|
||||
"""Abstract interface for all data mixing strategies.
|
||||
|
||||
Subclasses must implement ``sample(batch_size)`` and may override
|
||||
``get_iterator`` for specialised iteration.
|
||||
"""
|
||||
|
||||
@abc.abstractmethod
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
"""Draw one batch of ``batch_size`` transitions."""
|
||||
...
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Infinite iterator that yields batches.
|
||||
|
||||
The default implementation repeatedly calls ``self.sample()``.
|
||||
Subclasses with underlying buffer iterators (async prefetch)
|
||||
should override this for better throughput.
|
||||
"""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
|
||||
|
||||
class OnlineOfflineMixer(DataMixer):
|
||||
"""Mixes transitions from an online and an optional offline replay buffer.
|
||||
|
||||
When both buffers are present, each batch is constructed by sampling
|
||||
``ceil(batch_size * online_ratio)`` from the online buffer and the
|
||||
remainder from the offline buffer, then concatenating.
|
||||
|
||||
This mixer assumes both online and offline buffers are present.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
online_buffer: ReplayBuffer,
|
||||
offline_buffer: ReplayBuffer | None = None,
|
||||
online_ratio: float = 1.0,
|
||||
):
|
||||
if not 0.0 <= online_ratio <= 1.0:
|
||||
raise ValueError(f"online_ratio must be in [0, 1], got {online_ratio}")
|
||||
self.online_buffer = online_buffer
|
||||
self.offline_buffer = offline_buffer
|
||||
self.online_ratio = online_ratio
|
||||
|
||||
def sample(self, batch_size: int) -> BatchType:
|
||||
if self.offline_buffer is None:
|
||||
return self.online_buffer.sample(batch_size)
|
||||
|
||||
n_online = max(1, int(batch_size * self.online_ratio))
|
||||
n_offline = batch_size - n_online
|
||||
|
||||
online_batch = self.online_buffer.sample(n_online)
|
||||
offline_batch = self.offline_buffer.sample(n_offline)
|
||||
return concatenate_batch_transitions(online_batch, offline_batch)
|
||||
|
||||
def get_iterator(
|
||||
self,
|
||||
batch_size: int,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
"""Yield batches from online/offline mixed sampling."""
|
||||
while True:
|
||||
yield self.sample(batch_size)
|
||||
+91
-283
@@ -65,9 +65,11 @@ from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.policies.factory import make_policy
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.buffer import ReplayBuffer, concatenate_batch_transitions
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.process import ProcessSignalHandler
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.rl.wandb_utils import WandBLogger
|
||||
from lerobot.robots import so_follower # noqa: F401
|
||||
from lerobot.teleoperators import gamepad, so_leader # noqa: F401
|
||||
@@ -93,7 +95,7 @@ from lerobot.utils.train_utils import (
|
||||
save_checkpoint,
|
||||
update_last_checkpoint,
|
||||
)
|
||||
from lerobot.utils.transition import move_state_dict_to_device, move_transition_to_device
|
||||
from lerobot.utils.transition import move_transition_to_device
|
||||
from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
get_safe_torch_device,
|
||||
@@ -264,8 +266,8 @@ def add_actor_information_and_train(
|
||||
- Transfers transitions from the actor to the replay buffer.
|
||||
- Logs received interaction messages.
|
||||
- Ensures training begins only when the replay buffer has a sufficient number of transitions.
|
||||
- Samples batches from the replay buffer and performs multiple critic updates.
|
||||
- Periodically updates the actor, critic, and temperature optimizers.
|
||||
- Delegates training updates to an ``RLAlgorithm`` (currently ``SACAlgorithm``).
|
||||
- Periodically pushes updated weights to actors.
|
||||
- Logs training statistics, including loss values and optimization frequency.
|
||||
|
||||
NOTE: This function doesn't have a single responsibility, it should be split into multiple functions
|
||||
@@ -284,17 +286,15 @@ def add_actor_information_and_train(
|
||||
# of 7%
|
||||
device = get_safe_torch_device(try_device=cfg.policy.device, log=True)
|
||||
storage_device = get_safe_torch_device(try_device=cfg.policy.storage_device)
|
||||
clip_grad_norm_value = cfg.policy.grad_clip_norm
|
||||
online_step_before_learning = cfg.policy.online_step_before_learning
|
||||
utd_ratio = cfg.policy.utd_ratio
|
||||
fps = cfg.env.fps
|
||||
log_freq = cfg.log_freq
|
||||
save_freq = cfg.save_freq
|
||||
policy_update_freq = cfg.policy.policy_update_freq
|
||||
policy_parameters_push_frequency = cfg.policy.actor_learner_config.policy_parameters_push_frequency
|
||||
saving_checkpoint = cfg.save_checkpoint
|
||||
online_steps = cfg.policy.online_steps
|
||||
async_prefetch = cfg.policy.async_prefetch
|
||||
async_prefetch = cfg.async_prefetch
|
||||
queue_size = cfg.queue_size
|
||||
|
||||
# Initialize logging for multiprocessing
|
||||
if not use_threads(cfg):
|
||||
@@ -306,7 +306,7 @@ def add_actor_information_and_train(
|
||||
|
||||
logging.info("Initializing policy")
|
||||
|
||||
policy: SACPolicy = make_policy(
|
||||
policy = make_policy(
|
||||
cfg=cfg.policy,
|
||||
env_cfg=cfg.env,
|
||||
)
|
||||
@@ -315,19 +315,24 @@ def add_actor_information_and_train(
|
||||
|
||||
policy.train()
|
||||
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
algorithm = make_algorithm(
|
||||
policy=policy,
|
||||
policy_cfg=cfg.policy,
|
||||
algorithm_name=cfg.algorithm,
|
||||
)
|
||||
|
||||
# TODO: Re-enable processor pipeline once refactoring is validated against main
|
||||
preprocessor, postprocessor = None, None
|
||||
|
||||
# Push initial policy weights to actors (same path as periodic push)
|
||||
state_bytes = state_to_bytes(algorithm.get_weights())
|
||||
parameters_queue.put(state_bytes)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
optimizers, lr_scheduler = make_optimizers_and_scheduler(cfg=cfg, policy=policy)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
log_training_info(cfg=cfg, policy=policy)
|
||||
|
||||
replay_buffer = initialize_replay_buffer(cfg, device, storage_device)
|
||||
batch_size = cfg.batch_size
|
||||
total_batch_size = cfg.batch_size
|
||||
offline_replay_buffer = None
|
||||
|
||||
if cfg.dataset is not None:
|
||||
@@ -336,20 +341,70 @@ def add_actor_information_and_train(
|
||||
device=device,
|
||||
storage_device=storage_device,
|
||||
)
|
||||
batch_size: int = batch_size // 2 # We will sample from both replay buffer
|
||||
|
||||
# DataMixer: online-only or online/offline 50-50 mix
|
||||
data_mixer = OnlineOfflineMixer(
|
||||
online_buffer=replay_buffer,
|
||||
offline_buffer=offline_replay_buffer,
|
||||
online_ratio=cfg.online_ratio,
|
||||
)
|
||||
# RLTrainer owns the iterator, preprocessor, and creates optimizers.
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=total_batch_size,
|
||||
preprocessor=preprocessor,
|
||||
action_dim=cfg.policy.output_features["action"].shape[0],
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
# If we are resuming, we need to load the training state
|
||||
optimizers = algorithm.get_optimizers()
|
||||
resume_optimization_step, resume_interaction_step = load_training_state(cfg=cfg, optimizers=optimizers)
|
||||
|
||||
logging.info("Starting learner thread")
|
||||
interaction_message = None
|
||||
optimization_step = resume_optimization_step if resume_optimization_step is not None else 0
|
||||
algorithm.optimization_step = optimization_step
|
||||
interaction_step_shift = resume_interaction_step if resume_interaction_step is not None else 0
|
||||
|
||||
dataset_repo_id = None
|
||||
if cfg.dataset is not None:
|
||||
dataset_repo_id = cfg.dataset.repo_id
|
||||
|
||||
# Initialize iterators
|
||||
online_iterator = None
|
||||
offline_iterator = None
|
||||
# ── Offline phase (e.g. RLT RL-token training, ConRFT Cal-QL pretraining) ──
|
||||
offline_steps = getattr(cfg.policy, "offline_steps", 0)
|
||||
if algorithm.supports_offline_phase() and offline_steps > 0 and offline_replay_buffer is not None:
|
||||
logging.info(f"[LEARNER] Starting offline phase ({offline_steps} steps)")
|
||||
offline_mixer = OnlineOfflineMixer(
|
||||
online_buffer=offline_replay_buffer,
|
||||
offline_buffer=None,
|
||||
online_ratio=1.0,
|
||||
)
|
||||
offline_iterator = algorithm.configure_data_iterator(
|
||||
data_mixer=offline_mixer,
|
||||
batch_size=total_batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
for step in range(offline_steps):
|
||||
if shutdown_event is not None and shutdown_event.is_set():
|
||||
logging.info("[LEARNER] Shutdown during offline phase. Exiting...")
|
||||
return
|
||||
|
||||
stats = algorithm.offline_update(offline_iterator)
|
||||
|
||||
if step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Offline step {step}/{offline_steps}: {stats.to_log_dict()}")
|
||||
if wandb_logger:
|
||||
log_dict = stats.to_log_dict()
|
||||
log_dict["offline_step"] = step
|
||||
wandb_logger.log_dict(d=log_dict, mode="train", custom_step_key="offline_step")
|
||||
|
||||
algorithm.transition_to_online()
|
||||
optimizers = algorithm.get_optimizers()
|
||||
logging.info("[LEARNER] Offline phase complete, transitioned to online")
|
||||
|
||||
# NOTE: THIS IS THE MAIN LOOP OF THE LEARNER
|
||||
while True:
|
||||
@@ -380,180 +435,22 @@ def add_actor_information_and_train(
|
||||
if len(replay_buffer) < online_step_before_learning:
|
||||
continue
|
||||
|
||||
if online_iterator is None:
|
||||
online_iterator = replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
if offline_replay_buffer is not None and offline_iterator is None:
|
||||
offline_iterator = offline_replay_buffer.get_iterator(
|
||||
batch_size=batch_size, async_prefetch=async_prefetch, queue_size=2
|
||||
)
|
||||
|
||||
time_for_one_optimization_step = time.time()
|
||||
for _ in range(utd_ratio - 1):
|
||||
# Sample from the iterators
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
"complementary_info": batch["complementary_info"],
|
||||
}
|
||||
|
||||
# Use the forward method for critic loss
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
# Main critic optimization
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
)
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
|
||||
# Sample for the last update in the UTD ratio
|
||||
batch = next(online_iterator)
|
||||
|
||||
if dataset_repo_id is not None:
|
||||
batch_offline = next(offline_iterator)
|
||||
batch = concatenate_batch_transitions(
|
||||
left_batch_transitions=batch, right_batch_transition=batch_offline
|
||||
)
|
||||
|
||||
actions = batch[ACTION]
|
||||
rewards = batch["reward"]
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
done = batch["done"]
|
||||
|
||||
check_nan_in_transition(observations=observations, actions=actions, next_state=next_observations)
|
||||
|
||||
observation_features, next_observation_features = get_observation_features(
|
||||
policy=policy, observations=observations, next_observations=next_observations
|
||||
)
|
||||
|
||||
# Create a batch dictionary with all required elements for the forward method
|
||||
forward_batch = {
|
||||
ACTION: actions,
|
||||
"reward": rewards,
|
||||
"state": observations,
|
||||
"next_state": next_observations,
|
||||
"done": done,
|
||||
"observation_feature": observation_features,
|
||||
"next_observation_feature": next_observation_features,
|
||||
}
|
||||
|
||||
critic_output = policy.forward(forward_batch, model="critic")
|
||||
|
||||
loss_critic = critic_output["loss_critic"]
|
||||
optimizers["critic"].zero_grad()
|
||||
loss_critic.backward()
|
||||
critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.critic_ensemble.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["critic"].step()
|
||||
|
||||
# Initialize training info dictionary
|
||||
training_infos = {
|
||||
"loss_critic": loss_critic.item(),
|
||||
"critic_grad_norm": critic_grad_norm,
|
||||
}
|
||||
|
||||
# Discrete critic optimization (if available)
|
||||
if policy.config.num_discrete_actions is not None:
|
||||
discrete_critic_output = policy.forward(forward_batch, model="discrete_critic")
|
||||
loss_discrete_critic = discrete_critic_output["loss_discrete_critic"]
|
||||
optimizers["discrete_critic"].zero_grad()
|
||||
loss_discrete_critic.backward()
|
||||
discrete_critic_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.discrete_critic.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["discrete_critic"].step()
|
||||
|
||||
# Add discrete critic info to training info
|
||||
training_infos["loss_discrete_critic"] = loss_discrete_critic.item()
|
||||
training_infos["discrete_critic_grad_norm"] = discrete_critic_grad_norm
|
||||
|
||||
# Actor and temperature optimization (at specified frequency)
|
||||
if optimization_step % policy_update_freq == 0:
|
||||
for _ in range(policy_update_freq):
|
||||
# Actor optimization
|
||||
actor_output = policy.forward(forward_batch, model="actor")
|
||||
loss_actor = actor_output["loss_actor"]
|
||||
optimizers["actor"].zero_grad()
|
||||
loss_actor.backward()
|
||||
actor_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=policy.actor.parameters(), max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["actor"].step()
|
||||
|
||||
# Add actor info to training info
|
||||
training_infos["loss_actor"] = loss_actor.item()
|
||||
training_infos["actor_grad_norm"] = actor_grad_norm
|
||||
|
||||
# Temperature optimization
|
||||
temperature_output = policy.forward(forward_batch, model="temperature")
|
||||
loss_temperature = temperature_output["loss_temperature"]
|
||||
optimizers["temperature"].zero_grad()
|
||||
loss_temperature.backward()
|
||||
temp_grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||
parameters=[policy.log_alpha], max_norm=clip_grad_norm_value
|
||||
).item()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
# Add temperature info to training info
|
||||
training_infos["loss_temperature"] = loss_temperature.item()
|
||||
training_infos["temperature_grad_norm"] = temp_grad_norm
|
||||
training_infos["temperature"] = policy.temperature
|
||||
# One training step (trainer owns data_mixer iterator; algorithm owns UTD loop)
|
||||
stats = trainer.training_step()
|
||||
|
||||
# Push policy to actors if needed
|
||||
if time.time() - last_time_policy_pushed > policy_parameters_push_frequency:
|
||||
push_actor_policy_to_queue(parameters_queue=parameters_queue, policy=policy)
|
||||
state_dicts = algorithm.get_weights()
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
last_time_policy_pushed = time.time()
|
||||
|
||||
# Update target networks (main and discrete)
|
||||
policy.update_target_networks()
|
||||
training_infos = stats.to_log_dict()
|
||||
|
||||
# Log training metrics at specified intervals
|
||||
optimization_step = algorithm.optimization_step
|
||||
if optimization_step % log_freq == 0:
|
||||
training_infos["replay_buffer_size"] = len(replay_buffer)
|
||||
if offline_replay_buffer is not None:
|
||||
@@ -581,7 +478,6 @@ def add_actor_information_and_train(
|
||||
custom_step_key="Optimization step",
|
||||
)
|
||||
|
||||
optimization_step += 1
|
||||
if optimization_step % log_freq == 0:
|
||||
logging.info(f"[LEARNER] Number of optimization step: {optimization_step}")
|
||||
|
||||
@@ -598,6 +494,8 @@ def add_actor_information_and_train(
|
||||
offline_replay_buffer=offline_replay_buffer,
|
||||
dataset_repo_id=dataset_repo_id,
|
||||
fps=fps,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
|
||||
@@ -682,6 +580,8 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: ReplayBuffer | None = None,
|
||||
dataset_repo_id: str | None = None,
|
||||
fps: int = 30,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
) -> None:
|
||||
"""
|
||||
Save training checkpoint and associated data.
|
||||
@@ -705,6 +605,8 @@ def save_training_checkpoint(
|
||||
offline_replay_buffer: Optional offline replay buffer to save
|
||||
dataset_repo_id: Repository ID for dataset
|
||||
fps: Frames per second for dataset
|
||||
preprocessor: Optional preprocessor pipeline to save
|
||||
postprocessor: Optional postprocessor pipeline to save
|
||||
"""
|
||||
logging.info(f"Checkpoint policy after step {optimization_step}")
|
||||
_num_digits = max(6, len(str(online_steps)))
|
||||
@@ -721,6 +623,8 @@ def save_training_checkpoint(
|
||||
policy=policy,
|
||||
optimizer=optimizers,
|
||||
scheduler=None,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
|
||||
# Save interaction step manually
|
||||
@@ -758,58 +662,6 @@ def save_training_checkpoint(
|
||||
logging.info("Resume training")
|
||||
|
||||
|
||||
def make_optimizers_and_scheduler(cfg: TrainRLServerPipelineConfig, policy: nn.Module):
|
||||
"""
|
||||
Creates and returns optimizers for the actor, critic, and temperature components of a reinforcement learning policy.
|
||||
|
||||
This function sets up Adam optimizers for:
|
||||
- The **actor network**, ensuring that only relevant parameters are optimized.
|
||||
- The **critic ensemble**, which evaluates the value function.
|
||||
- The **temperature parameter**, which controls the entropy in soft actor-critic (SAC)-like methods.
|
||||
|
||||
It also initializes a learning rate scheduler, though currently, it is set to `None`.
|
||||
|
||||
NOTE:
|
||||
- If the encoder is shared, its parameters are excluded from the actor's optimization process.
|
||||
- The policy's log temperature (`log_alpha`) is wrapped in a list to ensure proper optimization as a standalone tensor.
|
||||
|
||||
Args:
|
||||
cfg: Configuration object containing hyperparameters.
|
||||
policy (nn.Module): The policy model containing the actor, critic, and temperature components.
|
||||
|
||||
Returns:
|
||||
Tuple[Dict[str, torch.optim.Optimizer], Optional[torch.optim.lr_scheduler._LRScheduler]]:
|
||||
A tuple containing:
|
||||
- `optimizers`: A dictionary mapping component names ("actor", "critic", "temperature") to their respective Adam optimizers.
|
||||
- `lr_scheduler`: Currently set to `None` but can be extended to support learning rate scheduling.
|
||||
|
||||
"""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=cfg.policy.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(params=policy.critic_ensemble.parameters(), lr=cfg.policy.critic_lr)
|
||||
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizer_discrete_critic = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(), lr=cfg.policy.critic_lr
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(params=[policy.log_alpha], lr=cfg.policy.critic_lr)
|
||||
lr_scheduler = None
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
if cfg.policy.num_discrete_actions is not None:
|
||||
optimizers["discrete_critic"] = optimizer_discrete_critic
|
||||
return optimizers, lr_scheduler
|
||||
|
||||
|
||||
# Training setup functions
|
||||
|
||||
|
||||
@@ -1014,33 +866,6 @@ def initialize_offline_replay_buffer(
|
||||
# Utilities/Helpers functions
|
||||
|
||||
|
||||
def get_observation_features(
|
||||
policy: SACPolicy, observations: torch.Tensor, next_observations: torch.Tensor
|
||||
) -> tuple[torch.Tensor | None, torch.Tensor | None]:
|
||||
"""
|
||||
Get observation features from the policy encoder. It act as cache for the observation features.
|
||||
when the encoder is frozen, the observation features are not updated.
|
||||
We can save compute by caching the observation features.
|
||||
|
||||
Args:
|
||||
policy: The policy model
|
||||
observations: The current observations
|
||||
next_observations: The next observations
|
||||
|
||||
Returns:
|
||||
tuple: observation_features, next_observation_features
|
||||
"""
|
||||
|
||||
if policy.config.vision_encoder_name is None or not policy.config.freeze_vision_encoder:
|
||||
return None, None
|
||||
|
||||
with torch.no_grad():
|
||||
observation_features = policy.actor.encoder.get_cached_image_features(observations)
|
||||
next_observation_features = policy.actor.encoder.get_cached_image_features(next_observations)
|
||||
|
||||
return observation_features, next_observation_features
|
||||
|
||||
|
||||
def use_threads(cfg: TrainRLServerPipelineConfig) -> bool:
|
||||
return cfg.policy.concurrency.learner == "threads"
|
||||
|
||||
@@ -1091,23 +916,6 @@ def check_nan_in_transition(
|
||||
return nan_detected
|
||||
|
||||
|
||||
def push_actor_policy_to_queue(parameters_queue: Queue, policy: nn.Module):
|
||||
logging.debug("[LEARNER] Pushing actor policy to the queue")
|
||||
|
||||
# Create a dictionary to hold all the state dicts
|
||||
state_dicts = {"policy": move_state_dict_to_device(policy.actor.state_dict(), device="cpu")}
|
||||
|
||||
# Add discrete critic if it exists
|
||||
if hasattr(policy, "discrete_critic") and policy.discrete_critic is not None:
|
||||
state_dicts["discrete_critic"] = move_state_dict_to_device(
|
||||
policy.discrete_critic.state_dict(), device="cpu"
|
||||
)
|
||||
logging.debug("[LEARNER] Including discrete critic in state dict push")
|
||||
|
||||
state_bytes = state_to_bytes(state_dicts)
|
||||
parameters_queue.put(state_bytes)
|
||||
|
||||
|
||||
def process_interaction_message(
|
||||
message, interaction_step_shift: int, wandb_logger: WandBLogger | None = None
|
||||
):
|
||||
|
||||
@@ -0,0 +1,132 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.algorithms.base import (
|
||||
BatchType,
|
||||
RLAlgorithm,
|
||||
TrainingStats,
|
||||
)
|
||||
from lerobot.rl.data_sources.data_mixer import DataMixer
|
||||
from lerobot.utils.constants import ACTION
|
||||
|
||||
|
||||
def preprocess_rl_batch(preprocessor: Any, batch: BatchType, *, action_dim: int | None = None) -> BatchType:
|
||||
"""Apply a policy preprocessor to an RL batch."""
|
||||
observations = batch["state"]
|
||||
next_observations = batch["next_state"]
|
||||
actions = batch[ACTION]
|
||||
|
||||
extra_action = None
|
||||
if action_dim is not None and actions.shape[-1] > action_dim:
|
||||
extra_action = actions[..., action_dim:]
|
||||
actions = actions[..., :action_dim]
|
||||
|
||||
obs_action = {**observations, ACTION: actions}
|
||||
obs_action = preprocessor(obs_action)
|
||||
batch["state"] = {k: v for k, v in obs_action.items() if k.startswith("observation.")}
|
||||
batch[ACTION] = obs_action[ACTION]
|
||||
|
||||
if extra_action is not None:
|
||||
batch[ACTION] = torch.cat([batch[ACTION], extra_action], dim=-1)
|
||||
|
||||
next_obs = {**next_observations}
|
||||
next_obs = preprocessor(next_obs)
|
||||
batch["next_state"] = {k: v for k, v in next_obs.items() if k.startswith("observation.")}
|
||||
|
||||
return batch
|
||||
|
||||
|
||||
class _PreprocessedIterator:
|
||||
"""Iterator wrapper that preprocesses each sampled RL batch."""
|
||||
|
||||
__slots__ = ("_raw", "_preprocessor", "_action_dim")
|
||||
|
||||
def __init__(
|
||||
self, raw_iterator: Iterator[BatchType], preprocessor: Any, action_dim: int | None = None
|
||||
) -> None:
|
||||
self._raw = raw_iterator
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
|
||||
def __iter__(self) -> _PreprocessedIterator:
|
||||
return self
|
||||
|
||||
def __next__(self) -> BatchType:
|
||||
batch = next(self._raw)
|
||||
return preprocess_rl_batch(self._preprocessor, batch, action_dim=self._action_dim)
|
||||
|
||||
|
||||
class RLTrainer:
|
||||
"""Unified training step orchestrator.
|
||||
|
||||
Holds the algorithm, a DataMixer, and an optional preprocessor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
algorithm: RLAlgorithm,
|
||||
data_mixer: DataMixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
preprocessor: Any | None = None,
|
||||
action_dim: int | None = None,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
self.algorithm = algorithm
|
||||
self.data_mixer = data_mixer
|
||||
self.batch_size = batch_size
|
||||
self._preprocessor = preprocessor
|
||||
self._action_dim = action_dim
|
||||
self.async_prefetch = async_prefetch
|
||||
self.queue_size = queue_size
|
||||
|
||||
self._iterator: Iterator[BatchType] | None = None
|
||||
|
||||
self.algorithm.make_optimizers()
|
||||
|
||||
def _build_data_iterator(self) -> Iterator[BatchType]:
|
||||
"""Create a fresh algorithm-configured iterator (optionally preprocessed)."""
|
||||
raw = self.algorithm.configure_data_iterator(
|
||||
data_mixer=self.data_mixer,
|
||||
batch_size=self.batch_size,
|
||||
async_prefetch=self.async_prefetch,
|
||||
queue_size=self.queue_size,
|
||||
)
|
||||
if self._preprocessor is not None:
|
||||
return _PreprocessedIterator(raw, self._preprocessor, self._action_dim)
|
||||
return raw
|
||||
|
||||
def reset_data_iterator(self) -> None:
|
||||
"""Discard the current iterator so it will be rebuilt lazily next step."""
|
||||
self._iterator = None
|
||||
|
||||
def set_data_mixer(self, data_mixer: DataMixer, *, reset: bool = True) -> None:
|
||||
"""Swap the active data mixer, optionally resetting the iterator."""
|
||||
self.data_mixer = data_mixer
|
||||
if reset:
|
||||
self.reset_data_iterator()
|
||||
|
||||
def training_step(self) -> TrainingStats:
|
||||
"""Run one training step (algorithm-agnostic)."""
|
||||
if self._iterator is None:
|
||||
self._iterator = self._build_data_iterator()
|
||||
return self.algorithm.update(self._iterator)
|
||||
@@ -104,28 +104,6 @@ Convert image dataset to video format and push to hub:
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Trim single episode to keep only frames within timestamp range:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trimmed \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--operation.end_timestamp 30.0
|
||||
|
||||
Trim multiple episodes at once (use null for no limit):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
|
||||
|
||||
Trim and re-upload to same repo (overwrites original):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--push_to_hub true
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -226,32 +204,9 @@ class InfoConfig(OperationConfig):
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrimEpisodeConfig:
|
||||
"""Trim episodes to keep only frames within timestamp ranges.
|
||||
|
||||
Supports multiple episodes via episode_trims dict:
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
|
||||
|
||||
Or single episode via legacy parameters:
|
||||
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
|
||||
"""
|
||||
type: str = "trim_episode"
|
||||
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
|
||||
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
|
||||
episode_trims: dict[str, list[float | None]] | None = None
|
||||
# Legacy single-episode parameters (used if episode_trims is None)
|
||||
episode_index: int | None = None
|
||||
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
|
||||
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
|
||||
)
|
||||
operation: OperationConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -396,92 +351,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
|
||||
"""Trim episodes to keep only frames within timestamp ranges."""
|
||||
if not isinstance(cfg.operation, TrimEpisodeConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeConfig")
|
||||
|
||||
# Parse episode trims - support both multi-episode dict and legacy single episode
|
||||
episode_trims: dict[int, tuple[float | None, float | None]] = {}
|
||||
|
||||
if cfg.operation.episode_trims is not None:
|
||||
# Multi-episode mode
|
||||
for ep_str, ts_range in cfg.operation.episode_trims.items():
|
||||
ep_idx = int(ep_str)
|
||||
start_ts = ts_range[0] if len(ts_range) > 0 else None
|
||||
end_ts = ts_range[1] if len(ts_range) > 1 else None
|
||||
episode_trims[ep_idx] = (start_ts, end_ts)
|
||||
elif cfg.operation.episode_index is not None:
|
||||
# Legacy single-episode mode
|
||||
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
|
||||
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
|
||||
episode_trims[cfg.operation.episode_index] = (
|
||||
cfg.operation.start_timestamp,
|
||||
cfg.operation.end_timestamp,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either episode_trims or episode_index must be specified")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
|
||||
|
||||
# Get episode boundaries and find frames to keep for each episode
|
||||
episodes_info = dataset.meta.episodes
|
||||
all_frames_to_keep: dict[int, list[int]] = {}
|
||||
|
||||
for ep_idx, (start_ts, end_ts) in episode_trims.items():
|
||||
if ep_idx >= len(episodes_info["episode_index"]):
|
||||
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
|
||||
|
||||
from_frame = episodes_info["dataset_from_index"][ep_idx]
|
||||
to_frame = episodes_info["dataset_to_index"][ep_idx]
|
||||
|
||||
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
|
||||
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
|
||||
|
||||
# Find frames within timestamp range
|
||||
frames_to_keep = []
|
||||
for frame_idx in range(from_frame, to_frame):
|
||||
frame = dataset.hf_dataset[frame_idx]
|
||||
ts = frame["timestamp"]
|
||||
|
||||
in_range = True
|
||||
if start_ts is not None and ts < start_ts:
|
||||
in_range = False
|
||||
if end_ts is not None and ts > end_ts:
|
||||
in_range = False
|
||||
|
||||
if in_range:
|
||||
frames_to_keep.append(frame_idx)
|
||||
|
||||
if not frames_to_keep:
|
||||
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
|
||||
|
||||
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
|
||||
all_frames_to_keep[ep_idx] = frames_to_keep
|
||||
|
||||
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
|
||||
|
||||
new_dataset = trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep=all_frames_to_keep,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
@@ -646,8 +515,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "trim_episode":
|
||||
handle_trim_episode(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
|
||||
@@ -26,10 +26,8 @@ lerobot-record \
|
||||
--dataset.repo_id=<my_username>/<my_dataset_name> \
|
||||
--dataset.num_episodes=2 \
|
||||
--dataset.single_task="Grab the cube" \
|
||||
--dataset.streaming_encoding=true \
|
||||
--dataset.encoder_threads=2 \
|
||||
--display_data=true
|
||||
# <- Optional: specify video codec (auto, h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# <- Optional: specify video codec (h264, hevc, libsvtav1). Default is libsvtav1. \
|
||||
# --dataset.vcodec=h264 \
|
||||
# <- Teleop optional if you want to teleoperate to record or in between episodes with a policy \
|
||||
# --teleop.type=so100_leader \
|
||||
@@ -60,10 +58,7 @@ lerobot-record \
|
||||
--display_data=true \
|
||||
--dataset.repo_id=${HF_USER}/bimanual-so-handover-cube \
|
||||
--dataset.num_episodes=25 \
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||
--dataset.streaming_encoding=true \
|
||||
# --dataset.vcodec=auto \
|
||||
--dataset.encoder_threads=2
|
||||
--dataset.single_task="Grab and handover the red cube to the other arm"
|
||||
```
|
||||
"""
|
||||
|
||||
@@ -184,19 +179,9 @@ class DatasetRecordConfig:
|
||||
# Number of episodes to record before batch encoding videos
|
||||
# Set to 1 for immediate encoding (default behavior), or higher for batched encoding
|
||||
video_encoding_batch_size: int = 1
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1', 'auto',
|
||||
# or hardware-specific: 'h264_videotoolbox', 'h264_nvenc', 'h264_vaapi', 'h264_qsv'.
|
||||
# Use 'auto' to auto-detect the best available hardware encoder.
|
||||
# Video codec for encoding videos. Options: 'h264', 'hevc', 'libsvtav1'.
|
||||
# Use 'h264' for faster encoding on systems where AV1 encoding is CPU-heavy.
|
||||
vcodec: str = "libsvtav1"
|
||||
# Enable streaming video encoding: encode frames in real-time during capture instead
|
||||
# of writing PNG images first. Makes save_episode() near-instant. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding
|
||||
streaming_encoding: bool = False
|
||||
# Maximum number of frames to buffer per camera when using streaming encoding.
|
||||
# ~1s buffer at 30fps. Provides backpressure if the encoder can't keep up.
|
||||
encoder_queue_maxsize: int = 30
|
||||
# Number of threads per encoder instance. None = auto (codec default).
|
||||
# Lower values reduce CPU usage, maps to 'lp' (via svtav1-params) for libsvtav1 and 'threads' for h264/hevc..
|
||||
encoder_threads: int | None = None
|
||||
# Rename map for the observation to override the image and state keys
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
@@ -467,9 +452,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
root=cfg.dataset.root,
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
if hasattr(robot, "cameras") and len(robot.cameras) > 0:
|
||||
@@ -492,9 +474,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
image_writer_threads=cfg.dataset.num_image_writer_threads_per_camera * len(robot.cameras),
|
||||
batch_encoding_size=cfg.dataset.video_encoding_batch_size,
|
||||
vcodec=cfg.dataset.vcodec,
|
||||
streaming_encoding=cfg.dataset.streaming_encoding,
|
||||
encoder_queue_maxsize=cfg.dataset.encoder_queue_maxsize,
|
||||
encoder_threads=cfg.dataset.encoder_threads,
|
||||
)
|
||||
|
||||
# Load pretrained policy
|
||||
@@ -518,11 +497,6 @@ def record(cfg: RecordConfig) -> LeRobotDataset:
|
||||
|
||||
listener, events = init_keyboard_listener()
|
||||
|
||||
if not cfg.dataset.streaming_encoding:
|
||||
logging.info(
|
||||
"Streaming encoding is disabled. If you have capable hardware, consider enabling it for way faster episode saving. --dataset.streaming_encoding=true --dataset.encoder_threads=2 # --dataset.vcodec=auto. More info in the documentation: https://huggingface.co/docs/lerobot/streaming_video_encoding"
|
||||
)
|
||||
|
||||
with VideoEncodingManager(dataset):
|
||||
recorded_episodes = 0
|
||||
while recorded_episodes < cfg.dataset.num_episodes and not events["stop_recording"]:
|
||||
|
||||
@@ -152,7 +152,6 @@ def test_motor(bus, motor_id: int, timeout: float, use_fd: bool):
|
||||
)
|
||||
try:
|
||||
bus.send(disable_msg)
|
||||
bus.recv(timeout=0.1) # Clear any pending responses
|
||||
except Exception:
|
||||
print(f"Error sending message to motor 0x{motor_id:02X}")
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -24,7 +24,6 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -52,7 +51,6 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -392,14 +390,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -424,8 +414,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -519,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -1,296 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
@@ -95,6 +95,7 @@ def save_checkpoint(
|
||||
optimizer (Optimizer | None, optional): The optimizer to save the state from. Defaults to None.
|
||||
scheduler (LRScheduler | None, optional): The scheduler to save the state from. Defaults to None.
|
||||
preprocessor: The preprocessor/pipeline to save. Defaults to None.
|
||||
postprocessor: The postprocessor/pipeline to save. Defaults to None.
|
||||
"""
|
||||
pretrained_dir = checkpoint_dir / PRETRAINED_MODEL_DIR
|
||||
policy.save_pretrained(pretrained_dir)
|
||||
|
||||
@@ -31,6 +31,7 @@ from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.factory import make_dataset
|
||||
from lerobot.datasets.image_writer import image_array_to_pil_image
|
||||
from lerobot.datasets.lerobot_dataset import (
|
||||
VALID_VIDEO_CODECS,
|
||||
LeRobotDataset,
|
||||
MultiLeRobotDataset,
|
||||
_encode_video_worker,
|
||||
@@ -44,7 +45,6 @@ from lerobot.datasets.utils import (
|
||||
hf_transform_to_torch,
|
||||
hw_to_dataset_features,
|
||||
)
|
||||
from lerobot.datasets.video_utils import VALID_VIDEO_CODECS
|
||||
from lerobot.envs.factory import make_env_config
|
||||
from lerobot.policies.factory import make_policy_config
|
||||
from lerobot.robots import make_robot_from_config
|
||||
@@ -393,7 +393,7 @@ def test_tmp_mixed_deletion(tmp_path, empty_lerobot_dataset_factory):
|
||||
vid_key: {"dtype": "video", "shape": DUMMY_HWC, "names": ["height", "width", "channels"]},
|
||||
}
|
||||
ds_mixed = empty_lerobot_dataset_factory(
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2, streaming_encoding=False
|
||||
root=tmp_path / "mixed", features=features_mixed, batch_encoding_size=2
|
||||
)
|
||||
ds_mixed.add_frame(
|
||||
{
|
||||
@@ -1450,10 +1450,7 @@ def test_valid_video_codecs_constant():
|
||||
assert "h264" in VALID_VIDEO_CODECS
|
||||
assert "hevc" in VALID_VIDEO_CODECS
|
||||
assert "libsvtav1" in VALID_VIDEO_CODECS
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
assert len(VALID_VIDEO_CODECS) == 10
|
||||
assert len(VALID_VIDEO_CODECS) == 3
|
||||
|
||||
|
||||
def test_delta_timestamps_with_episodes_filter(tmp_path, empty_lerobot_dataset_factory):
|
||||
|
||||
@@ -1,730 +0,0 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Tests for streaming video encoding and hardware-accelerated encoding."""
|
||||
|
||||
import queue
|
||||
import threading
|
||||
from unittest.mock import patch
|
||||
|
||||
import av
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
from lerobot.datasets.video_utils import (
|
||||
VALID_VIDEO_CODECS,
|
||||
StreamingVideoEncoder,
|
||||
_CameraEncoderThread,
|
||||
_get_codec_options,
|
||||
detect_available_hw_encoders,
|
||||
resolve_vcodec,
|
||||
)
|
||||
from lerobot.utils.constants import OBS_IMAGES
|
||||
|
||||
# ─── _get_codec_options tests ───
|
||||
|
||||
|
||||
class TestGetCodecOptions:
|
||||
def test_libsvtav1_defaults(self):
|
||||
opts = _get_codec_options("libsvtav1")
|
||||
assert opts["g"] == "2"
|
||||
assert opts["crf"] == "30"
|
||||
assert opts["preset"] == "12"
|
||||
|
||||
def test_libsvtav1_custom_preset(self):
|
||||
opts = _get_codec_options("libsvtav1", preset=8)
|
||||
assert opts["preset"] == "8"
|
||||
|
||||
def test_h264_options(self):
|
||||
opts = _get_codec_options("h264", g=10, crf=23)
|
||||
assert opts["g"] == "10"
|
||||
assert opts["crf"] == "23"
|
||||
assert "preset" not in opts
|
||||
|
||||
def test_videotoolbox_options(self):
|
||||
opts = _get_codec_options("h264_videotoolbox", g=2, crf=30)
|
||||
assert opts["g"] == "2"
|
||||
# CRF 30 maps to quality = max(1, min(100, 100 - 30*2)) = 40
|
||||
assert opts["q:v"] == "40"
|
||||
assert "crf" not in opts
|
||||
|
||||
def test_nvenc_options(self):
|
||||
opts = _get_codec_options("h264_nvenc", g=2, crf=25)
|
||||
assert opts["rc"] == "constqp"
|
||||
assert opts["qp"] == "25"
|
||||
assert "crf" not in opts
|
||||
# NVENC doesn't support g
|
||||
assert "g" not in opts
|
||||
|
||||
def test_vaapi_options(self):
|
||||
opts = _get_codec_options("h264_vaapi", crf=28)
|
||||
assert opts["qp"] == "28"
|
||||
|
||||
def test_qsv_options(self):
|
||||
opts = _get_codec_options("h264_qsv", crf=25)
|
||||
assert opts["global_quality"] == "25"
|
||||
|
||||
def test_no_g_no_crf(self):
|
||||
opts = _get_codec_options("h264", g=None, crf=None)
|
||||
assert "g" not in opts
|
||||
assert "crf" not in opts
|
||||
|
||||
|
||||
# ─── HW encoder detection tests ───
|
||||
|
||||
|
||||
class TestHWEncoderDetection:
|
||||
def test_detect_available_hw_encoders_returns_list(self):
|
||||
result = detect_available_hw_encoders()
|
||||
assert isinstance(result, list)
|
||||
|
||||
def test_detect_available_hw_encoders_only_valid(self):
|
||||
from lerobot.datasets.video_utils import HW_ENCODERS
|
||||
|
||||
result = detect_available_hw_encoders()
|
||||
for encoder in result:
|
||||
assert encoder in HW_ENCODERS
|
||||
|
||||
def test_resolve_vcodec_passthrough(self):
|
||||
assert resolve_vcodec("libsvtav1") == "libsvtav1"
|
||||
assert resolve_vcodec("h264") == "h264"
|
||||
|
||||
def test_resolve_vcodec_auto_fallback(self):
|
||||
"""When no HW encoders are available, auto should fall back to libsvtav1."""
|
||||
with patch("lerobot.datasets.video_utils.detect_available_hw_encoders", return_value=[]):
|
||||
assert resolve_vcodec("auto") == "libsvtav1"
|
||||
|
||||
def test_resolve_vcodec_auto_picks_hw(self):
|
||||
"""When a HW encoder is available, auto should pick it."""
|
||||
with patch(
|
||||
"lerobot.datasets.video_utils.detect_available_hw_encoders",
|
||||
return_value=["h264_videotoolbox"],
|
||||
):
|
||||
assert resolve_vcodec("auto") == "h264_videotoolbox"
|
||||
|
||||
def test_resolve_vcodec_auto_returns_valid(self):
|
||||
"""Test that resolve_vcodec('auto') returns a known valid codec."""
|
||||
result = resolve_vcodec("auto")
|
||||
assert result in VALID_VIDEO_CODECS
|
||||
|
||||
def test_hw_encoder_names_accepted_in_validation(self):
|
||||
"""Test that HW encoder names pass validation in VALID_VIDEO_CODECS."""
|
||||
assert "auto" in VALID_VIDEO_CODECS
|
||||
assert "h264_videotoolbox" in VALID_VIDEO_CODECS
|
||||
assert "h264_nvenc" in VALID_VIDEO_CODECS
|
||||
|
||||
def test_resolve_vcodec_invalid_raises(self):
|
||||
"""Test that resolve_vcodec raises ValueError for invalid codecs."""
|
||||
with pytest.raises(ValueError, match="Invalid vcodec"):
|
||||
resolve_vcodec("not_a_real_codec")
|
||||
|
||||
|
||||
# ─── _CameraEncoderThread tests ───
|
||||
|
||||
|
||||
class TestCameraEncoderThread:
|
||||
def test_encodes_valid_mp4(self, tmp_path):
|
||||
"""Test that the encoder thread creates a valid MP4 file with correct frame count."""
|
||||
num_frames = 30
|
||||
height, width = 64, 96
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_output" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed frames (HWC uint8)
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (height, width, 3), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
# Send sentinel
|
||||
frame_queue.put(None)
|
||||
encoder_thread.join(timeout=60)
|
||||
assert not encoder_thread.is_alive()
|
||||
|
||||
# Check result
|
||||
status, data = result_queue.get(timeout=5)
|
||||
assert status == "ok"
|
||||
assert data is not None # Stats should be returned
|
||||
assert "mean" in data
|
||||
assert "std" in data
|
||||
assert "min" in data
|
||||
assert "max" in data
|
||||
assert "count" in data
|
||||
|
||||
# Verify the MP4 file is valid
|
||||
assert video_path.exists()
|
||||
with av.open(str(video_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
# The frame count should match
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
def test_handles_chw_input(self, tmp_path):
|
||||
"""Test that CHW format input is handled correctly."""
|
||||
num_frames = 5
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_chw" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed CHW frames
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (3, 64, 96), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
frame_queue.put(None)
|
||||
encoder_thread.join(timeout=60)
|
||||
|
||||
status, _ = result_queue.get(timeout=5)
|
||||
assert status == "ok"
|
||||
assert video_path.exists()
|
||||
|
||||
def test_stop_event_cancellation(self, tmp_path):
|
||||
"""Test that setting the stop event causes the thread to exit."""
|
||||
fps = 30
|
||||
video_path = tmp_path / "test_cancel" / "test.mp4"
|
||||
|
||||
frame_queue: queue.Queue = queue.Queue(maxsize=60)
|
||||
result_queue: queue.Queue = queue.Queue(maxsize=1)
|
||||
stop_event = threading.Event()
|
||||
|
||||
encoder_thread = _CameraEncoderThread(
|
||||
video_path=video_path,
|
||||
fps=fps,
|
||||
vcodec="libsvtav1",
|
||||
pix_fmt="yuv420p",
|
||||
g=2,
|
||||
crf=30,
|
||||
preset=13,
|
||||
frame_queue=frame_queue,
|
||||
result_queue=result_queue,
|
||||
stop_event=stop_event,
|
||||
)
|
||||
encoder_thread.start()
|
||||
|
||||
# Feed a few frames
|
||||
for _ in range(3):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame_queue.put(frame)
|
||||
|
||||
# Signal stop instead of sending sentinel
|
||||
stop_event.set()
|
||||
encoder_thread.join(timeout=10)
|
||||
assert not encoder_thread.is_alive()
|
||||
|
||||
|
||||
# ─── StreamingVideoEncoder tests ───
|
||||
|
||||
|
||||
class TestStreamingVideoEncoder:
|
||||
def test_single_camera_episode(self, tmp_path):
|
||||
"""Test encoding a single camera episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 20
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.laptop", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
assert f"{OBS_IMAGES}.laptop" in results
|
||||
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.laptop"]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
# Verify frame count
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_multi_camera_episode(self, tmp_path):
|
||||
"""Test encoding multiple cameras simultaneously."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.laptop", f"{OBS_IMAGES}.phone"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 15
|
||||
for _ in range(num_frames):
|
||||
frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(video_keys[0], frame0)
|
||||
encoder.feed_frame(video_keys[1], frame1)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
|
||||
for key in video_keys:
|
||||
assert key in results
|
||||
mp4_path, stats = results[key]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_sequential_episodes(self, tmp_path):
|
||||
"""Test that multiple sequential episodes work correctly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
|
||||
for ep in range(3):
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
num_frames = 10 + ep * 5
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
results = encoder.finish_episode()
|
||||
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_cancel_episode(self, tmp_path):
|
||||
"""Test that canceling an episode cleans up properly."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
for _ in range(5):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
encoder.cancel_episode()
|
||||
|
||||
# Should be able to start a new episode after cancel
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
for _ in range(5):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
results = encoder.finish_episode()
|
||||
|
||||
assert f"{OBS_IMAGES}.cam" in results
|
||||
encoder.close()
|
||||
|
||||
def test_feed_without_start_raises(self, tmp_path):
|
||||
"""Test that feeding frames without starting an episode raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.feed_frame("cam", np.zeros((64, 96, 3), dtype=np.uint8))
|
||||
encoder.close()
|
||||
|
||||
def test_finish_without_start_raises(self, tmp_path):
|
||||
"""Test that finishing without starting raises."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
with pytest.raises(RuntimeError, match="No active episode"):
|
||||
encoder.finish_episode()
|
||||
encoder.close()
|
||||
|
||||
def test_close_is_idempotent(self, tmp_path):
|
||||
"""Test that close() can be called multiple times safely."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
encoder.close()
|
||||
encoder.close() # Should not raise
|
||||
|
||||
def test_video_duration_matches_frame_count(self, tmp_path):
|
||||
"""Test that encoded video duration matches num_frames / fps."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 90 # 3 seconds at 30fps
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
mp4_path, _ = results[f"{OBS_IMAGES}.cam"]
|
||||
|
||||
expected_duration = num_frames / 30.0 # 3.0 seconds
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
if stream.duration is not None:
|
||||
actual_duration = float(stream.duration * stream.time_base)
|
||||
else:
|
||||
actual_duration = float(container.duration / av.time_base)
|
||||
|
||||
assert total_frames == num_frames
|
||||
# Allow small tolerance for duration due to codec framing
|
||||
assert abs(actual_duration - expected_duration) < 0.5, (
|
||||
f"Video duration {actual_duration:.2f}s != expected {expected_duration:.2f}s"
|
||||
)
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_multi_camera_start_episode_called_once(self, tmp_path):
|
||||
"""Test that with multiple cameras, no frames are lost due to double start_episode."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30)
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam1", f"{OBS_IMAGES}.cam2"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
num_frames = 30
|
||||
for _ in range(num_frames):
|
||||
frame0 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
frame1 = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(video_keys[0], frame0)
|
||||
encoder.feed_frame(video_keys[1], frame1)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
|
||||
# Both cameras should have all frames
|
||||
for key in video_keys:
|
||||
mp4_path, stats = results[key]
|
||||
assert mp4_path.exists()
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames, (
|
||||
f"Camera {key}: expected {num_frames} frames, got {total_frames}"
|
||||
)
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_encoder_threads_passed_to_thread(self, tmp_path):
|
||||
"""Test that encoder_threads is stored and passed through to encoder threads."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, encoder_threads=2
|
||||
)
|
||||
assert encoder.encoder_threads == 2
|
||||
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Verify the thread received the encoder_threads value
|
||||
thread = encoder._threads[f"{OBS_IMAGES}.cam"]
|
||||
assert thread.encoder_threads == 2
|
||||
|
||||
# Feed some frames and finish to ensure it works end-to-end
|
||||
num_frames = 10
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
results = encoder.finish_episode()
|
||||
mp4_path, stats = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
assert stats is not None
|
||||
|
||||
with av.open(str(mp4_path)) as container:
|
||||
stream = container.streams.video[0]
|
||||
total_frames = sum(1 for _ in container.decode(stream))
|
||||
assert total_frames == num_frames
|
||||
|
||||
encoder.close()
|
||||
|
||||
def test_encoder_threads_none_by_default(self, tmp_path):
|
||||
"""Test that encoder_threads defaults to None (codec auto-detect)."""
|
||||
encoder = StreamingVideoEncoder(fps=30, vcodec="libsvtav1", pix_fmt="yuv420p")
|
||||
assert encoder.encoder_threads is None
|
||||
encoder.close()
|
||||
|
||||
def test_graceful_frame_dropping(self, tmp_path):
|
||||
"""Test that full queue drops frames instead of crashing."""
|
||||
encoder = StreamingVideoEncoder(
|
||||
fps=30, vcodec="libsvtav1", pix_fmt="yuv420p", g=2, crf=30, preset=13, queue_maxsize=1
|
||||
)
|
||||
video_keys = [f"{OBS_IMAGES}.cam"]
|
||||
encoder.start_episode(video_keys, tmp_path)
|
||||
|
||||
# Feed many frames quickly - with queue_maxsize=1, some will be dropped
|
||||
num_frames = 50
|
||||
for _ in range(num_frames):
|
||||
frame = np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8)
|
||||
encoder.feed_frame(f"{OBS_IMAGES}.cam", frame)
|
||||
|
||||
# Should not raise - frames are dropped gracefully
|
||||
results = encoder.finish_episode()
|
||||
assert f"{OBS_IMAGES}.cam" in results
|
||||
|
||||
mp4_path, _ = results[f"{OBS_IMAGES}.cam"]
|
||||
assert mp4_path.exists()
|
||||
|
||||
# Some frames should have been dropped (queue was tiny)
|
||||
dropped = encoder._dropped_frames.get(f"{OBS_IMAGES}.cam", 0)
|
||||
# We can't guarantee drops but can verify no crash occurred
|
||||
assert dropped >= 0
|
||||
|
||||
encoder.close()
|
||||
|
||||
|
||||
# ─── Integration tests with LeRobotDataset ───
|
||||
|
||||
|
||||
class TestStreamingEncoderIntegration:
|
||||
def test_add_frame_save_episode_streaming(self, tmp_path):
|
||||
"""Full integration test: add_frame -> save_episode with streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/streaming",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "streaming_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is not None
|
||||
|
||||
num_frames = 20
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
# Verify dataset metadata
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == num_frames
|
||||
|
||||
# Verify stats exist for the video key
|
||||
assert dataset.meta.stats is not None
|
||||
assert "observation.images.cam" in dataset.meta.stats
|
||||
assert "action" in dataset.meta.stats
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_streaming_disabled_creates_pngs(self, tmp_path):
|
||||
"""Test that disabling streaming encoding falls back to PNG path."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (6,), "names": ["j1", "j2", "j3", "j4", "j5", "j6"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/no_streaming",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "no_streaming_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=False,
|
||||
)
|
||||
|
||||
assert dataset._streaming_encoder is None
|
||||
|
||||
num_frames = 5
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(6).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# With streaming disabled, PNG files should be written
|
||||
images_dir = dataset.root / "images"
|
||||
assert images_dir.exists()
|
||||
|
||||
dataset.save_episode()
|
||||
dataset.finalize()
|
||||
|
||||
def test_multi_episode_streaming(self, tmp_path):
|
||||
"""Test recording multiple episodes with streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/multi_ep",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "multi_ep_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
for ep in range(3):
|
||||
num_frames = 10 + ep * 5
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": f"task_{ep}",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 3
|
||||
assert dataset.meta.total_frames == 10 + 15 + 20
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_clear_episode_buffer_cancels_streaming(self, tmp_path):
|
||||
"""Test that clearing episode buffer cancels streaming encoding."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/cancel",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "cancel_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
# Add some frames
|
||||
for _ in range(5):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# Cancel and re-record
|
||||
dataset.clear_episode_buffer()
|
||||
|
||||
# Record a new episode
|
||||
for _ in range(10):
|
||||
frame = {
|
||||
"observation.images.cam": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == 10
|
||||
|
||||
dataset.finalize()
|
||||
|
||||
def test_multi_camera_add_frame_streaming(self, tmp_path):
|
||||
"""Test that start_episode is called once with multiple video keys."""
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
features = {
|
||||
"observation.images.cam1": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"observation.images.cam2": {
|
||||
"dtype": "video",
|
||||
"shape": (64, 96, 3),
|
||||
"names": ["height", "width", "channels"],
|
||||
},
|
||||
"action": {"dtype": "float32", "shape": (2,), "names": ["j1", "j2"]},
|
||||
}
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id="test/multi_cam",
|
||||
fps=30,
|
||||
features=features,
|
||||
root=tmp_path / "multi_cam_test",
|
||||
use_videos=True,
|
||||
streaming_encoding=True,
|
||||
)
|
||||
|
||||
num_frames = 15
|
||||
for _ in range(num_frames):
|
||||
frame = {
|
||||
"observation.images.cam1": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"observation.images.cam2": np.random.randint(0, 255, (64, 96, 3), dtype=np.uint8),
|
||||
"action": np.random.randn(2).astype(np.float32),
|
||||
"task": "test task",
|
||||
}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
dataset.save_episode()
|
||||
|
||||
assert dataset.meta.total_episodes == 1
|
||||
assert dataset.meta.total_frames == num_frames
|
||||
|
||||
dataset.finalize()
|
||||
+188
-207
@@ -14,8 +14,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import math
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
from torch import Tensor, nn
|
||||
@@ -23,6 +21,7 @@ from torch import Tensor, nn
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import MLP, SACPolicy
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import seeded_context, set_seed
|
||||
|
||||
@@ -138,41 +137,6 @@ def create_observation_batch_with_visual_input(batch_size: int = 8, state_dim: i
|
||||
}
|
||||
|
||||
|
||||
def make_optimizers(policy: SACPolicy, has_discrete_action: bool = False) -> dict[str, torch.optim.Optimizer]:
|
||||
"""Create optimizers for the SAC policy."""
|
||||
optimizer_actor = torch.optim.Adam(
|
||||
# Handle the case of shared encoder where the encoder weights are not optimized with the actor gradient
|
||||
params=[
|
||||
p
|
||||
for n, p in policy.actor.named_parameters()
|
||||
if not policy.config.shared_encoder or not n.startswith("encoder")
|
||||
],
|
||||
lr=policy.config.actor_lr,
|
||||
)
|
||||
optimizer_critic = torch.optim.Adam(
|
||||
params=policy.critic_ensemble.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
optimizer_temperature = torch.optim.Adam(
|
||||
params=[policy.log_alpha],
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
optimizers = {
|
||||
"actor": optimizer_actor,
|
||||
"critic": optimizer_critic,
|
||||
"temperature": optimizer_temperature,
|
||||
}
|
||||
|
||||
if has_discrete_action:
|
||||
optimizers["discrete_critic"] = torch.optim.Adam(
|
||||
params=policy.discrete_critic.parameters(),
|
||||
lr=policy.config.critic_lr,
|
||||
)
|
||||
|
||||
return optimizers
|
||||
|
||||
|
||||
def create_default_config(
|
||||
state_dim: int, continuous_action_dim: int, has_discrete_action: bool = False
|
||||
) -> SACConfig:
|
||||
@@ -212,7 +176,6 @@ def create_config_with_visual_input(
|
||||
"std": torch.randn(3, 1, 1),
|
||||
}
|
||||
|
||||
# Let make tests a little bit faster
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.latent_dim = 32
|
||||
|
||||
@@ -220,75 +183,112 @@ def create_config_with_visual_input(
|
||||
return config
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_default_config(batch_size: int, state_dim: int, action_dim: int):
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
|
||||
def _make_algorithm(config: SACConfig) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
"""Helper to create policy + algorithm pair for tests that need critics."""
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
return algorithm, policy
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_select_action(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
# squeeze(0) removes batch dim when batch_size==1
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
def test_sac_policy_select_action_with_discrete():
|
||||
"""select_action should return continuous + discrete actions."""
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch(batch_size=1, state_dim=10)
|
||||
# Squeeze to unbatched (single observation)
|
||||
observation_batch = {k: v.squeeze(0) for k, v in observation_batch.items()}
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape[-1] == 7 # 6 continuous + 1 discrete
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
def test_sac_policy_forward(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
with torch.no_grad():
|
||||
output = policy.forward(batch)
|
||||
assert "action" in output
|
||||
assert "log_prob" in output
|
||||
assert "action_mean" in output
|
||||
assert output["action"].shape == (batch_size, action_dim)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_through_algorithm(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_default_train_batch(batch_size=batch_size, action_dim=action_dim, state_dim=state_dim)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
temp_loss = algorithm._compute_loss_temperature(forward_batch)
|
||||
assert temp_loss.item() is not None
|
||||
assert temp_loss.shape == ()
|
||||
algorithm.optimizers["temperature"].zero_grad()
|
||||
temp_loss.backward()
|
||||
algorithm.optimizers["temperature"].step()
|
||||
|
||||
|
||||
@pytest.mark.parametrize("batch_size,state_dim,action_dim", [(2, 6, 6), (1, 10, 10)])
|
||||
def test_sac_training_with_visual_input(batch_size: int, state_dim: int, action_dim: int):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
assert temperature_loss.item() is not None
|
||||
assert temperature_loss.shape == ()
|
||||
|
||||
temperature_loss.backward()
|
||||
optimizers["temperature"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
@@ -296,207 +296,181 @@ def test_sac_policy_with_visual_input(batch_size: int, state_dim: int, action_di
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, action_dim)
|
||||
assert selected_action.shape[-1] == action_dim
|
||||
|
||||
|
||||
# Let's check best candidates for pretrained encoders
|
||||
@pytest.mark.parametrize(
|
||||
"batch_size,state_dim,action_dim,vision_encoder_name",
|
||||
[(1, 6, 6, "helper2424/resnet10"), (1, 6, 6, "facebook/convnext-base-224")],
|
||||
)
|
||||
@pytest.mark.skipif(not TRANSFORMERS_AVAILABLE, reason="Transformers are not installed")
|
||||
def test_sac_policy_with_pretrained_encoder(
|
||||
def test_sac_training_with_pretrained_encoder(
|
||||
batch_size: int, state_dim: int, action_dim: int, vision_encoder_name: str
|
||||
):
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.vision_encoder_name = vision_encoder_name
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.item() is not None
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.item() is not None
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
|
||||
def test_sac_policy_with_shared_encoder():
|
||||
def test_sac_training_with_shared_encoder():
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.shared_encoder = True
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
|
||||
def test_sac_policy_with_discrete_critic():
|
||||
def test_sac_training_with_discrete_critic():
|
||||
batch_size = 2
|
||||
continuous_action_dim = 9
|
||||
full_action_dim = continuous_action_dim + 1 # the last action is discrete
|
||||
full_action_dim = continuous_action_dim + 1
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(
|
||||
state_dim=state_dim, continuous_action_dim=continuous_action_dim, has_discrete_action=True
|
||||
)
|
||||
config.num_discrete_actions = 5
|
||||
|
||||
num_discrete_actions = 5
|
||||
config.num_discrete_actions = num_discrete_actions
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=full_action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
optimizers = make_optimizers(policy, has_discrete_action=True)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
|
||||
discrete_critic_loss = policy.forward(batch, model="discrete_critic")["loss_discrete_critic"]
|
||||
assert discrete_critic_loss.item() is not None
|
||||
discrete_critic_loss = algorithm._compute_loss_discrete_critic(forward_batch)
|
||||
assert discrete_critic_loss.shape == ()
|
||||
algorithm.optimizers["discrete_critic"].zero_grad()
|
||||
discrete_critic_loss.backward()
|
||||
optimizers["discrete_critic"].step()
|
||||
algorithm.optimizers["discrete_critic"].step()
|
||||
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
assert actor_loss.item() is not None
|
||||
actor_loss = algorithm._compute_loss_actor(forward_batch)
|
||||
assert actor_loss.shape == ()
|
||||
|
||||
algorithm.optimizers["actor"].zero_grad()
|
||||
actor_loss.backward()
|
||||
optimizers["actor"].step()
|
||||
algorithm.optimizers["actor"].step()
|
||||
|
||||
policy.eval()
|
||||
with torch.no_grad():
|
||||
observation_batch = create_observation_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim
|
||||
)
|
||||
selected_action = policy.select_action(observation_batch)
|
||||
assert selected_action.shape == (batch_size, full_action_dim)
|
||||
|
||||
discrete_actions = selected_action[:, -1].long()
|
||||
discrete_action_values = set(discrete_actions.tolist())
|
||||
|
||||
assert all(action in range(num_discrete_actions) for action in discrete_action_values), (
|
||||
f"Discrete action {discrete_action_values} is not in range({num_discrete_actions})"
|
||||
)
|
||||
# Policy.select_action now handles both continuous + discrete
|
||||
selected_action = policy.select_action({k: v.squeeze(0) for k, v in observation_batch.items()})
|
||||
assert selected_action.shape[-1] == continuous_action_dim + 1
|
||||
|
||||
|
||||
def test_sac_policy_with_default_entropy():
|
||||
def test_sac_algorithm_target_entropy():
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -5.0
|
||||
_, policy = _make_algorithm(config)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -5.0
|
||||
|
||||
|
||||
def test_sac_policy_default_target_entropy_with_discrete_action():
|
||||
def test_sac_algorithm_target_entropy_with_discrete_action():
|
||||
config = create_config_with_visual_input(state_dim=10, continuous_action_dim=6, has_discrete_action=True)
|
||||
config.num_discrete_actions = 5
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == -3.0
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
assert algorithm.target_entropy == -3.5
|
||||
|
||||
|
||||
def test_sac_policy_with_predefined_entropy():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.target_entropy = -3.5
|
||||
def test_sac_algorithm_temperature():
|
||||
import math
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
assert policy.target_entropy == pytest.approx(-3.5)
|
||||
|
||||
|
||||
def test_sac_policy_update_temperature():
|
||||
"""Test that temperature property is always in sync with log_alpha."""
|
||||
config = create_default_config(continuous_action_dim=10, state_dim=10)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
assert policy.temperature == pytest.approx(1.0)
|
||||
policy.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
# Temperature property automatically reflects log_alpha changes
|
||||
assert policy.temperature == pytest.approx(0.1)
|
||||
assert algorithm.temperature == pytest.approx(1.0)
|
||||
algorithm.log_alpha.data = torch.tensor([math.log(0.1)])
|
||||
assert algorithm.temperature == pytest.approx(0.1)
|
||||
|
||||
|
||||
def test_sac_policy_update_target_network():
|
||||
def test_sac_algorithm_update_target_network():
|
||||
config = create_default_config(state_dim=10, continuous_action_dim=6)
|
||||
config.critic_target_update_weight = 1.0
|
||||
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(config)
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
|
||||
for p in policy.critic_ensemble.parameters():
|
||||
for p in algorithm.critic_ensemble.parameters():
|
||||
p.data = torch.ones_like(p.data)
|
||||
|
||||
policy.update_target_networks()
|
||||
for p in policy.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data)), (
|
||||
f"Target network {p.data} is not equal to {torch.ones_like(p.data)}"
|
||||
)
|
||||
algorithm._update_target_networks()
|
||||
for p in algorithm.critic_target.parameters():
|
||||
assert torch.allclose(p.data, torch.ones_like(p.data))
|
||||
|
||||
|
||||
@pytest.mark.parametrize("num_critics", [1, 3])
|
||||
def test_sac_policy_with_critics_number_of_heads(num_critics: int):
|
||||
def test_sac_algorithm_with_critics_number_of_heads(num_critics: int):
|
||||
batch_size = 2
|
||||
action_dim = 10
|
||||
state_dim = 10
|
||||
config = create_config_with_visual_input(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_critics = num_critics
|
||||
|
||||
policy = SACPolicy(config=config)
|
||||
policy.train()
|
||||
algorithm, policy = _make_algorithm(config)
|
||||
|
||||
assert len(policy.critic_ensemble.critics) == num_critics
|
||||
assert len(algorithm.critic_ensemble.critics) == num_critics
|
||||
|
||||
batch = create_train_batch_with_visual_input(
|
||||
batch_size=batch_size, state_dim=state_dim, action_dim=action_dim
|
||||
)
|
||||
forward_batch = algorithm._prepare_forward_batch(batch)
|
||||
|
||||
policy.train()
|
||||
|
||||
optimizers = make_optimizers(policy)
|
||||
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
assert cirtic_loss.item() is not None
|
||||
assert cirtic_loss.shape == ()
|
||||
cirtic_loss.backward()
|
||||
optimizers["critic"].step()
|
||||
critic_loss = algorithm._compute_loss_critic(forward_batch)
|
||||
assert critic_loss.shape == ()
|
||||
algorithm.optimizers["critic"].zero_grad()
|
||||
critic_loss.backward()
|
||||
algorithm.optimizers["critic"].step()
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load(tmp_path):
|
||||
"""Test that the policy can be saved and loaded from pretrained."""
|
||||
root = tmp_path / "test_sac_save_and_load"
|
||||
|
||||
state_dim = 10
|
||||
@@ -510,34 +484,41 @@ def test_sac_policy_save_and_load(tmp_path):
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
batch = create_default_train_batch(batch_size=1, state_dim=10, action_dim=10)
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
with torch.no_grad():
|
||||
with seeded_context(12):
|
||||
# Collect policy values before saving
|
||||
cirtic_loss = policy.forward(batch, model="critic")["loss_critic"]
|
||||
actor_loss = policy.forward(batch, model="actor")["loss_actor"]
|
||||
temperature_loss = policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
actions = policy.select_action(observation_batch)
|
||||
|
||||
with seeded_context(12):
|
||||
# Collect policy values after loading
|
||||
loaded_cirtic_loss = loaded_policy.forward(batch, model="critic")["loss_critic"]
|
||||
loaded_actor_loss = loaded_policy.forward(batch, model="actor")["loss_actor"]
|
||||
loaded_temperature_loss = loaded_policy.forward(batch, model="temperature")["loss_temperature"]
|
||||
|
||||
loaded_observation_batch = create_observation_batch(batch_size=batch_size, state_dim=state_dim)
|
||||
loaded_actions = loaded_policy.select_action(loaded_observation_batch)
|
||||
|
||||
assert policy.state_dict().keys() == loaded_policy.state_dict().keys()
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
# Compare values before and after saving and loading
|
||||
# They should be the same
|
||||
assert torch.allclose(cirtic_loss, loaded_cirtic_loss)
|
||||
assert torch.allclose(actor_loss, loaded_actor_loss)
|
||||
assert torch.allclose(temperature_loss, loaded_temperature_loss)
|
||||
assert torch.allclose(actions, loaded_actions)
|
||||
|
||||
|
||||
def test_sac_policy_save_and_load_with_discrete_critic(tmp_path):
|
||||
"""Discrete critic should be saved/loaded as part of the policy."""
|
||||
root = tmp_path / "test_sac_save_and_load_discrete"
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
config = create_default_config(state_dim=state_dim, continuous_action_dim=action_dim)
|
||||
config.num_discrete_actions = 3
|
||||
policy = SACPolicy(config=config)
|
||||
policy.eval()
|
||||
policy.save_pretrained(root)
|
||||
|
||||
loaded_policy = SACPolicy.from_pretrained(root, config=config)
|
||||
loaded_policy.eval()
|
||||
|
||||
assert loaded_policy.discrete_critic is not None
|
||||
dc_keys = [k for k in loaded_policy.state_dict() if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
for k in policy.state_dict():
|
||||
assert torch.allclose(policy.state_dict()[k], loaded_policy.state_dict()[k], atol=1e-6)
|
||||
|
||||
@@ -23,8 +23,9 @@ import torch
|
||||
from torch.multiprocessing import Event, Queue
|
||||
|
||||
from lerobot.configs.train import TrainRLServerPipelineConfig
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.utils.constants import OBS_STR
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE, OBS_STR
|
||||
from lerobot.utils.transition import Transition
|
||||
from tests.utils import require_package
|
||||
|
||||
@@ -296,3 +297,172 @@ def test_end_to_end_parameters_flow(cfg, data_size):
|
||||
assert received_params.keys() == input_params.keys()
|
||||
for key in input_params:
|
||||
assert torch.allclose(received_params[key], input_params[key])
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Regression test: learner algorithm integration (no gRPC required)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def test_learner_algorithm_wiring():
|
||||
"""Verify that make_algorithm constructs an SACAlgorithm from config,
|
||||
make_optimizers() creates the right optimizers, update() works, and
|
||||
get_weights() output is serializable."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
from lerobot.transport.utils import state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
|
||||
optimizers = algorithm.make_optimizers()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
assert "temperature" in optimizers
|
||||
|
||||
batch_size = 4
|
||||
|
||||
def batch_iterator():
|
||||
while True:
|
||||
yield {
|
||||
ACTION: torch.randn(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"next_state": {OBS_STATE: torch.randn(batch_size, state_dim)},
|
||||
"done": torch.zeros(batch_size),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
stats = algorithm.update(batch_iterator())
|
||||
assert "critic" in stats.losses
|
||||
|
||||
# get_weights -> state_to_bytes round-trip
|
||||
weights = algorithm.get_weights()
|
||||
assert len(weights) > 0
|
||||
serialized = state_to_bytes(weights)
|
||||
assert isinstance(serialized, bytes)
|
||||
assert len(serialized) > 0
|
||||
|
||||
# RLTrainer with DataMixer
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
|
||||
replay_buffer = ReplayBuffer(
|
||||
capacity=50,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for _ in range(50):
|
||||
replay_buffer.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(action_dim),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=False,
|
||||
truncated=False,
|
||||
)
|
||||
data_mixer = OnlineOfflineMixer(online_buffer=replay_buffer, offline_buffer=None)
|
||||
trainer = RLTrainer(
|
||||
algorithm=algorithm,
|
||||
data_mixer=data_mixer,
|
||||
batch_size=batch_size,
|
||||
async_prefetch=False,
|
||||
)
|
||||
trainer_stats = trainer.training_step()
|
||||
assert "critic" in trainer_stats.losses
|
||||
|
||||
|
||||
def test_initial_and_periodic_weight_push_consistency():
|
||||
"""Both initial and periodic weight pushes should use algorithm.get_weights()
|
||||
and produce identical structures."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.transport.utils import bytes_to_state_dict, state_to_bytes
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
algorithm.make_optimizers()
|
||||
|
||||
# Simulate initial push (same code path the learner now uses)
|
||||
initial_weights = algorithm.get_weights()
|
||||
initial_bytes = state_to_bytes(initial_weights)
|
||||
|
||||
# Simulate periodic push
|
||||
periodic_weights = algorithm.get_weights()
|
||||
periodic_bytes = state_to_bytes(periodic_weights)
|
||||
|
||||
initial_decoded = bytes_to_state_dict(initial_bytes)
|
||||
periodic_decoded = bytes_to_state_dict(periodic_bytes)
|
||||
|
||||
assert initial_decoded.keys() == periodic_decoded.keys()
|
||||
|
||||
|
||||
def test_actor_side_algorithm_select_action_and_load_weights():
|
||||
"""Simulate actor: create algorithm without optimizers, select_action, load_weights."""
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm
|
||||
|
||||
state_dim = 10
|
||||
action_dim = 6
|
||||
sac_cfg = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
use_torch_compile=False,
|
||||
)
|
||||
sac_cfg.validate_features()
|
||||
|
||||
# Actor side: no optimizers
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.eval()
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
# select_action should work
|
||||
obs = {OBS_STATE: torch.randn(state_dim)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
# Simulate receiving weights from learner
|
||||
fake_weights = algorithm.get_weights()
|
||||
algorithm.load_weights(fake_weights, device="cpu")
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RL data mixing (DataMixer, OnlineOfflineMixer)."""
|
||||
|
||||
import torch
|
||||
|
||||
from lerobot.rl.buffer import ReplayBuffer
|
||||
from lerobot.rl.data_sources import OnlineOfflineMixer
|
||||
from lerobot.utils.constants import OBS_STATE
|
||||
|
||||
|
||||
def _make_buffer(capacity: int = 100, state_dim: int = 4) -> ReplayBuffer:
|
||||
buf = ReplayBuffer(
|
||||
capacity=capacity,
|
||||
device="cpu",
|
||||
state_keys=[OBS_STATE],
|
||||
storage_device="cpu",
|
||||
use_drq=False,
|
||||
)
|
||||
for i in range(capacity):
|
||||
buf.add(
|
||||
state={OBS_STATE: torch.randn(state_dim)},
|
||||
action=torch.randn(2),
|
||||
reward=1.0,
|
||||
next_state={OBS_STATE: torch.randn(state_dim)},
|
||||
done=bool(i % 10 == 9),
|
||||
truncated=False,
|
||||
)
|
||||
return buf
|
||||
|
||||
|
||||
def test_online_only_mixer_sample():
|
||||
"""OnlineOfflineMixer with no offline buffer returns online-only batches."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=0.5)
|
||||
batch = mixer.sample(batch_size=8)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 8
|
||||
assert batch["action"].shape[0] == 8
|
||||
assert batch["reward"].shape[0] == 8
|
||||
|
||||
|
||||
def test_online_only_mixer_ratio_one():
|
||||
"""OnlineOfflineMixer with online_ratio=1.0 and no offline is equivalent to online-only."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None, online_ratio=1.0)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_sample():
|
||||
"""OnlineOfflineMixer with two buffers returns concatenated batches."""
|
||||
online = _make_buffer(capacity=50)
|
||||
offline = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(
|
||||
online_buffer=online,
|
||||
offline_buffer=offline,
|
||||
online_ratio=0.5,
|
||||
)
|
||||
batch = mixer.sample(batch_size=10)
|
||||
assert batch["state"][OBS_STATE].shape[0] == 10
|
||||
assert batch["action"].shape[0] == 10
|
||||
# 5 from online, 5 from offline (approx)
|
||||
assert batch["reward"].shape[0] == 10
|
||||
|
||||
|
||||
def test_online_offline_mixer_iterator():
|
||||
"""get_iterator yields batches of the requested size."""
|
||||
buf = _make_buffer(capacity=50)
|
||||
mixer = OnlineOfflineMixer(online_buffer=buf, offline_buffer=None)
|
||||
it = mixer.get_iterator(batch_size=4, async_prefetch=False)
|
||||
batch1 = next(it)
|
||||
batch2 = next(it)
|
||||
assert batch1["state"][OBS_STATE].shape[0] == 4
|
||||
assert batch2["state"][OBS_STATE].shape[0] == 4
|
||||
@@ -0,0 +1,477 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for the RL algorithm abstraction and SACAlgorithm implementation."""
|
||||
|
||||
import pytest
|
||||
import torch
|
||||
|
||||
from lerobot.configs.types import FeatureType, PolicyFeature
|
||||
from lerobot.policies.sac.configuration_sac import SACConfig
|
||||
from lerobot.policies.sac.modeling_sac import SACPolicy
|
||||
from lerobot.rl.algorithms import make_algorithm
|
||||
from lerobot.rl.algorithms.base import RLAlgorithmConfig, TrainingStats
|
||||
from lerobot.rl.algorithms.sac import SACAlgorithm, SACAlgorithmConfig
|
||||
from lerobot.utils.constants import ACTION, OBS_IMAGE, OBS_STATE
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers (reuse patterns from tests/policies/test_sac_policy.py)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def set_random_seed():
|
||||
set_seed(42)
|
||||
|
||||
|
||||
def _make_sac_config(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
num_discrete_actions: int | None = None,
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
with_images: bool = False,
|
||||
) -> SACConfig:
|
||||
config = SACConfig(
|
||||
input_features={OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(state_dim,))},
|
||||
output_features={ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(action_dim,))},
|
||||
dataset_stats={
|
||||
OBS_STATE: {"min": [0.0] * state_dim, "max": [1.0] * state_dim},
|
||||
ACTION: {"min": [0.0] * action_dim, "max": [1.0] * action_dim},
|
||||
},
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
use_torch_compile=False,
|
||||
)
|
||||
if with_images:
|
||||
config.input_features[OBS_IMAGE] = PolicyFeature(type=FeatureType.VISUAL, shape=(3, 84, 84))
|
||||
config.dataset_stats[OBS_IMAGE] = {
|
||||
"mean": torch.randn(3, 1, 1).tolist(),
|
||||
"std": torch.randn(3, 1, 1).abs().tolist(),
|
||||
}
|
||||
config.latent_dim = 32
|
||||
config.state_encoder_hidden_dim = 32
|
||||
config.validate_features()
|
||||
return config
|
||||
|
||||
|
||||
def _make_algorithm(
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
utd_ratio: int = 1,
|
||||
policy_update_freq: int = 1,
|
||||
num_discrete_actions: int | None = None,
|
||||
with_images: bool = False,
|
||||
) -> tuple[SACAlgorithm, SACPolicy]:
|
||||
sac_cfg = _make_sac_config(
|
||||
state_dim=state_dim,
|
||||
action_dim=action_dim,
|
||||
utd_ratio=utd_ratio,
|
||||
policy_update_freq=policy_update_freq,
|
||||
num_discrete_actions=num_discrete_actions,
|
||||
with_images=with_images,
|
||||
)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
policy.train()
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
algorithm = SACAlgorithm(policy=policy, config=algo_config)
|
||||
algorithm.make_optimizers()
|
||||
return algorithm, policy
|
||||
|
||||
|
||||
def _make_batch(
|
||||
batch_size: int = 4,
|
||||
state_dim: int = 10,
|
||||
action_dim: int = 6,
|
||||
with_images: bool = False,
|
||||
) -> dict:
|
||||
obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
|
||||
next_obs = {OBS_STATE: torch.randn(batch_size, state_dim)}
|
||||
if with_images:
|
||||
obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
|
||||
next_obs[OBS_IMAGE] = torch.randn(batch_size, 3, 84, 84)
|
||||
return {
|
||||
ACTION: torch.randn(batch_size, action_dim),
|
||||
"reward": torch.randn(batch_size),
|
||||
"state": obs,
|
||||
"next_state": next_obs,
|
||||
"done": torch.zeros(batch_size),
|
||||
"complementary_info": {},
|
||||
}
|
||||
|
||||
|
||||
def _batch_iterator(**batch_kwargs):
|
||||
"""Infinite iterator that yields fresh batches (mirrors a real DataMixer iterator)."""
|
||||
while True:
|
||||
yield _make_batch(**batch_kwargs)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Registry / config tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_sac_algorithm_config_registered():
|
||||
"""SACAlgorithmConfig should be discoverable through the registry."""
|
||||
assert "sac" in RLAlgorithmConfig.get_known_choices()
|
||||
cls = RLAlgorithmConfig.get_choice_class("sac")
|
||||
assert cls is SACAlgorithmConfig
|
||||
|
||||
|
||||
def test_sac_algorithm_config_from_policy_config():
|
||||
"""from_policy_config should copy relevant fields."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=4, policy_update_freq=2)
|
||||
algo_cfg = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
assert algo_cfg.utd_ratio == 4
|
||||
assert algo_cfg.policy_update_freq == 2
|
||||
assert algo_cfg.clip_grad_norm == sac_cfg.grad_clip_norm
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TrainingStats tests
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_training_stats_defaults():
|
||||
stats = TrainingStats()
|
||||
assert stats.losses == {}
|
||||
assert stats.grad_norms == {}
|
||||
assert stats.extra == {}
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# get_weights
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_get_weights_returns_policy_state_dict():
|
||||
algorithm, policy = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key in policy.state_dict():
|
||||
assert key in weights
|
||||
assert torch.equal(weights[key].cpu(), policy.state_dict()[key].cpu())
|
||||
|
||||
|
||||
def test_get_weights_includes_discrete_critic_when_present():
|
||||
algorithm, policy = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
|
||||
|
||||
def test_get_weights_excludes_discrete_critic_when_absent():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) == 0
|
||||
|
||||
|
||||
def test_get_weights_are_on_cpu():
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
for key, tensor in weights.items():
|
||||
assert tensor.device == torch.device("cpu"), f"{key} is not on CPU"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# select_action (lives on the policy, not the algorithm)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_select_action_returns_correct_shape():
|
||||
action_dim = 6
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=action_dim)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (action_dim,)
|
||||
|
||||
|
||||
def test_select_action_with_discrete_critic():
|
||||
continuous_dim = 5
|
||||
_, policy = _make_algorithm(state_dim=10, action_dim=continuous_dim, num_discrete_actions=3)
|
||||
policy.eval()
|
||||
obs = {OBS_STATE: torch.randn(10)}
|
||||
action = policy.select_action(obs)
|
||||
assert action.shape == (continuous_dim + 1,)
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# update (single batch, utd_ratio=1)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_update_returns_training_stats():
|
||||
algorithm, _ = _make_algorithm()
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert isinstance(stats, TrainingStats)
|
||||
assert "critic" in stats.losses
|
||||
assert isinstance(stats.losses["critic"], float)
|
||||
|
||||
|
||||
def test_update_populates_actor_and_temperature_losses():
|
||||
"""With policy_update_freq=1 and step 0, actor/temperature should be updated."""
|
||||
algorithm, _ = _make_algorithm(policy_update_freq=1)
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert "actor" in stats.losses
|
||||
assert "temperature" in stats.losses
|
||||
assert "temperature" in stats.extra
|
||||
|
||||
|
||||
@pytest.mark.parametrize("policy_update_freq", [2, 3])
|
||||
def test_update_skips_actor_at_non_update_steps(policy_update_freq):
|
||||
"""Actor/temperature should only update when optimization_step % freq == 0."""
|
||||
algorithm, _ = _make_algorithm(policy_update_freq=policy_update_freq)
|
||||
it = _batch_iterator()
|
||||
|
||||
# Step 0: should update actor
|
||||
stats_0 = algorithm.update(it)
|
||||
assert "actor" in stats_0.losses
|
||||
|
||||
# Step 1: should NOT update actor
|
||||
stats_1 = algorithm.update(it)
|
||||
assert "actor" not in stats_1.losses
|
||||
|
||||
|
||||
def test_update_increments_optimization_step():
|
||||
algorithm, _ = _make_algorithm()
|
||||
it = _batch_iterator()
|
||||
assert algorithm.optimization_step == 0
|
||||
algorithm.update(it)
|
||||
assert algorithm.optimization_step == 1
|
||||
algorithm.update(it)
|
||||
assert algorithm.optimization_step == 2
|
||||
|
||||
|
||||
def test_update_with_discrete_critic():
|
||||
algorithm, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
stats = algorithm.update(_batch_iterator(action_dim=7)) # continuous + 1 discrete
|
||||
assert "discrete_critic" in stats.losses
|
||||
assert "discrete_critic" in stats.grad_norms
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# update with UTD ratio > 1
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
@pytest.mark.parametrize("utd_ratio", [2, 4])
|
||||
def test_update_with_utd_ratio(utd_ratio):
|
||||
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
||||
stats = algorithm.update(_batch_iterator())
|
||||
assert isinstance(stats, TrainingStats)
|
||||
assert "critic" in stats.losses
|
||||
assert algorithm.optimization_step == 1
|
||||
|
||||
|
||||
def test_update_utd_ratio_pulls_utd_batches():
|
||||
"""next(batch_iterator) should be called exactly utd_ratio times."""
|
||||
utd_ratio = 3
|
||||
algorithm, _ = _make_algorithm(utd_ratio=utd_ratio)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def counting_iterator():
|
||||
nonlocal call_count
|
||||
while True:
|
||||
call_count += 1
|
||||
yield _make_batch()
|
||||
|
||||
algorithm.update(counting_iterator())
|
||||
assert call_count == utd_ratio
|
||||
|
||||
|
||||
def test_update_utd_ratio_3_critic_warmup_changes_weights():
|
||||
"""With utd_ratio=3, critic weights should change after update (3 critic steps)."""
|
||||
algorithm, policy = _make_algorithm(utd_ratio=3)
|
||||
|
||||
critic_params_before = {n: p.clone() for n, p in algorithm.critic_ensemble.named_parameters()}
|
||||
|
||||
algorithm.update(_batch_iterator())
|
||||
|
||||
changed = False
|
||||
for n, p in algorithm.critic_ensemble.named_parameters():
|
||||
if not torch.equal(p, critic_params_before[n]):
|
||||
changed = True
|
||||
break
|
||||
assert changed, "Critic weights should have changed after UTD update"
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# get_observation_features
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_get_observation_features_returns_none_without_frozen_encoder():
|
||||
algorithm, _ = _make_algorithm(with_images=False)
|
||||
obs = {OBS_STATE: torch.randn(4, 10)}
|
||||
next_obs = {OBS_STATE: torch.randn(4, 10)}
|
||||
feat, next_feat = algorithm.get_observation_features(obs, next_obs)
|
||||
assert feat is None
|
||||
assert next_feat is None
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# optimization_step setter
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_optimization_step_can_be_set_for_resume():
|
||||
algorithm, _ = _make_algorithm()
|
||||
algorithm.optimization_step = 100
|
||||
assert algorithm.optimization_step == 100
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# make_algorithm factory
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_make_algorithm_returns_sac_for_sac_policy():
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
def test_make_optimizers_creates_expected_keys():
|
||||
"""make_optimizers() should populate the algorithm with Adam optimizers."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
optimizers = algorithm.make_optimizers()
|
||||
assert "actor" in optimizers
|
||||
assert "critic" in optimizers
|
||||
assert "temperature" in optimizers
|
||||
assert all(isinstance(v, torch.optim.Adam) for v in optimizers.values())
|
||||
assert algorithm.get_optimizers() is optimizers
|
||||
|
||||
|
||||
def test_actor_side_no_optimizers():
|
||||
"""Actor-side usage: no optimizers needed, make_optimizers is not called."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.optimizers == {}
|
||||
|
||||
|
||||
def test_make_algorithm_copies_config_fields():
|
||||
sac_cfg = _make_sac_config(utd_ratio=5, policy_update_freq=3)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert algorithm.config.utd_ratio == 5
|
||||
assert algorithm.config.policy_update_freq == 3
|
||||
|
||||
|
||||
def test_make_algorithm_raises_for_unknown_type():
|
||||
class FakeConfig:
|
||||
type = "unknown_algo"
|
||||
|
||||
with pytest.raises(ValueError, match="No RLAlgorithmConfig"):
|
||||
make_algorithm(policy=None, policy_cfg=FakeConfig(), algorithm_name="unknown_algo")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# load_weights (round-trip with get_weights)
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_load_weights_round_trip():
|
||||
"""get_weights -> load_weights should restore identical parameters on a fresh policy."""
|
||||
algo_src, _ = _make_algorithm(state_dim=10, action_dim=6)
|
||||
algo_src.update(_batch_iterator())
|
||||
|
||||
sac_cfg = _make_sac_config(state_dim=10, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
for key in weights:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Policy param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_round_trip_with_discrete_critic():
|
||||
algo_src, _ = _make_algorithm(num_discrete_actions=3, action_dim=6)
|
||||
algo_src.update(_batch_iterator(action_dim=7))
|
||||
|
||||
sac_cfg = _make_sac_config(num_discrete_actions=3, action_dim=6)
|
||||
policy_dst = SACPolicy(config=sac_cfg)
|
||||
algo_dst = SACAlgorithm(policy=policy_dst, config=algo_src.config)
|
||||
|
||||
weights = algo_src.get_weights()
|
||||
algo_dst.load_weights(weights, device="cpu")
|
||||
|
||||
dc_keys = [k for k in weights if k.startswith("discrete_critic.")]
|
||||
assert len(dc_keys) > 0
|
||||
for key in dc_keys:
|
||||
assert torch.equal(
|
||||
algo_dst.policy.state_dict()[key].cpu(),
|
||||
weights[key].cpu(),
|
||||
), f"Discrete critic param '{key}' mismatch after load_weights"
|
||||
|
||||
|
||||
def test_load_weights_ignores_missing_discrete_critic():
|
||||
"""load_weights should not fail when weights lack discrete_critic on a non-discrete policy."""
|
||||
algorithm, _ = _make_algorithm()
|
||||
weights = algorithm.get_weights()
|
||||
algorithm.load_weights(weights, device="cpu")
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# TrainingStats generic losses dict
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_training_stats_generic_losses():
|
||||
stats = TrainingStats(
|
||||
losses={"loss_bc": 0.5, "loss_q": 1.2},
|
||||
extra={"temperature": 0.1},
|
||||
)
|
||||
assert stats.losses["loss_bc"] == 0.5
|
||||
assert stats.losses["loss_q"] == 1.2
|
||||
assert stats.extra["temperature"] == 0.1
|
||||
|
||||
|
||||
# ===========================================================================
|
||||
# Registry-driven build_algorithm
|
||||
# ===========================================================================
|
||||
|
||||
|
||||
def test_build_algorithm_via_config():
|
||||
"""SACAlgorithmConfig.build_algorithm should produce a working SACAlgorithm."""
|
||||
sac_cfg = _make_sac_config(utd_ratio=2)
|
||||
algo_config = SACAlgorithmConfig.from_policy_config(sac_cfg)
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
|
||||
algorithm = algo_config.build_algorithm(policy)
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
assert algorithm.config.utd_ratio == 2
|
||||
|
||||
|
||||
def test_make_algorithm_uses_build_algorithm():
|
||||
"""make_algorithm should delegate to config.build_algorithm (no hardcoded if/else)."""
|
||||
sac_cfg = _make_sac_config()
|
||||
policy = SACPolicy(config=sac_cfg)
|
||||
algorithm = make_algorithm(policy=policy, policy_cfg=sac_cfg, algorithm_name="sac")
|
||||
assert isinstance(algorithm, SACAlgorithm)
|
||||
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import torch
|
||||
from torch import Tensor
|
||||
|
||||
from lerobot.rl.algorithms.base import RLAlgorithm, TrainingStats
|
||||
from lerobot.rl.trainer import RLTrainer
|
||||
from lerobot.utils.constants import ACTION, OBS_STATE
|
||||
|
||||
|
||||
class _CountingAlgorithm(RLAlgorithm):
|
||||
def __init__(self):
|
||||
self.configure_calls = 0
|
||||
self.update_calls = 0
|
||||
|
||||
def select_action(self, observation: dict[str, Tensor]) -> Tensor:
|
||||
return torch.zeros(1)
|
||||
|
||||
def configure_data_iterator(
|
||||
self,
|
||||
data_mixer,
|
||||
batch_size: int,
|
||||
*,
|
||||
async_prefetch: bool = True,
|
||||
queue_size: int = 2,
|
||||
):
|
||||
self.configure_calls += 1
|
||||
return data_mixer.get_iterator(
|
||||
batch_size=batch_size,
|
||||
async_prefetch=async_prefetch,
|
||||
queue_size=queue_size,
|
||||
)
|
||||
|
||||
def make_optimizers(self):
|
||||
return {}
|
||||
|
||||
def update(self, batch_iterator):
|
||||
self.update_calls += 1
|
||||
_ = next(batch_iterator)
|
||||
return TrainingStats(losses={"dummy": 1.0})
|
||||
|
||||
def load_weights(self, weights, device="cpu") -> None:
|
||||
_ = (weights, device)
|
||||
|
||||
|
||||
class _SimpleMixer:
|
||||
def get_iterator(self, batch_size: int, async_prefetch: bool = True, queue_size: int = 2):
|
||||
_ = (async_prefetch, queue_size)
|
||||
while True:
|
||||
yield {
|
||||
"state": {OBS_STATE: torch.randn(batch_size, 3)},
|
||||
ACTION: torch.randn(batch_size, 2),
|
||||
"reward": torch.randn(batch_size),
|
||||
"next_state": {OBS_STATE: torch.randn(batch_size, 3)},
|
||||
"done": torch.zeros(batch_size),
|
||||
"truncated": torch.zeros(batch_size),
|
||||
"complementary_info": None,
|
||||
}
|
||||
|
||||
|
||||
def test_trainer_lazy_iterator_lifecycle_and_reset():
|
||||
algo = _CountingAlgorithm()
|
||||
mixer = _SimpleMixer()
|
||||
trainer = RLTrainer(algorithm=algo, data_mixer=mixer, batch_size=4, async_prefetch=False)
|
||||
|
||||
# First call builds iterator once.
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
assert algo.update_calls == 1
|
||||
|
||||
# Second call reuses existing iterator.
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
assert algo.update_calls == 2
|
||||
|
||||
# Explicit reset forces lazy rebuild on next step.
|
||||
trainer.reset_data_iterator()
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 2
|
||||
assert algo.update_calls == 3
|
||||
|
||||
|
||||
def test_trainer_set_data_mixer_resets_by_default():
|
||||
algo = _CountingAlgorithm()
|
||||
mixer_a = _SimpleMixer()
|
||||
mixer_b = _SimpleMixer()
|
||||
trainer = RLTrainer(algorithm=algo, data_mixer=mixer_a, batch_size=2, async_prefetch=False)
|
||||
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 1
|
||||
|
||||
trainer.set_data_mixer(mixer_b, reset=True)
|
||||
trainer.training_step()
|
||||
assert algo.configure_calls == 2
|
||||
|
||||
|
||||
def test_algorithm_optimization_step_contract_defaults():
|
||||
algo = _CountingAlgorithm()
|
||||
assert algo.optimization_step == 0
|
||||
algo.optimization_step = 11
|
||||
assert algo.optimization_step == 11
|
||||
Reference in New Issue
Block a user