mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 03:59:42 +00:00
Compare commits
5 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| c0a2e9814d | |||
| bac4f61eae | |||
| f4b834844e | |||
| dfdc48a7f1 | |||
| 6a8878a639 |
+6
-10
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
|
|||||||
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
--robot.type=so100_follower \
|
--strategy.type=base \
|
||||||
|
--policy.path=${HF_USER}/act_policy \
|
||||||
|
--robot.type=so101_follower \
|
||||||
--robot.port=/dev/ttyACM0 \
|
--robot.port=/dev/ttyACM0 \
|
||||||
--robot.id=my_robot \
|
|
||||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
--display_data=true \
|
--display_data=true \
|
||||||
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \
|
--task="Your task description" \ # can be skipped for ACT
|
||||||
--dataset.num_episodes=10 \
|
--duration=60
|
||||||
--dataset.single_task="Your task description" \
|
|
||||||
--dataset.streaming_encoding=true \
|
|
||||||
--dataset.encoder_threads=2 \
|
|
||||||
# --dataset.camera_encoder.vcodec=auto \
|
|
||||||
--policy.path=${HF_USER}/act_policy
|
|
||||||
```
|
```
|
||||||
|
|||||||
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
|
|||||||
|
|
||||||
### Evaluate in your hardware setup
|
### Evaluate in your hardware setup
|
||||||
|
|
||||||
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example:
|
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout\
|
||||||
|
--strategy.type=sentry \
|
||||||
|
--strategy.upload_every_n_episodes=5 \
|
||||||
--robot.type=bi_so_follower \
|
--robot.type=bi_so_follower \
|
||||||
--robot.left_arm_port=/dev/ttyACM1 \
|
--robot.left_arm_port=/dev/ttyACM1 \
|
||||||
--robot.right_arm_port=/dev/ttyACM0 \
|
--robot.right_arm_port=/dev/ttyACM0 \
|
||||||
@@ -119,14 +121,12 @@ lerobot-record \
|
|||||||
}' \
|
}' \
|
||||||
--display_data=true \
|
--display_data=true \
|
||||||
--dataset.repo_id=<user>/eval_groot-bimanual \
|
--dataset.repo_id=<user>/eval_groot-bimanual \
|
||||||
--dataset.num_episodes=10 \
|
|
||||||
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
--dataset.single_task="Grab and handover the red cube to the other arm" \
|
||||||
--dataset.streaming_encoding=true \
|
--dataset.streaming_encoding=true \
|
||||||
--dataset.encoder_threads=2 \
|
--dataset.encoder_threads=2 \
|
||||||
# --dataset.camera_encoder.vcodec=auto \
|
# --dataset.camera_encoder.vcodec=auto \
|
||||||
--policy.path=<user>/groot-bimanual \ # your trained model
|
--policy.path=<user>/groot-bimanual \ # your trained model
|
||||||
--dataset.episode_time_s=30 \
|
--duration=600
|
||||||
--dataset.reset_time_s=10
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
|||||||
+168
-67
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
|||||||
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||||
|
|
||||||
robot_config = SO101FollowerConfig(
|
robot_config = SO101FollowerConfig(
|
||||||
port="/dev/tty.usbmodem58760431541",
|
port="/dev/tty.usbmodem5AB90687491",
|
||||||
id="my_red_robot_arm",
|
id="my_follower_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_config = SO101LeaderConfig(
|
teleop_config = SO101LeaderConfig(
|
||||||
port="/dev/tty.usbmodem58760431551",
|
port="/dev/tty.usbmodem5AB90689011",
|
||||||
id="my_blue_leader_arm",
|
id="my_leader_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
robot = SO101Follower(robot_config)
|
robot = SO101Follower(robot_config)
|
||||||
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
|
|||||||
<hfoption id="Command">
|
<hfoption id="Command">
|
||||||
```bash
|
```bash
|
||||||
lerobot-teleoperate \
|
lerobot-teleoperate \
|
||||||
--robot.type=koch_follower \
|
--robot.type=so101_follower \
|
||||||
--robot.port=/dev/tty.usbmodem58760431541 \
|
--robot.port=/dev/tty.usbmodem5AB90687491 \
|
||||||
--robot.id=my_awesome_follower_arm \
|
--robot.id=my_follower_arm \
|
||||||
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \
|
--robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
|
||||||
--teleop.type=koch_leader \
|
--teleop.type=so101_leader \
|
||||||
--teleop.port=/dev/tty.usbmodem58760431551 \
|
--teleop.port=/dev/tty.usbmodem5AB90689011 \
|
||||||
--teleop.id=my_awesome_leader_arm \
|
--teleop.id=my_leader_arm \
|
||||||
--display_data=true
|
--display_data=true
|
||||||
```
|
```
|
||||||
</hfoption>
|
</hfoption>
|
||||||
@@ -122,34 +122,48 @@ lerobot-teleoperate \
|
|||||||
|
|
||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
|
import time
|
||||||
|
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
|
||||||
|
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
|
||||||
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
|
|
||||||
|
|
||||||
camera_config = {
|
robot_config = SO101FollowerConfig(
|
||||||
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30)
|
port="/dev/tty.usbmodem5AB90687491",
|
||||||
}
|
id="my_follower_arm",
|
||||||
|
cameras={
|
||||||
robot_config = KochFollowerConfig(
|
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||||
port="/dev/tty.usbmodem585A0076841",
|
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||||
id="my_red_robot_arm",
|
}
|
||||||
cameras=camera_config
|
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_config = KochLeaderConfig(
|
teleop_config = SO101LeaderConfig(
|
||||||
port="/dev/tty.usbmodem58760431551",
|
port="/dev/tty.usbmodem5AB90689011",
|
||||||
id="my_blue_leader_arm",
|
id="my_leader_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
robot = KochFollower(robot_config)
|
init_rerun(session_name="teleoperation")
|
||||||
teleop_device = KochLeader(teleop_config)
|
|
||||||
|
robot = SO101Follower(robot_config)
|
||||||
|
teleop_device = SO101Leader(teleop_config)
|
||||||
robot.connect()
|
robot.connect()
|
||||||
teleop_device.connect()
|
teleop_device.connect()
|
||||||
|
|
||||||
|
TARGET_HZ = 30
|
||||||
|
TIME_PER_FRAME = 1.0 / TARGET_HZ
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
|
start_time = time.perf_counter()
|
||||||
|
|
||||||
observation = robot.get_observation()
|
observation = robot.get_observation()
|
||||||
action = teleop_device.get_action()
|
action = teleop_device.get_action()
|
||||||
robot.send_action(action)
|
robot.send_action(action)
|
||||||
|
log_rerun_data(observation=observation, action=action)
|
||||||
|
|
||||||
|
elapsed_time = time.perf_counter() - start_time
|
||||||
|
sleep_time = TIME_PER_FRAME - elapsed_time
|
||||||
|
if sleep_time > 0:
|
||||||
|
time.sleep(sleep_time)
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
@@ -202,10 +216,11 @@ lerobot-record \
|
|||||||
<!-- prettier-ignore-start -->
|
<!-- prettier-ignore-start -->
|
||||||
```python
|
```python
|
||||||
from lerobot.cameras.opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv import OpenCVCameraConfig
|
||||||
from lerobot.datasets import LeRobotDataset
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
from lerobot.utils.feature_utils import hw_to_dataset_features
|
from lerobot.utils.feature_utils import hw_to_dataset_features
|
||||||
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig
|
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
|
||||||
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig
|
from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
|
||||||
|
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
|
||||||
from lerobot.common.control_utils import init_keyboard_listener
|
from lerobot.common.control_utils import init_keyboard_listener
|
||||||
from lerobot.utils.utils import log_say
|
from lerobot.utils.utils import log_say
|
||||||
from lerobot.utils.visualization_utils import init_rerun
|
from lerobot.utils.visualization_utils import init_rerun
|
||||||
@@ -218,52 +233,54 @@ EPISODE_TIME_SEC = 60
|
|||||||
RESET_TIME_SEC = 10
|
RESET_TIME_SEC = 10
|
||||||
TASK_DESCRIPTION = "My task description"
|
TASK_DESCRIPTION = "My task description"
|
||||||
|
|
||||||
# Create robot configuration
|
def main():
|
||||||
robot_config = SO100FollowerConfig(
|
# Create robot configuration
|
||||||
id="my_awesome_follower_arm",
|
robot_config = SO101FollowerConfig(
|
||||||
|
port="/dev/tty.usbmodem5AB90687491",
|
||||||
|
id="my_follower_arm",
|
||||||
cameras={
|
cameras={
|
||||||
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error.
|
"wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
|
||||||
},
|
"top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
|
||||||
port="/dev/tty.usbmodem58760434471",
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
teleop_config = SO100LeaderConfig(
|
teleop_config = SO101LeaderConfig(
|
||||||
id="my_awesome_leader_arm",
|
port="/dev/tty.usbmodem5AB90689011",
|
||||||
port="/dev/tty.usbmodem585A0077581",
|
id="my_leader_arm",
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the robot and teleoperator
|
# Initialize the robot and teleoperator
|
||||||
robot = SO100Follower(robot_config)
|
robot = SO101Follower(robot_config)
|
||||||
teleop = SO100Leader(teleop_config)
|
teleop = SO101Leader(teleop_config)
|
||||||
|
|
||||||
# Configure the dataset features
|
# Configure the dataset features
|
||||||
action_features = hw_to_dataset_features(robot.action_features, "action")
|
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||||
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
dataset_features = {**action_features, **obs_features}
|
dataset_features = {**action_features, **obs_features}
|
||||||
|
|
||||||
# Create the dataset
|
# Create the dataset
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id="<hf_username>/<dataset_repo_id>",
|
repo_id="<hf_username>/<dataset_repo_id>",
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=dataset_features,
|
features=dataset_features,
|
||||||
robot_type=robot.name,
|
robot_type=robot.name,
|
||||||
use_videos=True,
|
use_videos=True,
|
||||||
image_writer_threads=4,
|
image_writer_threads=4,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
# Initialize the keyboard listener and rerun visualization
|
||||||
_, events = init_keyboard_listener()
|
_, events = init_keyboard_listener()
|
||||||
init_rerun(session_name="recording")
|
init_rerun(session_name="recording")
|
||||||
|
|
||||||
# Connect the robot and teleoperator
|
# Connect the robot and teleoperator
|
||||||
robot.connect()
|
robot.connect()
|
||||||
teleop.connect()
|
teleop.connect()
|
||||||
|
|
||||||
# Create the required processors
|
# Create the required processors
|
||||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
record_loop(
|
record_loop(
|
||||||
@@ -306,11 +323,18 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
episode_idx += 1
|
||||||
|
|
||||||
# Clean up
|
# finalize dataset
|
||||||
log_say("Stop recording")
|
log_say("Finalizing dataset...")
|
||||||
robot.disconnect()
|
dataset.finalize()
|
||||||
teleop.disconnect()
|
# Clean up
|
||||||
dataset.push_to_hub()
|
log_say("Stop recording")
|
||||||
|
robot.disconnect()
|
||||||
|
teleop.disconnect()
|
||||||
|
dataset.push_to_hub()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
main()
|
||||||
```
|
```
|
||||||
<!-- prettier-ignore-end -->
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
|
|||||||
##### 2. Checkpointing and Resuming
|
##### 2. Checkpointing and Resuming
|
||||||
|
|
||||||
- Checkpoints are automatically created during recording.
|
- Checkpoints are automatically created during recording.
|
||||||
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset !
|
- If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
|
||||||
- To start recording from scratch, **manually delete** the dataset directory.
|
- To start recording from scratch, **manually delete** the dataset directory.
|
||||||
|
|
||||||
##### 3. Recording Parameters
|
##### 3. Recording Parameters
|
||||||
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
|
|||||||
|
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
|
|
||||||
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm")
|
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
|
||||||
|
|
||||||
robot = SO100Follower(robot_config)
|
robot = SO100Follower(robot_config)
|
||||||
robot.connect()
|
robot.connect()
|
||||||
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
|
|||||||
|
|
||||||
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
|
||||||
|
|
||||||
|
#### Train using Hugging Face Jobs
|
||||||
|
|
||||||
|
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
|
||||||
|
|
||||||
|
To run the training use this command:
|
||||||
|
|
||||||
|
<hfoptions id="train_with_hf_jobs">
|
||||||
|
<hfoption id="Command">
|
||||||
|
```bash
|
||||||
|
hf jobs run \
|
||||||
|
--flavor a10g-small \
|
||||||
|
--timeout 4h \
|
||||||
|
--secrets HF_TOKEN \
|
||||||
|
huggingface/lerobot-gpu:latest \
|
||||||
|
-- \
|
||||||
|
python -m lerobot.scripts.lerobot_train \
|
||||||
|
--dataset.repo_id=username/dataset \
|
||||||
|
--policy.type=act \
|
||||||
|
--steps=5000 \
|
||||||
|
--batch_size=16 \
|
||||||
|
--policy.device=cuda \
|
||||||
|
--policy.repo_id=username/your_policy \
|
||||||
|
--log_freq=100
|
||||||
|
```
|
||||||
|
</hfoption>
|
||||||
|
<hfoption id="API example">
|
||||||
|
|
||||||
|
<!-- prettier-ignore-start -->
|
||||||
|
```python
|
||||||
|
from huggingface_hub import run_job, get_token
|
||||||
|
|
||||||
|
run_name = "act_so101_hf_jobs"
|
||||||
|
dataset_id = "username/dataset"
|
||||||
|
user_hub_id = "username"
|
||||||
|
|
||||||
|
command_args = [
|
||||||
|
"python", "-m", "lerobot.scripts.lerobot_train",
|
||||||
|
"--dataset.repo_id", dataset_id,
|
||||||
|
"--policy.type", "act",
|
||||||
|
"--steps", "5000",
|
||||||
|
"--batch_size", "16",
|
||||||
|
"--num_workers", "4",
|
||||||
|
"--policy.device", "cuda",
|
||||||
|
"--log_freq", "100",
|
||||||
|
"--save_freq", "1000",
|
||||||
|
"--save_checkpoint", "true",
|
||||||
|
"--wandb.enable", "false",
|
||||||
|
"--policy.repo_id", f"{user_hub_id}/{run_name}"
|
||||||
|
]
|
||||||
|
|
||||||
|
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
|
||||||
|
|
||||||
|
job_info = run_job(
|
||||||
|
image="huggingface/lerobot-gpu:latest",
|
||||||
|
command=command_args,
|
||||||
|
flavor="a10g-small",
|
||||||
|
timeout="4h",
|
||||||
|
secrets={"HF_TOKEN": get_token()}
|
||||||
|
)
|
||||||
|
|
||||||
|
print("\n🚀 Job successfully launched!")
|
||||||
|
print(f"🔹 Job ID: {job_info.id}")
|
||||||
|
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
|
||||||
|
```
|
||||||
|
<!-- prettier-ignore-end -->
|
||||||
|
|
||||||
|
</hfoption>
|
||||||
|
</hfoptions>
|
||||||
|
|
||||||
|
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
|
||||||
|
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
|
||||||
|
For longer training sessions increase the timeout.
|
||||||
|
|
||||||
|
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
|
||||||
|
|
||||||
|
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
|
||||||
|
|
||||||
#### Upload policy checkpoints
|
#### Upload policy checkpoints
|
||||||
|
|
||||||
Once training is done, upload the latest checkpoint with:
|
Once training is done, upload the latest checkpoint with:
|
||||||
|
|||||||
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
|
|||||||
Once you are logged in, you can run inference in your setup by doing:
|
Once you are logged in, you can run inference in your setup by doing:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
lerobot-record \
|
lerobot-rollout \
|
||||||
|
--strategy.type=base \
|
||||||
--robot.type=so101_follower \
|
--robot.type=so101_follower \
|
||||||
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
--robot.port=/dev/ttyACM0 \ # <- Use your port
|
||||||
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
--robot.id=my_blue_follower_arm \ # <- Use your robot id
|
||||||
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
|
||||||
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
--task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
|
||||||
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub
|
# <- RTC optional, use when running on low power hardware \
|
||||||
--dataset.episode_time_s=50 \
|
# --inference.type=rtc \
|
||||||
--dataset.num_episodes=10 \
|
# --inference.rtc.execution_horizon=10 \
|
||||||
--dataset.streaming_encoding=true \
|
# --inference.rtc.max_guidance_weight=10.0 \
|
||||||
--dataset.encoder_threads=2 \
|
|
||||||
# --dataset.camera_encoder.vcodec=auto \
|
|
||||||
# <- Teleop optional if you want to teleoperate in between episodes \
|
# <- Teleop optional if you want to teleoperate in between episodes \
|
||||||
# --teleop.type=so100_leader \
|
# --teleop.type=so100_leader \
|
||||||
# --teleop.port=/dev/ttyACM0 \
|
# --teleop.port=/dev/ttyACM0 \
|
||||||
# --teleop.id=my_red_leader_arm \
|
# --teleop.id=my_red_leader_arm \
|
||||||
|
# --display_data=true #optional use if you want to see the camera stream \
|
||||||
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|||||||
@@ -15,10 +15,12 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes.
|
Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
|
||||||
|
|
||||||
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
Downloads datasets from HuggingFace, seeks directly into the episode segment
|
||||||
of the source video, draws a progress line on each frame, and writes the result.
|
of the source video, draws a progress line on each frame, and writes the result.
|
||||||
|
The progress data is read from a parquet file that lives alongside the dataset
|
||||||
|
(configurable via ``--progress-file``).
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
python examples/dataset/create_progress_videos.py \
|
python examples/dataset/create_progress_videos.py \
|
||||||
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
|
|||||||
TASK_FONT_SCALE = 0.55
|
TASK_FONT_SCALE = 0.55
|
||||||
|
|
||||||
|
|
||||||
def download_episode_metadata(repo_id: str, episode: int) -> Path:
|
def download_episode_metadata(
|
||||||
"""Download only the metadata and sarm_progress files for a dataset.
|
repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||||
|
) -> Path:
|
||||||
|
"""Download only the metadata and per-frame progress file for a dataset.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
repo_id: HuggingFace dataset repository ID.
|
repo_id: HuggingFace dataset repository ID.
|
||||||
episode: Episode index (used for logging only; all meta is fetched).
|
episode: Episode index (used for logging only; all meta is fetched).
|
||||||
|
progress_file: Filename of the per-frame progress parquet inside the
|
||||||
|
dataset repo.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Local cache path for the downloaded snapshot.
|
Local cache path for the downloaded snapshot.
|
||||||
"""
|
"""
|
||||||
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode)
|
logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
|
||||||
local_path = Path(
|
local_path = Path(
|
||||||
snapshot_download(
|
snapshot_download(
|
||||||
repo_id=repo_id,
|
repo_id=repo_id,
|
||||||
repo_type="dataset",
|
repo_type="dataset",
|
||||||
allow_patterns=["meta/**", "sarm_progress.parquet"],
|
allow_patterns=["meta/**", progress_file],
|
||||||
ignore_patterns=["*.mp4"],
|
ignore_patterns=["*.mp4"],
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
|
|||||||
return video_path
|
return video_path
|
||||||
|
|
||||||
|
|
||||||
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None:
|
def load_progress_data(
|
||||||
"""Load sarm_progress values for an episode.
|
local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
|
||||||
|
) -> np.ndarray | None:
|
||||||
|
"""Load per-frame progress values for an episode.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
local_path: Dataset cache root.
|
local_path: Dataset cache root.
|
||||||
episode: Episode index.
|
episode: Episode index.
|
||||||
|
progress_file: Filename of the per-frame progress parquet.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
|
||||||
"""
|
"""
|
||||||
parquet_path = local_path / "sarm_progress.parquet"
|
parquet_path = local_path / progress_file
|
||||||
if not parquet_path.exists():
|
if not parquet_path.exists():
|
||||||
logging.warning("sarm_progress.parquet not found")
|
logging.warning("%s not found", progress_file)
|
||||||
return None
|
return None
|
||||||
df = pd.read_parquet(parquet_path)
|
df = pd.read_parquet(parquet_path)
|
||||||
logging.info(" sarm_progress.parquet columns: %s", list(df.columns))
|
logging.info(" %s columns: %s", progress_file, list(df.columns))
|
||||||
episode_df = df[df["episode_index"] == episode].copy()
|
episode_df = df[df["episode_index"] == episode].copy()
|
||||||
if episode_df.empty:
|
if episode_df.empty:
|
||||||
logging.warning("No sarm_progress rows for episode %d", episode)
|
logging.warning("No progress rows for episode %d in %s", episode, progress_file)
|
||||||
return None
|
return None
|
||||||
episode_df = episode_df.sort_values("frame_index")
|
episode_df = episode_df.sort_values("frame_index")
|
||||||
|
|
||||||
@@ -576,6 +585,7 @@ def process_dataset(
|
|||||||
camera_key: str | None,
|
camera_key: str | None,
|
||||||
output_dir: Path,
|
output_dir: Path,
|
||||||
create_gif: bool = False,
|
create_gif: bool = False,
|
||||||
|
progress_file: str = "sarm_progress.parquet",
|
||||||
) -> Path | None:
|
) -> Path | None:
|
||||||
"""Full pipeline: download, extract metadata, composite progress, write output.
|
"""Full pipeline: download, extract metadata, composite progress, write output.
|
||||||
|
|
||||||
@@ -585,6 +595,8 @@ def process_dataset(
|
|||||||
camera_key: Camera key to use, or None for auto-selection.
|
camera_key: Camera key to use, or None for auto-selection.
|
||||||
output_dir: Directory to write output files.
|
output_dir: Directory to write output files.
|
||||||
create_gif: If True, also generate a GIF from the MP4.
|
create_gif: If True, also generate a GIF from the MP4.
|
||||||
|
progress_file: Filename of the per-frame progress parquet inside the
|
||||||
|
dataset repo.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Path to the final output file, or None on failure.
|
Path to the final output file, or None on failure.
|
||||||
@@ -592,7 +604,7 @@ def process_dataset(
|
|||||||
safe_name = repo_id.replace("/", "_")
|
safe_name = repo_id.replace("/", "_")
|
||||||
logging.info("Processing: %s | episode %d", repo_id, episode)
|
logging.info("Processing: %s | episode %d", repo_id, episode)
|
||||||
|
|
||||||
local_path = download_episode_metadata(repo_id, episode)
|
local_path = download_episode_metadata(repo_id, episode, progress_file)
|
||||||
logging.info(" Local cache: %s", local_path)
|
logging.info(" Local cache: %s", local_path)
|
||||||
|
|
||||||
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
episode_meta = load_episode_meta(local_path, episode, camera_key)
|
||||||
@@ -600,9 +612,9 @@ def process_dataset(
|
|||||||
|
|
||||||
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
|
||||||
|
|
||||||
progress_data = load_progress_data(local_path, episode)
|
progress_data = load_progress_data(local_path, episode, progress_file)
|
||||||
if progress_data is None:
|
if progress_data is None:
|
||||||
logging.error("Could not load sarm_progress data. Skipping overlay.")
|
logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
|
||||||
return None
|
return None
|
||||||
|
|
||||||
logging.info(" Progress frames: %d", len(progress_data))
|
logging.info(" Progress frames: %d", len(progress_data))
|
||||||
@@ -627,7 +639,7 @@ def process_dataset(
|
|||||||
|
|
||||||
def main() -> None:
|
def main() -> None:
|
||||||
parser = argparse.ArgumentParser(
|
parser = argparse.ArgumentParser(
|
||||||
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes."
|
description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
|
||||||
)
|
)
|
||||||
parser.add_argument(
|
parser.add_argument(
|
||||||
"--repo-id",
|
"--repo-id",
|
||||||
@@ -658,6 +670,15 @@ def main() -> None:
|
|||||||
action="store_true",
|
action="store_true",
|
||||||
help="Also generate a GIF from the MP4 output.",
|
help="Also generate a GIF from the MP4 output.",
|
||||||
)
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"--progress-file",
|
||||||
|
type=str,
|
||||||
|
default="sarm_progress.parquet",
|
||||||
|
help=(
|
||||||
|
"Filename of the per-frame progress parquet inside the dataset repo "
|
||||||
|
"(default: 'sarm_progress.parquet')."
|
||||||
|
),
|
||||||
|
)
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
|
|
||||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||||
@@ -670,6 +691,7 @@ def main() -> None:
|
|||||||
camera_key=args.camera_key,
|
camera_key=args.camera_key,
|
||||||
output_dir=args.output_dir,
|
output_dir=args.output_dir,
|
||||||
create_gif=args.gif,
|
create_gif=args.gif,
|
||||||
|
progress_file=args.progress_file,
|
||||||
)
|
)
|
||||||
|
|
||||||
if result:
|
if result:
|
||||||
|
|||||||
@@ -250,7 +250,14 @@ class DatasetWriter:
|
|||||||
for key, ft in self._meta.features.items():
|
for key, ft in self._meta.features.items():
|
||||||
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
|
||||||
continue
|
continue
|
||||||
episode_buffer[key] = np.stack(episode_buffer[key])
|
stacked_values = np.stack(episode_buffer[key])
|
||||||
|
|
||||||
|
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
|
||||||
|
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
|
||||||
|
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
|
||||||
|
stacked_values = stacked_values.reshape(episode_length)
|
||||||
|
|
||||||
|
episode_buffer[key] = stacked_values
|
||||||
|
|
||||||
# Wait for image writer to end, so that episode stats over images can be computed
|
# Wait for image writer to end, so that episode stats over images can be computed
|
||||||
self._wait_image_writer()
|
self._wait_image_writer()
|
||||||
|
|||||||
@@ -17,11 +17,13 @@ import contextlib
|
|||||||
import glob
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
import queue
|
import queue
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
import tempfile
|
||||||
import threading
|
import threading
|
||||||
import warnings
|
import warnings
|
||||||
|
from collections import OrderedDict
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from fractions import Fraction
|
from fractions import Fraction
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
@@ -191,15 +193,70 @@ def decode_video_frames_pyav(
|
|||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
class VideoDecoderCache:
|
DEFAULT_DECODER_CACHE_SIZE = 100
|
||||||
"""Thread-safe cache for video decoders to avoid expensive re-initialization."""
|
"""Default LRU capacity for :class:`VideoDecoderCache`.
|
||||||
|
|
||||||
def __init__(self):
|
Sized to comfortably hold a small rolling window of episodes worth of decoders
|
||||||
self._cache: dict[str, tuple[Any, Any]] = {}
|
(typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
|
||||||
|
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
|
||||||
|
an open ``fsspec`` file handle — on the order of a few MB per entry. Override
|
||||||
|
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
|
||||||
|
to the constructor (``None`` restores the legacy unbounded behaviour).
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
def _default_max_cache_size() -> int | None:
|
||||||
|
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
|
||||||
|
if raw is None:
|
||||||
|
return DEFAULT_DECODER_CACHE_SIZE
|
||||||
|
raw = raw.strip().lower()
|
||||||
|
if raw in ("", "none", "unbounded", "-1"):
|
||||||
|
return None
|
||||||
|
try:
|
||||||
|
value = int(raw)
|
||||||
|
except ValueError as e:
|
||||||
|
raise ValueError(
|
||||||
|
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
|
||||||
|
) from e
|
||||||
|
if value <= 0:
|
||||||
|
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
|
class VideoDecoderCache:
|
||||||
|
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
|
||||||
|
|
||||||
|
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
|
||||||
|
backing it. When the cache is full and a new path is requested, the
|
||||||
|
least-recently-used entry is evicted and its file handle is closed. This
|
||||||
|
bounds host-RAM growth when iterating over datasets with many distinct
|
||||||
|
video files (otherwise each ``DataLoader`` worker pins every decoder it has
|
||||||
|
ever opened until the process exits).
|
||||||
|
|
||||||
|
Args:
|
||||||
|
max_size: Maximum number of decoders to retain. ``None`` disables
|
||||||
|
eviction and restores legacy unbounded behaviour. Defaults to the
|
||||||
|
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
|
||||||
|
:data:`DEFAULT_DECODER_CACHE_SIZE`.
|
||||||
|
"""
|
||||||
|
|
||||||
|
_SENTINEL: ClassVar[object] = object()
|
||||||
|
|
||||||
|
def __init__(self, max_size: int | None | object = _SENTINEL):
|
||||||
|
if max_size is VideoDecoderCache._SENTINEL:
|
||||||
|
max_size = _default_max_cache_size()
|
||||||
|
if max_size is not None and max_size <= 0:
|
||||||
|
raise ValueError(f"max_size must be positive or None; got {max_size}")
|
||||||
|
self.max_size: int | None = max_size # type: ignore[assignment]
|
||||||
|
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
|
||||||
self._lock = Lock()
|
self._lock = Lock()
|
||||||
|
|
||||||
|
def __contains__(self, video_path: object) -> bool:
|
||||||
|
with self._lock:
|
||||||
|
return str(video_path) in self._cache
|
||||||
|
|
||||||
def get_decoder(self, video_path: str):
|
def get_decoder(self, video_path: str):
|
||||||
"""Get a cached decoder or create a new one."""
|
"""Get a cached decoder or create a new one, evicting LRU if at capacity."""
|
||||||
if importlib.util.find_spec("torchcodec"):
|
if importlib.util.find_spec("torchcodec"):
|
||||||
from torchcodec.decoders import VideoDecoder
|
from torchcodec.decoders import VideoDecoder
|
||||||
else:
|
else:
|
||||||
@@ -211,7 +268,11 @@ class VideoDecoderCache:
|
|||||||
video_path = str(video_path)
|
video_path = str(video_path)
|
||||||
|
|
||||||
with self._lock:
|
with self._lock:
|
||||||
if video_path not in self._cache:
|
entry = self._cache.get(video_path)
|
||||||
|
if entry is not None:
|
||||||
|
self._cache.move_to_end(video_path)
|
||||||
|
return entry[0]
|
||||||
|
|
||||||
file_handle = fsspec.open(video_path).__enter__()
|
file_handle = fsspec.open(video_path).__enter__()
|
||||||
try:
|
try:
|
||||||
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
decoder = VideoDecoder(file_handle, seek_mode="approximate")
|
||||||
@@ -220,12 +281,22 @@ class VideoDecoderCache:
|
|||||||
raise
|
raise
|
||||||
self._cache[video_path] = (decoder, file_handle)
|
self._cache[video_path] = (decoder, file_handle)
|
||||||
|
|
||||||
return self._cache[video_path][0]
|
# Evict LRU entries until we are back under the cap. We close
|
||||||
|
# evicted file handles immediately; the associated ``VideoDecoder``
|
||||||
|
# is released to the GC when its last reference goes away.
|
||||||
|
if self.max_size is not None:
|
||||||
|
while len(self._cache) > self.max_size:
|
||||||
|
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
|
evicted_handle.close()
|
||||||
|
|
||||||
|
return decoder
|
||||||
|
|
||||||
def clear(self):
|
def clear(self):
|
||||||
"""Clear the cache and close file handles."""
|
"""Clear the cache and close all file handles."""
|
||||||
with self._lock:
|
with self._lock:
|
||||||
for _, file_handle in self._cache.values():
|
for _, file_handle in self._cache.values():
|
||||||
|
with contextlib.suppress(Exception):
|
||||||
file_handle.close()
|
file_handle.close()
|
||||||
self._cache.clear()
|
self._cache.clear()
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from .tables import (
|
|||||||
CAN_CMD_SET_ZERO,
|
CAN_CMD_SET_ZERO,
|
||||||
DEFAULT_BAUDRATE,
|
DEFAULT_BAUDRATE,
|
||||||
DEFAULT_TIMEOUT_MS,
|
DEFAULT_TIMEOUT_MS,
|
||||||
|
HANDSHAKE_TIMEOUT_S,
|
||||||
MODEL_RESOLUTION,
|
MODEL_RESOLUTION,
|
||||||
MOTOR_LIMIT_PARAMS,
|
MOTOR_LIMIT_PARAMS,
|
||||||
NORMALIZED_DATA,
|
NORMALIZED_DATA,
|
||||||
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
self._is_connected = False
|
self._is_connected = False
|
||||||
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
|
||||||
|
|
||||||
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]:
|
def _query_status_via_clear_fault(
|
||||||
|
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
|
||||||
|
) -> tuple[bool, can.Message | None]:
|
||||||
motor_name = self._get_motor_name(motor)
|
motor_name = self._get_motor_name(motor)
|
||||||
motor_id = self._get_motor_id(motor_name)
|
motor_id = self._get_motor_id(motor_name)
|
||||||
recv_id = self._get_motor_recv_id(motor_name)
|
recv_id = self._get_motor_recv_id(motor_name)
|
||||||
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
|
||||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
return self._recv_status_via_clear_fault(expected_recv_id=recv_id)
|
return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
|
||||||
|
|
||||||
def _recv_status_via_clear_fault(
|
def _recv_status_via_clear_fault(
|
||||||
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
|
||||||
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
faulted_motors = []
|
faulted_motors = []
|
||||||
|
|
||||||
for motor_name in self.motors:
|
for motor_name in self.motors:
|
||||||
has_fault, msg = self._query_status_via_clear_fault(motor_name)
|
has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
|
||||||
if msg is None:
|
if msg is None:
|
||||||
missing_motors.append(motor_name)
|
missing_motors.append(motor_name)
|
||||||
elif has_fault:
|
elif has_fault:
|
||||||
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
|
|
||||||
return responses
|
return responses
|
||||||
|
|
||||||
|
def _recv_all_messages_until_quiet(
|
||||||
|
self,
|
||||||
|
*,
|
||||||
|
timeout: float = RUNNING_TIMEOUT,
|
||||||
|
max_messages: int = 4096,
|
||||||
|
) -> list[can.Message]:
|
||||||
|
"""
|
||||||
|
Receive frames until the bus goes quiet.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
timeout: Poll timeout used for each recv() call. Collection stops
|
||||||
|
when one recv() times out (quiet gap).
|
||||||
|
max_messages: Safety cap to prevent unbounded loops.
|
||||||
|
"""
|
||||||
|
out: list[can.Message] = []
|
||||||
|
max_messages = max(1, max_messages)
|
||||||
|
timeout = max(0.0, timeout)
|
||||||
|
|
||||||
|
try:
|
||||||
|
while len(out) < max_messages:
|
||||||
|
msg = self._bus().recv(timeout=timeout)
|
||||||
|
if msg is None:
|
||||||
|
break
|
||||||
|
out.append(msg)
|
||||||
|
except (can.CanError, OSError) as e:
|
||||||
|
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
|
||||||
|
|
||||||
|
return out
|
||||||
|
|
||||||
|
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
|
||||||
|
"""
|
||||||
|
Decode all received feedback frames and update cached motor states.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Set of payload recv_ids that were successfully mapped to motors.
|
||||||
|
"""
|
||||||
|
processed_recv_ids: set[int] = set()
|
||||||
|
for msg in messages:
|
||||||
|
if len(msg.data) < 1:
|
||||||
|
logger.debug(
|
||||||
|
f"Dropping short CAN frame on {self.port} "
|
||||||
|
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
recv_id = int(msg.data[0])
|
||||||
|
motor_name = self._recv_id_to_motor.get(recv_id)
|
||||||
|
if motor_name is None:
|
||||||
|
logger.debug(
|
||||||
|
f"Unmapped CAN frame on {self.port} "
|
||||||
|
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
|
||||||
|
)
|
||||||
|
continue
|
||||||
|
|
||||||
|
self._process_response(motor_name, msg)
|
||||||
|
processed_recv_ids.add(recv_id)
|
||||||
|
|
||||||
|
return processed_recv_ids
|
||||||
|
|
||||||
|
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
|
||||||
|
"""
|
||||||
|
Drain pending RX frames from the CAN interface.
|
||||||
|
|
||||||
|
This is used by higher-level controllers to drop stale feedback before issuing
|
||||||
|
a fresh read cycle, so subsequent state reads are based on most recent replies.
|
||||||
|
It should also be called once when a controller instance is created/connected,
|
||||||
|
to clear residual frames left on the interface from previous sessions.
|
||||||
|
"""
|
||||||
|
drained = 0
|
||||||
|
poll_timeout_s = max(0.0, poll_timeout_s)
|
||||||
|
max_messages = max(1, max_messages)
|
||||||
|
try:
|
||||||
|
while drained < max_messages:
|
||||||
|
msg = self._bus().recv(timeout=poll_timeout_s)
|
||||||
|
if msg is None:
|
||||||
|
break
|
||||||
|
drained += 1
|
||||||
|
except (can.CanError, OSError) as e:
|
||||||
|
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
|
||||||
|
return drained
|
||||||
|
|
||||||
def _speed_control(
|
def _speed_control(
|
||||||
self,
|
self,
|
||||||
motor: NameOrID,
|
motor: NameOrID,
|
||||||
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
|
||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
|
||||||
|
# Read every feedback frame until RX goes quiet, then decode all of them.
|
||||||
|
# This avoids dropping useful frames when responses from different motors interleave.
|
||||||
|
messages = self._recv_all_messages_until_quiet()
|
||||||
|
processed_recv_ids = self._process_feedback_messages(messages)
|
||||||
|
|
||||||
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
|
|
||||||
for recv_id, motor_name in recv_id_to_motor.items():
|
for recv_id, motor_name in recv_id_to_motor.items():
|
||||||
if msg := responses.get(recv_id):
|
if recv_id not in processed_recv_ids:
|
||||||
self._process_response(motor_name, msg)
|
logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||||
|
|
||||||
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
|
||||||
"""Convert float to unsigned integer for CAN transmission."""
|
"""Convert float to unsigned integer for CAN transmission."""
|
||||||
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
try:
|
try:
|
||||||
self._decode_motor_state(msg.data)
|
self._decode_motor_state(msg.data)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Failed to decode response from {motor}: {e}")
|
logger.warning(
|
||||||
|
f"Failed to decode response from {motor} "
|
||||||
|
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
|
||||||
|
)
|
||||||
|
|
||||||
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
def _get_cached_value(self, motor: str, data_name: str) -> Value:
|
||||||
"""Retrieve a specific value from the state cache."""
|
"""Retrieve a specific value from the state cache."""
|
||||||
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
|
|||||||
self._bus().send(msg)
|
self._bus().send(msg)
|
||||||
updated_motors.append(motor)
|
updated_motors.append(motor)
|
||||||
|
|
||||||
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors]
|
messages = self._recv_all_messages_until_quiet()
|
||||||
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT)
|
processed_recv_ids = self._process_feedback_messages(messages)
|
||||||
|
|
||||||
for response in responses.values():
|
|
||||||
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
|
|
||||||
if payload_motor_name is not None:
|
|
||||||
self._process_response(payload_motor_name, response)
|
|
||||||
else:
|
|
||||||
# Fallback: still attempt to decode based on payload byte0 mapping.
|
|
||||||
self._decode_motor_state(response.data)
|
|
||||||
|
|
||||||
for motor in updated_motors:
|
for motor in updated_motors:
|
||||||
recv_id = self._get_motor_recv_id(motor)
|
recv_id = self._get_motor_recv_id(motor)
|
||||||
if recv_id not in responses:
|
if recv_id not in processed_recv_ids:
|
||||||
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
|
||||||
|
|
||||||
def read_calibration(self) -> dict[str, MotorCalibration]:
|
def read_calibration(self) -> dict[str, MotorCalibration]:
|
||||||
|
|||||||
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
|
|||||||
CAN_PARAM_ID = 0x7FF
|
CAN_PARAM_ID = 0x7FF
|
||||||
|
|
||||||
|
|
||||||
RUNNING_TIMEOUT = 0.001
|
RUNNING_TIMEOUT = 0.003
|
||||||
|
HANDSHAKE_TIMEOUT_S = 0.05
|
||||||
PARAM_TIMEOUT = 0.01
|
PARAM_TIMEOUT = 0.01
|
||||||
|
|
||||||
STATE_CACHE_TTL_S = 0.02
|
STATE_CACHE_TTL_S = 0.02
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
|
|
||||||
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
import datasets
|
||||||
from huggingface_hub import HfApi
|
from huggingface_hub import HfApi
|
||||||
from PIL import Image
|
from PIL import Image
|
||||||
from safetensors.torch import load_file
|
from safetensors.torch import load_file
|
||||||
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
|
|||||||
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"dtype,np_dtype,values,assert_fn",
|
||||||
|
[
|
||||||
|
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
|
||||||
|
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
|
||||||
|
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
|
||||||
|
],
|
||||||
|
ids=["float32", "int64", "bool"],
|
||||||
|
)
|
||||||
|
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
|
||||||
|
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
|
||||||
|
):
|
||||||
|
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
|
||||||
|
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
|
||||||
|
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
|
||||||
|
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
|
||||||
|
|
||||||
|
captured = {}
|
||||||
|
original_from_dict = datasets.Dataset.from_dict
|
||||||
|
|
||||||
|
def _from_dict_spy(cls, mapping, *args, **kwargs):
|
||||||
|
captured["state"] = mapping["state"]
|
||||||
|
return original_from_dict(mapping, *args, **kwargs)
|
||||||
|
|
||||||
|
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
dataset.finalize()
|
||||||
|
|
||||||
|
assert "state" in captured
|
||||||
|
assert isinstance(captured["state"], np.ndarray)
|
||||||
|
assert captured["state"].shape == (2,)
|
||||||
|
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
|
||||||
|
|
||||||
|
|
||||||
def test_set_image_transforms_applies_transparently(image_dataset):
|
def test_set_image_transforms_applies_transparently(image_dataset):
|
||||||
dataset = image_dataset
|
dataset = image_dataset
|
||||||
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
|
||||||
|
|||||||
@@ -0,0 +1,140 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Unit tests for ``lerobot.datasets.video_utils.VideoDecoderCache``.
|
||||||
|
|
||||||
|
These cover the LRU bounding + file-handle release behaviour added to prevent
|
||||||
|
unbounded growth when iterating over datasets with many distinct video files
|
||||||
|
(observed: ~35 GB anon-rss per DataLoader worker on an 8 k-file dataset).
|
||||||
|
"""
|
||||||
|
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
pytest.importorskip("torchcodec", reason="torchcodec is required (install lerobot[dataset])")
|
||||||
|
|
||||||
|
from lerobot.datasets.video_utils import VideoDecoderCache # noqa: E402
|
||||||
|
|
||||||
|
TEST_ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" / "encoded_videos"
|
||||||
|
SRC_CLIP = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
|
||||||
|
|
||||||
|
|
||||||
|
def _make_distinct_clips(tmp_path: Path, n: int) -> list[Path]:
|
||||||
|
"""Copy the small reference mp4 to ``n`` distinct paths.
|
||||||
|
|
||||||
|
The cache keys on absolute path, so distinct paths force distinct cache entries
|
||||||
|
even though the file contents are identical.
|
||||||
|
"""
|
||||||
|
assert SRC_CLIP.exists(), f"missing test artifact {SRC_CLIP}"
|
||||||
|
paths = []
|
||||||
|
for i in range(n):
|
||||||
|
dst = tmp_path / f"clip_{i:04d}.mp4"
|
||||||
|
shutil.copyfile(SRC_CLIP, dst)
|
||||||
|
paths.append(dst)
|
||||||
|
return paths
|
||||||
|
|
||||||
|
|
||||||
|
class TestVideoDecoderCacheBounded:
|
||||||
|
def test_default_cache_is_bounded(self):
|
||||||
|
"""The default cache must have a finite ``max_size`` to bound RSS growth."""
|
||||||
|
cache = VideoDecoderCache()
|
||||||
|
assert cache.max_size is not None, "default cache must be bounded"
|
||||||
|
assert cache.max_size > 0
|
||||||
|
|
||||||
|
def test_size_capped_at_max_size(self, tmp_path):
|
||||||
|
"""``get_decoder`` for >``max_size`` distinct paths must NOT grow without bound."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=5)
|
||||||
|
cache = VideoDecoderCache(max_size=2)
|
||||||
|
for p in paths:
|
||||||
|
cache.get_decoder(p)
|
||||||
|
assert cache.size() == 2
|
||||||
|
|
||||||
|
def test_evicts_least_recently_used(self, tmp_path):
|
||||||
|
"""Re-accessing an entry must promote it; the LRU entry is the one evicted."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=3)
|
||||||
|
cache = VideoDecoderCache(max_size=2)
|
||||||
|
|
||||||
|
cache.get_decoder(paths[0])
|
||||||
|
cache.get_decoder(paths[1])
|
||||||
|
cache.get_decoder(paths[0]) # promote paths[0] to MRU; paths[1] is now LRU
|
||||||
|
cache.get_decoder(paths[2]) # should evict paths[1]
|
||||||
|
|
||||||
|
assert str(paths[0]) in cache # MRU stays
|
||||||
|
assert str(paths[1]) not in cache # LRU evicted
|
||||||
|
assert str(paths[2]) in cache # newest stays
|
||||||
|
|
||||||
|
def test_eviction_closes_file_handle(self, tmp_path):
|
||||||
|
"""Evicting an entry must close its fsspec file handle (otherwise we leak FDs)."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=2)
|
||||||
|
cache = VideoDecoderCache(max_size=1)
|
||||||
|
|
||||||
|
cache.get_decoder(paths[0])
|
||||||
|
# Reach into the cache to capture the handle before it is evicted. This is
|
||||||
|
# the only assertion in the suite that touches a private attribute, and it
|
||||||
|
# is the most direct way to prove the file descriptor is actually released.
|
||||||
|
evicted_handle = cache._cache[str(paths[0])][1]
|
||||||
|
assert evicted_handle.closed is False
|
||||||
|
|
||||||
|
cache.get_decoder(paths[1]) # forces eviction of paths[0]
|
||||||
|
|
||||||
|
assert evicted_handle.closed is True
|
||||||
|
|
||||||
|
def test_clear_closes_all_file_handles(self, tmp_path):
|
||||||
|
"""``clear()`` must close every cached file handle."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=3)
|
||||||
|
cache = VideoDecoderCache(max_size=10)
|
||||||
|
|
||||||
|
for p in paths:
|
||||||
|
cache.get_decoder(p)
|
||||||
|
handles = [entry[1] for entry in cache._cache.values()]
|
||||||
|
assert all(not h.closed for h in handles)
|
||||||
|
|
||||||
|
cache.clear()
|
||||||
|
|
||||||
|
assert cache.size() == 0
|
||||||
|
assert all(h.closed for h in handles)
|
||||||
|
|
||||||
|
def test_hit_does_not_reopen_or_evict(self, tmp_path):
|
||||||
|
"""A cache hit must return the same decoder instance without touching the cap."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=1)
|
||||||
|
cache = VideoDecoderCache(max_size=2)
|
||||||
|
|
||||||
|
first = cache.get_decoder(paths[0])
|
||||||
|
second = cache.get_decoder(paths[0])
|
||||||
|
|
||||||
|
assert first is second
|
||||||
|
assert cache.size() == 1
|
||||||
|
|
||||||
|
def test_unbounded_when_max_size_none(self, tmp_path):
|
||||||
|
"""``max_size=None`` preserves the legacy unbounded behaviour."""
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=4)
|
||||||
|
cache = VideoDecoderCache(max_size=None)
|
||||||
|
for p in paths:
|
||||||
|
cache.get_decoder(p)
|
||||||
|
assert cache.size() == 4
|
||||||
|
|
||||||
|
def test_env_var_overrides_default(self, tmp_path, monkeypatch):
|
||||||
|
"""``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var sets the default ``max_size``."""
|
||||||
|
monkeypatch.setenv("LEROBOT_VIDEO_DECODER_CACHE_SIZE", "3")
|
||||||
|
cache = VideoDecoderCache()
|
||||||
|
assert cache.max_size == 3
|
||||||
|
|
||||||
|
paths = _make_distinct_clips(tmp_path, n=5)
|
||||||
|
for p in paths:
|
||||||
|
cache.get_decoder(p)
|
||||||
|
assert cache.size() == 3
|
||||||
Reference in New Issue
Block a user