Compare commits

...

5 Commits

Author SHA1 Message Date
Nikodem Bartnik c0a2e9814d fix examples (#3623)
- Fixed broken API examples in Lerobot Imitation Learning Documentation
- Teleoperation with cameras improved by adding a fixed frequency in the loop (without it the cameras feed gets very slow)
- Wrapped record example script in main() to avoid problems on Mac
- Previously teleoperation example was using SO-ARM and teleoperation with cameras was using Koch. I changed it to use SO-ARM in all of the examples.
- Added section on how to train with HF Jobs - CLI and Python examples
- Replaced lerobot-record with lerobot-rollout in policies examples
2026-05-21 22:14:07 +02:00
Khalil Meftah bac4f61eae refactor: support custom progress parquet overlays (#3640) 2026-05-21 14:32:10 +02:00
Virgileboat f4b834844e Feat/clean can bus (#3526)
* change timeout  for handshake

* enforce last state read when querry

* change import order

* fix(motors): flush stale robstride RX and harden feedback drain

* robstride: remove redundant timeout and max_messages casts

* bugfix + %-style

* update exception catch
2026-05-21 11:44:04 +02:00
Roham Z. Nobari dfdc48a7f1 fix(datasets): bound VideoDecoderCache to prevent OOM on large datasets (#3614)
VideoDecoderCache used an unbounded dict keyed on absolute path, with no
eviction in the standard LeRobotDataset path. With shuffled iteration over
datasets that have many distinct mp4 files, every DataLoader worker
accumulated one cached (VideoDecoder, fsspec file handle) pair per distinct
path it had ever touched. Per-entry cost is ~3-5 MB of host RAM plus one
open FD; at ~8 k entries this is roughly 30 GB per worker.

This was hit in the wild during a SmolVLA training run on a 4,195-episode
SO-101 dataset (8,390 mp4s, two cameras per episode). dmesg showed
anon-rss climbing to 34.9 GB on a single pt_data_worker before the OOM
killer fired ~30 min into training; with --num_workers=8 the per-worker
peak halved to 17.9 GB, which is the expected inverse-scaling signature
when the leak is per-decode and the workload is split across workers. The
working workaround on the affected platform was --dataset.video_backend=pyav,
because the pyav path opens/closes per call and never touches this cache.

Switch the backing store to an OrderedDict and evict LRU entries when the
cap is reached, closing the evicted file handle inside the lock so we do
not leak FDs either. Default cap is DEFAULT_DECODER_CACHE_SIZE = 100,
overridable via LEROBOT_VIDEO_DECODER_CACHE_SIZE or by passing max_size=
to the constructor; max_size=None restores the legacy unbounded behaviour
for callers that need it.

Validation on the original failing workload (decode_video_frames_torchcodec
called over real mp4s from the affected SO-101 dataset):

  unbounded:    300 files  ->  +1087 MB host RSS,  cache=300, still climbing
  cap=50:       500 files  ->   +266 MB host RSS,  cache=50,  stable
  cap=50:      2000 calls  ->   +312 MB host RSS,  cache=50,  stable
  cap=100:     1000 calls  ->   +470 MB host RSS,  cache=100, stable

Three independent seeded runs at cap=50 agreed to within 1% (263 / 266 /
265 MB delta), and the 2000-call multi-pass run shows RSS plateaus after
the cap is reached instead of drifting.

Tests in tests/datasets/test_video_decoder_cache.py cover:
default-is-bounded, size cap, LRU ordering, FD close on eviction, FD close
on clear(), cache-hit invariance, max_size=None fallback, and env-var
override. No regressions in test_video_encoding.py, test_streaming.py, or
test_dataset_reader.py (73 prior tests still pass alongside the 8 new ones).
2026-05-19 16:54:25 +02:00
四七 6a8878a639 fix(datasets): normalize shape=(1,) numeric values before HF encoding (#3344)
* fix(datasets): normalize shape=(1,) numeric values before save

* test(datasets): cover shape=(1,) int/bool and finalize

Co-authored-by: Copilot <copilot@github.com>
2026-05-19 16:53:19 +02:00
11 changed files with 638 additions and 182 deletions
+6 -10
View File
@@ -79,17 +79,13 @@ If your local computer doesn't have a powerful GPU, you can utilize Google Colab
Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes: Once training is complete, you can evaluate your ACT policy using the `lerobot-record` command with your trained policy. This will run inference and record evaluation episodes:
```bash ```bash
lerobot-record \ lerobot-rollout \
--robot.type=so100_follower \ --strategy.type=base \
--policy.path=${HF_USER}/act_policy \
--robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ --robot.port=/dev/ttyACM0 \
--robot.id=my_robot \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \ --robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--display_data=true \ --display_data=true \
--dataset.repo_id=${HF_USER}/eval_act_your_dataset \ --task="Your task description" \ # can be skipped for ACT
--dataset.num_episodes=10 \ --duration=60
--dataset.single_task="Your task description" \
--dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
--policy.path=${HF_USER}/act_policy
``` ```
+5 -5
View File
@@ -105,10 +105,12 @@ These results demonstrate GR00T's strong generalization capabilities across dive
### Evaluate in your hardware setup ### Evaluate in your hardware setup
Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Imitation Learning for Robots](./il_robots). For example: Once you have trained your model using your parameters you can run inference in your downstream task. Follow the instructions in [Policy Deployment (lerobot-rollout)](./inference). For example:
```bash ```bash
lerobot-record \ lerobot-rollout\
--strategy.type=sentry \
--strategy.upload_every_n_episodes=5 \
--robot.type=bi_so_follower \ --robot.type=bi_so_follower \
--robot.left_arm_port=/dev/ttyACM1 \ --robot.left_arm_port=/dev/ttyACM1 \
--robot.right_arm_port=/dev/ttyACM0 \ --robot.right_arm_port=/dev/ttyACM0 \
@@ -119,14 +121,12 @@ lerobot-record \
}' \ }' \
--display_data=true \ --display_data=true \
--dataset.repo_id=<user>/eval_groot-bimanual \ --dataset.repo_id=<user>/eval_groot-bimanual \
--dataset.num_episodes=10 \
--dataset.single_task="Grab and handover the red cube to the other arm" \ --dataset.single_task="Grab and handover the red cube to the other arm" \
--dataset.streaming_encoding=true \ --dataset.streaming_encoding=true \
--dataset.encoder_threads=2 \ --dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \ # --dataset.camera_encoder.vcodec=auto \
--policy.path=<user>/groot-bimanual \ # your trained model --policy.path=<user>/groot-bimanual \ # your trained model
--dataset.episode_time_s=30 \ --duration=600
--dataset.reset_time_s=10
``` ```
## License ## License
+168 -67
View File
@@ -68,13 +68,13 @@ from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
robot_config = SO101FollowerConfig( robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem58760431541", port="/dev/tty.usbmodem5AB90687491",
id="my_red_robot_arm", id="my_follower_arm",
) )
teleop_config = SO101LeaderConfig( teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551", port="/dev/tty.usbmodem5AB90689011",
id="my_blue_leader_arm", id="my_leader_arm",
) )
robot = SO101Follower(robot_config) robot = SO101Follower(robot_config)
@@ -108,13 +108,13 @@ With `rerun`, you can teleoperate again while simultaneously visualizing the cam
<hfoption id="Command"> <hfoption id="Command">
```bash ```bash
lerobot-teleoperate \ lerobot-teleoperate \
--robot.type=koch_follower \ --robot.type=so101_follower \
--robot.port=/dev/tty.usbmodem58760431541 \ --robot.port=/dev/tty.usbmodem5AB90687491 \
--robot.id=my_awesome_follower_arm \ --robot.id=my_follower_arm \
--robot.cameras="{ front: {type: opencv, index_or_path: 0, width: 1920, height: 1080, fps: 30}}" \ --robot.cameras="{front: {type: opencv, index_or_path: 0, width: 640, height: 480, fps: 30}}" \
--teleop.type=koch_leader \ --teleop.type=so101_leader \
--teleop.port=/dev/tty.usbmodem58760431551 \ --teleop.port=/dev/tty.usbmodem5AB90689011 \
--teleop.id=my_awesome_leader_arm \ --teleop.id=my_leader_arm \
--display_data=true --display_data=true
``` ```
</hfoption> </hfoption>
@@ -122,34 +122,48 @@ lerobot-teleoperate \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
import time
from lerobot.teleoperators.so_leader import SO101Leader, SO101LeaderConfig
from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.teleoperators.koch_leader import KochLeader, KochLeaderConfig from lerobot.utils.visualization_utils import init_rerun, log_rerun_data, shutdown_rerun
from lerobot.robots.koch_follower import KochFollower, KochFollowerConfig
camera_config = { robot_config = SO101FollowerConfig(
"front": OpenCVCameraConfig(index_or_path=0, width=1920, height=1080, fps=30) port="/dev/tty.usbmodem5AB90687491",
} id="my_follower_arm",
cameras={
robot_config = KochFollowerConfig( "wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
port="/dev/tty.usbmodem585A0076841", "top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
id="my_red_robot_arm", }
cameras=camera_config
) )
teleop_config = KochLeaderConfig( teleop_config = SO101LeaderConfig(
port="/dev/tty.usbmodem58760431551", port="/dev/tty.usbmodem5AB90689011",
id="my_blue_leader_arm", id="my_leader_arm",
) )
robot = KochFollower(robot_config) init_rerun(session_name="teleoperation")
teleop_device = KochLeader(teleop_config)
robot = SO101Follower(robot_config)
teleop_device = SO101Leader(teleop_config)
robot.connect() robot.connect()
teleop_device.connect() teleop_device.connect()
TARGET_HZ = 30
TIME_PER_FRAME = 1.0 / TARGET_HZ
while True: while True:
start_time = time.perf_counter()
observation = robot.get_observation() observation = robot.get_observation()
action = teleop_device.get_action() action = teleop_device.get_action()
robot.send_action(action) robot.send_action(action)
log_rerun_data(observation=observation, action=action)
elapsed_time = time.perf_counter() - start_time
sleep_time = TIME_PER_FRAME - elapsed_time
if sleep_time > 0:
time.sleep(sleep_time)
``` ```
<!-- prettier-ignore-end --> <!-- prettier-ignore-end -->
@@ -202,10 +216,11 @@ lerobot-record \
<!-- prettier-ignore-start --> <!-- prettier-ignore-start -->
```python ```python
from lerobot.cameras.opencv import OpenCVCameraConfig from lerobot.cameras.opencv import OpenCVCameraConfig
from lerobot.datasets import LeRobotDataset from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.utils.feature_utils import hw_to_dataset_features from lerobot.utils.feature_utils import hw_to_dataset_features
from lerobot.robots.so_follower import SO100Follower, SO100FollowerConfig from lerobot.robots.so_follower import SO101Follower, SO101FollowerConfig
from lerobot.teleoperators.so_leader import SO100Leader, SO100LeaderConfig from lerobot.teleoperators.so_leader.config_so_leader import SO101LeaderConfig
from lerobot.teleoperators.so_leader.so_leader import SO101Leader
from lerobot.common.control_utils import init_keyboard_listener from lerobot.common.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import init_rerun from lerobot.utils.visualization_utils import init_rerun
@@ -218,52 +233,54 @@ EPISODE_TIME_SEC = 60
RESET_TIME_SEC = 10 RESET_TIME_SEC = 10
TASK_DESCRIPTION = "My task description" TASK_DESCRIPTION = "My task description"
# Create robot configuration def main():
robot_config = SO100FollowerConfig( # Create robot configuration
id="my_awesome_follower_arm", robot_config = SO101FollowerConfig(
port="/dev/tty.usbmodem5AB90687491",
id="my_follower_arm",
cameras={ cameras={
"front": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=FPS) # Optional: fourcc="MJPG" for troubleshooting OpenCV async error. "wrist": OpenCVCameraConfig(index_or_path=0, width=640, height=480, fps=30),
}, "top": OpenCVCameraConfig(index_or_path=1, width=640, height=480, fps=30)
port="/dev/tty.usbmodem58760434471", }
) )
teleop_config = SO100LeaderConfig( teleop_config = SO101LeaderConfig(
id="my_awesome_leader_arm", port="/dev/tty.usbmodem5AB90689011",
port="/dev/tty.usbmodem585A0077581", id="my_leader_arm",
) )
# Initialize the robot and teleoperator # Initialize the robot and teleoperator
robot = SO100Follower(robot_config) robot = SO101Follower(robot_config)
teleop = SO100Leader(teleop_config) teleop = SO101Leader(teleop_config)
# Configure the dataset features # Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action") action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation") obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features} dataset_features = {**action_features, **obs_features}
# Create the dataset # Create the dataset
dataset = LeRobotDataset.create( dataset = LeRobotDataset.create(
repo_id="<hf_username>/<dataset_repo_id>", repo_id="<hf_username>/<dataset_repo_id>",
fps=FPS, fps=FPS,
features=dataset_features, features=dataset_features,
robot_type=robot.name, robot_type=robot.name,
use_videos=True, use_videos=True,
image_writer_threads=4, image_writer_threads=4,
) )
# Initialize the keyboard listener and rerun visualization # Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener() _, events = init_keyboard_listener()
init_rerun(session_name="recording") init_rerun(session_name="recording")
# Connect the robot and teleoperator # Connect the robot and teleoperator
robot.connect() robot.connect()
teleop.connect() teleop.connect()
# Create the required processors # Create the required processors
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors() teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
episode_idx = 0 episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]: while episode_idx < NUM_EPISODES and not events["stop_recording"]:
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
record_loop( record_loop(
@@ -306,11 +323,18 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
dataset.save_episode() dataset.save_episode()
episode_idx += 1 episode_idx += 1
# Clean up # finalize dataset
log_say("Stop recording") log_say("Finalizing dataset...")
robot.disconnect() dataset.finalize()
teleop.disconnect() # Clean up
dataset.push_to_hub() log_say("Stop recording")
robot.disconnect()
teleop.disconnect()
dataset.push_to_hub()
if __name__ == "__main__":
main()
``` ```
<!-- prettier-ignore-end --> <!-- prettier-ignore-end -->
@@ -348,7 +372,7 @@ The `record` function provides a suite of tools for capturing and managing data
##### 2. Checkpointing and Resuming ##### 2. Checkpointing and Resuming
- Checkpoints are automatically created during recording. - Checkpoints are automatically created during recording.
- If an issue occurs, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset ! - If an issue occurs or you want to record additional episodes in the same dataset, you can resume by re-running the same command with `--resume=true`. When resuming a recording, `--dataset.num_episodes` must be set to the **number of additional episodes to be recorded**, and not to the targeted total number of episodes in the dataset! Make sure that you also set `--dataset.root="local_path"`, it's a local path to save the new part of the dataset and is required to resume.
- To start recording from scratch, **manually delete** the dataset directory. - To start recording from scratch, **manually delete** the dataset directory.
##### 3. Recording Parameters ##### 3. Recording Parameters
@@ -422,7 +446,7 @@ from lerobot.utils.utils import log_say
episode_idx = 0 episode_idx = 0
robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem58760434471", id="my_awesome_follower_arm") robot_config = SO100FollowerConfig(port="/dev/tty.usbmodem5AB90687491", id="my_follower_arm")
robot = SO100Follower(robot_config) robot = SO100Follower(robot_config)
robot.connect() robot.connect()
@@ -490,6 +514,83 @@ Additionally you can provide extra `tags` or specify a `license` for your model
If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act). If your local computer doesn't have a powerful GPU you could utilize Google Colab to train your model by following the [ACT training notebook](./notebooks#training-act).
#### Train using Hugging Face Jobs
Hugging Face jobs let's you easily select hardware and run the training in the cloud. So if you don't have a powerful GPU or you need more VRAM or just want to train a model much faster use HF Jobs! It's pay as you go and you simply pay for each second of use, you can see the pricing and additional information [here](https://huggingface.co/docs/hub/jobs).
To run the training use this command:
<hfoptions id="train_with_hf_jobs">
<hfoption id="Command">
```bash
hf jobs run \
--flavor a10g-small \
--timeout 4h \
--secrets HF_TOKEN \
huggingface/lerobot-gpu:latest \
-- \
python -m lerobot.scripts.lerobot_train \
--dataset.repo_id=username/dataset \
--policy.type=act \
--steps=5000 \
--batch_size=16 \
--policy.device=cuda \
--policy.repo_id=username/your_policy \
--log_freq=100
```
</hfoption>
<hfoption id="API example">
<!-- prettier-ignore-start -->
```python
from huggingface_hub import run_job, get_token
run_name = "act_so101_hf_jobs"
dataset_id = "username/dataset"
user_hub_id = "username"
command_args = [
"python", "-m", "lerobot.scripts.lerobot_train",
"--dataset.repo_id", dataset_id,
"--policy.type", "act",
"--steps", "5000",
"--batch_size", "16",
"--num_workers", "4",
"--policy.device", "cuda",
"--log_freq", "100",
"--save_freq", "1000",
"--save_checkpoint", "true",
"--wandb.enable", "false",
"--policy.repo_id", f"{user_hub_id}/{run_name}"
]
print(f"Submitting job '{run_name}' to Hugging Face Infrastructure...")
job_info = run_job(
image="huggingface/lerobot-gpu:latest",
command=command_args,
flavor="a10g-small",
timeout="4h",
secrets={"HF_TOKEN": get_token()}
)
print("\n🚀 Job successfully launched!")
print(f"🔹 Job ID: {job_info.id}")
print(f"🔗 Live UI Dashboard & Logs: {job_info.url}")
```
<!-- prettier-ignore-end -->
</hfoption>
</hfoptions>
You can modify the `--flavor` to use different hardware, for example: `t4-small`, `a100-large`, `h200`. Use `hf jobs hardware` to see the full list with pricing.
Depending on the model you want to train and the hardware you selected you can also modify the `--batch_size` and `--number_of_workers`.
For longer training sessions increase the timeout.
Once the training is started you can go to [Jobs](https://huggingface.co/settings/jobs) and see if your jobs is running as well as all the outputs. Sometimes it takes a few minutes to schedule your job so be patient.
After training the model will be pushed to hub and you can use it as any other model with LeRobot.
#### Upload policy checkpoints #### Upload policy checkpoints
Once training is done, upload the latest checkpoint with: Once training is done, upload the latest checkpoint with:
+8 -8
View File
@@ -97,22 +97,22 @@ Similarly for when recording an episode, it is recommended that you are logged i
Once you are logged in, you can run inference in your setup by doing: Once you are logged in, you can run inference in your setup by doing:
```bash ```bash
lerobot-record \ lerobot-rollout \
--strategy.type=base \
--robot.type=so101_follower \ --robot.type=so101_follower \
--robot.port=/dev/ttyACM0 \ # <- Use your port --robot.port=/dev/ttyACM0 \ # <- Use your port
--robot.id=my_blue_follower_arm \ # <- Use your robot id --robot.id=my_blue_follower_arm \ # <- Use your robot id
--robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras --robot.cameras="{ front: {type: opencv, index_or_path: 8, width: 640, height: 480, fps: 30}}" \ # <- Use your cameras
--dataset.single_task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording --task="Grasp a lego block and put it in the bin." \ # <- Use the same task description you used in your dataset recording
--dataset.repo_id=${HF_USER}/eval_DATASET_NAME_test \ # <- This will be the dataset name on HF Hub # <- RTC optional, use when running on low power hardware \
--dataset.episode_time_s=50 \ # --inference.type=rtc \
--dataset.num_episodes=10 \ # --inference.rtc.execution_horizon=10 \
--dataset.streaming_encoding=true \ # --inference.rtc.max_guidance_weight=10.0 \
--dataset.encoder_threads=2 \
# --dataset.camera_encoder.vcodec=auto \
# <- Teleop optional if you want to teleoperate in between episodes \ # <- Teleop optional if you want to teleoperate in between episodes \
# --teleop.type=so100_leader \ # --teleop.type=so100_leader \
# --teleop.port=/dev/ttyACM0 \ # --teleop.port=/dev/ttyACM0 \
# --teleop.id=my_red_leader_arm \ # --teleop.id=my_red_leader_arm \
# --display_data=true #optional use if you want to see the camera stream \
--policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model --policy.path=HF_USER/FINETUNE_MODEL_NAME # <- Use your fine-tuned model
``` ```
+37 -15
View File
@@ -15,10 +15,12 @@
# limitations under the License. # limitations under the License.
""" """
Create MP4 (or GIF) videos with sarm_progress overlay for specified episodes. Create MP4 (or GIF) videos with per-frame progress overlay for specified episodes.
Downloads datasets from HuggingFace, seeks directly into the episode segment Downloads datasets from HuggingFace, seeks directly into the episode segment
of the source video, draws a progress line on each frame, and writes the result. of the source video, draws a progress line on each frame, and writes the result.
The progress data is read from a parquet file that lives alongside the dataset
(configurable via ``--progress-file``).
Usage: Usage:
python examples/dataset/create_progress_videos.py \ python examples/dataset/create_progress_videos.py \
@@ -56,22 +58,26 @@ SCORE_FONT_SCALE = 0.8
TASK_FONT_SCALE = 0.55 TASK_FONT_SCALE = 0.55
def download_episode_metadata(repo_id: str, episode: int) -> Path: def download_episode_metadata(
"""Download only the metadata and sarm_progress files for a dataset. repo_id: str, episode: int, progress_file: str = "sarm_progress.parquet"
) -> Path:
"""Download only the metadata and per-frame progress file for a dataset.
Args: Args:
repo_id: HuggingFace dataset repository ID. repo_id: HuggingFace dataset repository ID.
episode: Episode index (used for logging only; all meta is fetched). episode: Episode index (used for logging only; all meta is fetched).
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns: Returns:
Local cache path for the downloaded snapshot. Local cache path for the downloaded snapshot.
""" """
logging.info("[1/4] Downloading metadata for %s (episode %d) ...", repo_id, episode) logging.info("[1/4] Downloading metadata + %s for %s (episode %d) ...", progress_file, repo_id, episode)
local_path = Path( local_path = Path(
snapshot_download( snapshot_download(
repo_id=repo_id, repo_id=repo_id,
repo_type="dataset", repo_type="dataset",
allow_patterns=["meta/**", "sarm_progress.parquet"], allow_patterns=["meta/**", progress_file],
ignore_patterns=["*.mp4"], ignore_patterns=["*.mp4"],
) )
) )
@@ -215,25 +221,28 @@ def download_video_file(repo_id: str, local_path: Path, video_rel: str) -> Path:
return video_path return video_path
def load_progress_data(local_path: Path, episode: int) -> np.ndarray | None: def load_progress_data(
"""Load sarm_progress values for an episode. local_path: Path, episode: int, progress_file: str = "sarm_progress.parquet"
) -> np.ndarray | None:
"""Load per-frame progress values for an episode.
Args: Args:
local_path: Dataset cache root. local_path: Dataset cache root.
episode: Episode index. episode: Episode index.
progress_file: Filename of the per-frame progress parquet.
Returns: Returns:
Sorted (N, 2) array of (frame_index, progress), or None if unavailable. Sorted (N, 2) array of (frame_index, progress), or None if unavailable.
""" """
parquet_path = local_path / "sarm_progress.parquet" parquet_path = local_path / progress_file
if not parquet_path.exists(): if not parquet_path.exists():
logging.warning("sarm_progress.parquet not found") logging.warning("%s not found", progress_file)
return None return None
df = pd.read_parquet(parquet_path) df = pd.read_parquet(parquet_path)
logging.info(" sarm_progress.parquet columns: %s", list(df.columns)) logging.info(" %s columns: %s", progress_file, list(df.columns))
episode_df = df[df["episode_index"] == episode].copy() episode_df = df[df["episode_index"] == episode].copy()
if episode_df.empty: if episode_df.empty:
logging.warning("No sarm_progress rows for episode %d", episode) logging.warning("No progress rows for episode %d in %s", episode, progress_file)
return None return None
episode_df = episode_df.sort_values("frame_index") episode_df = episode_df.sort_values("frame_index")
@@ -576,6 +585,7 @@ def process_dataset(
camera_key: str | None, camera_key: str | None,
output_dir: Path, output_dir: Path,
create_gif: bool = False, create_gif: bool = False,
progress_file: str = "sarm_progress.parquet",
) -> Path | None: ) -> Path | None:
"""Full pipeline: download, extract metadata, composite progress, write output. """Full pipeline: download, extract metadata, composite progress, write output.
@@ -585,6 +595,8 @@ def process_dataset(
camera_key: Camera key to use, or None for auto-selection. camera_key: Camera key to use, or None for auto-selection.
output_dir: Directory to write output files. output_dir: Directory to write output files.
create_gif: If True, also generate a GIF from the MP4. create_gif: If True, also generate a GIF from the MP4.
progress_file: Filename of the per-frame progress parquet inside the
dataset repo.
Returns: Returns:
Path to the final output file, or None on failure. Path to the final output file, or None on failure.
@@ -592,7 +604,7 @@ def process_dataset(
safe_name = repo_id.replace("/", "_") safe_name = repo_id.replace("/", "_")
logging.info("Processing: %s | episode %d", repo_id, episode) logging.info("Processing: %s | episode %d", repo_id, episode)
local_path = download_episode_metadata(repo_id, episode) local_path = download_episode_metadata(repo_id, episode, progress_file)
logging.info(" Local cache: %s", local_path) logging.info(" Local cache: %s", local_path)
episode_meta = load_episode_meta(local_path, episode, camera_key) episode_meta = load_episode_meta(local_path, episode, camera_key)
@@ -600,9 +612,9 @@ def process_dataset(
video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"]) video_path = download_video_file(repo_id, local_path, episode_meta["video_rel"])
progress_data = load_progress_data(local_path, episode) progress_data = load_progress_data(local_path, episode, progress_file)
if progress_data is None: if progress_data is None:
logging.error("Could not load sarm_progress data. Skipping overlay.") logging.error("Could not load progress data from %s. Skipping overlay.", progress_file)
return None return None
logging.info(" Progress frames: %d", len(progress_data)) logging.info(" Progress frames: %d", len(progress_data))
@@ -627,7 +639,7 @@ def process_dataset(
def main() -> None: def main() -> None:
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
description="Create MP4/GIF videos with sarm_progress overlay for dataset episodes." description="Create MP4/GIF videos with per-frame progress overlay for dataset episodes."
) )
parser.add_argument( parser.add_argument(
"--repo-id", "--repo-id",
@@ -658,6 +670,15 @@ def main() -> None:
action="store_true", action="store_true",
help="Also generate a GIF from the MP4 output.", help="Also generate a GIF from the MP4 output.",
) )
parser.add_argument(
"--progress-file",
type=str,
default="sarm_progress.parquet",
help=(
"Filename of the per-frame progress parquet inside the dataset repo "
"(default: 'sarm_progress.parquet')."
),
)
args = parser.parse_args() args = parser.parse_args()
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s") logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
@@ -670,6 +691,7 @@ def main() -> None:
camera_key=args.camera_key, camera_key=args.camera_key,
output_dir=args.output_dir, output_dir=args.output_dir,
create_gif=args.gif, create_gif=args.gif,
progress_file=args.progress_file,
) )
if result: if result:
+8 -1
View File
@@ -250,7 +250,14 @@ class DatasetWriter:
for key, ft in self._meta.features.items(): for key, ft in self._meta.features.items():
if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]: if key in ["index", "episode_index", "task_index"] or ft["dtype"] in ["image", "video"]:
continue continue
episode_buffer[key] = np.stack(episode_buffer[key]) stacked_values = np.stack(episode_buffer[key])
# `shape=(1,)` numeric features are serialized as `datasets.Value`, which expects scalars.
# Normalizing to `(N,)` keeps save semantics stable across dependency versions.
if tuple(ft["shape"]) == (1,) and ft["dtype"] != "string":
stacked_values = stacked_values.reshape(episode_length)
episode_buffer[key] = stacked_values
# Wait for image writer to end, so that episode stats over images can be computed # Wait for image writer to end, so that episode stats over images can be computed
self._wait_image_writer() self._wait_image_writer()
+79 -8
View File
@@ -17,11 +17,13 @@ import contextlib
import glob import glob
import importlib import importlib
import logging import logging
import os
import queue import queue
import shutil import shutil
import tempfile import tempfile
import threading import threading
import warnings import warnings
from collections import OrderedDict
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from fractions import Fraction from fractions import Fraction
from pathlib import Path from pathlib import Path
@@ -191,15 +193,70 @@ def decode_video_frames_pyav(
return closest_frames return closest_frames
class VideoDecoderCache: DEFAULT_DECODER_CACHE_SIZE = 100
"""Thread-safe cache for video decoders to avoid expensive re-initialization.""" """Default LRU capacity for :class:`VideoDecoderCache`.
def __init__(self): Sized to comfortably hold a small rolling window of episodes worth of decoders
self._cache: dict[str, tuple[Any, Any]] = {} (typical recipes: 2-4 cameras per episode × tens of episodes in flight) while
bounding host RAM. Each cached entry retains a torchcodec ``VideoDecoder`` plus
an open ``fsspec`` file handle on the order of a few MB per entry. Override
via the ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var or by passing ``max_size``
to the constructor (``None`` restores the legacy unbounded behaviour).
"""
def _default_max_cache_size() -> int | None:
raw = os.environ.get("LEROBOT_VIDEO_DECODER_CACHE_SIZE")
if raw is None:
return DEFAULT_DECODER_CACHE_SIZE
raw = raw.strip().lower()
if raw in ("", "none", "unbounded", "-1"):
return None
try:
value = int(raw)
except ValueError as e:
raise ValueError(
f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be an integer, 'none', or '-1'; got {raw!r}"
) from e
if value <= 0:
raise ValueError(f"LEROBOT_VIDEO_DECODER_CACHE_SIZE must be positive; got {value}")
return value
class VideoDecoderCache:
"""Thread-safe LRU cache for torchcodec ``VideoDecoder`` instances.
Cached entries hold a ``VideoDecoder`` plus the open ``fsspec`` file handle
backing it. When the cache is full and a new path is requested, the
least-recently-used entry is evicted and its file handle is closed. This
bounds host-RAM growth when iterating over datasets with many distinct
video files (otherwise each ``DataLoader`` worker pins every decoder it has
ever opened until the process exits).
Args:
max_size: Maximum number of decoders to retain. ``None`` disables
eviction and restores legacy unbounded behaviour. Defaults to the
value of ``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` if set, otherwise
:data:`DEFAULT_DECODER_CACHE_SIZE`.
"""
_SENTINEL: ClassVar[object] = object()
def __init__(self, max_size: int | None | object = _SENTINEL):
if max_size is VideoDecoderCache._SENTINEL:
max_size = _default_max_cache_size()
if max_size is not None and max_size <= 0:
raise ValueError(f"max_size must be positive or None; got {max_size}")
self.max_size: int | None = max_size # type: ignore[assignment]
self._cache: OrderedDict[str, tuple[Any, Any]] = OrderedDict()
self._lock = Lock() self._lock = Lock()
def __contains__(self, video_path: object) -> bool:
with self._lock:
return str(video_path) in self._cache
def get_decoder(self, video_path: str): def get_decoder(self, video_path: str):
"""Get a cached decoder or create a new one.""" """Get a cached decoder or create a new one, evicting LRU if at capacity."""
if importlib.util.find_spec("torchcodec"): if importlib.util.find_spec("torchcodec"):
from torchcodec.decoders import VideoDecoder from torchcodec.decoders import VideoDecoder
else: else:
@@ -211,7 +268,11 @@ class VideoDecoderCache:
video_path = str(video_path) video_path = str(video_path)
with self._lock: with self._lock:
if video_path not in self._cache: entry = self._cache.get(video_path)
if entry is not None:
self._cache.move_to_end(video_path)
return entry[0]
file_handle = fsspec.open(video_path).__enter__() file_handle = fsspec.open(video_path).__enter__()
try: try:
decoder = VideoDecoder(file_handle, seek_mode="approximate") decoder = VideoDecoder(file_handle, seek_mode="approximate")
@@ -220,12 +281,22 @@ class VideoDecoderCache:
raise raise
self._cache[video_path] = (decoder, file_handle) self._cache[video_path] = (decoder, file_handle)
return self._cache[video_path][0] # Evict LRU entries until we are back under the cap. We close
# evicted file handles immediately; the associated ``VideoDecoder``
# is released to the GC when its last reference goes away.
if self.max_size is not None:
while len(self._cache) > self.max_size:
_evicted_path, (_evicted_decoder, evicted_handle) = self._cache.popitem(last=False)
with contextlib.suppress(Exception):
evicted_handle.close()
return decoder
def clear(self): def clear(self):
"""Clear the cache and close file handles.""" """Clear the cache and close all file handles."""
with self._lock: with self._lock:
for _, file_handle in self._cache.values(): for _, file_handle in self._cache.values():
with contextlib.suppress(Exception):
file_handle.close() file_handle.close()
self._cache.clear() self._cache.clear()
+100 -18
View File
@@ -43,6 +43,7 @@ from .tables import (
CAN_CMD_SET_ZERO, CAN_CMD_SET_ZERO,
DEFAULT_BAUDRATE, DEFAULT_BAUDRATE,
DEFAULT_TIMEOUT_MS, DEFAULT_TIMEOUT_MS,
HANDSHAKE_TIMEOUT_S,
MODEL_RESOLUTION, MODEL_RESOLUTION,
MOTOR_LIMIT_PARAMS, MOTOR_LIMIT_PARAMS,
NORMALIZED_DATA, NORMALIZED_DATA,
@@ -215,14 +216,16 @@ class RobstrideMotorsBus(MotorsBusBase):
self._is_connected = False self._is_connected = False
raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e raise ConnectionError(f"Failed to connect to CAN bus: {e}") from e
def _query_status_via_clear_fault(self, motor: NameOrID) -> tuple[bool, can.Message | None]: def _query_status_via_clear_fault(
self, motor: NameOrID, timeout: float = RUNNING_TIMEOUT
) -> tuple[bool, can.Message | None]:
motor_name = self._get_motor_name(motor) motor_name = self._get_motor_name(motor)
motor_id = self._get_motor_id(motor_name) motor_id = self._get_motor_id(motor_name)
recv_id = self._get_motor_recv_id(motor_name) recv_id = self._get_motor_recv_id(motor_name)
data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT] data = [0xFF] * 7 + [CAN_CMD_CLEAR_FAULT]
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg) self._bus().send(msg)
return self._recv_status_via_clear_fault(expected_recv_id=recv_id) return self._recv_status_via_clear_fault(expected_recv_id=recv_id, timeout=timeout)
def _recv_status_via_clear_fault( def _recv_status_via_clear_fault(
self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT self, expected_recv_id: int | None = None, timeout: float = RUNNING_TIMEOUT
@@ -280,7 +283,7 @@ class RobstrideMotorsBus(MotorsBusBase):
faulted_motors = [] faulted_motors = []
for motor_name in self.motors: for motor_name in self.motors:
has_fault, msg = self._query_status_via_clear_fault(motor_name) has_fault, msg = self._query_status_via_clear_fault(motor_name, timeout=HANDSHAKE_TIMEOUT_S)
if msg is None: if msg is None:
missing_motors.append(motor_name) missing_motors.append(motor_name)
elif has_fault: elif has_fault:
@@ -505,6 +508,87 @@ class RobstrideMotorsBus(MotorsBusBase):
return responses return responses
def _recv_all_messages_until_quiet(
self,
*,
timeout: float = RUNNING_TIMEOUT,
max_messages: int = 4096,
) -> list[can.Message]:
"""
Receive frames until the bus goes quiet.
Args:
timeout: Poll timeout used for each recv() call. Collection stops
when one recv() times out (quiet gap).
max_messages: Safety cap to prevent unbounded loops.
"""
out: list[can.Message] = []
max_messages = max(1, max_messages)
timeout = max(0.0, timeout)
try:
while len(out) < max_messages:
msg = self._bus().recv(timeout=timeout)
if msg is None:
break
out.append(msg)
except (can.CanError, OSError) as e:
logger.debug(f"Error draining CAN RX queue on {self.port}: {e}")
return out
def _process_feedback_messages(self, messages: list[can.Message]) -> set[int]:
"""
Decode all received feedback frames and update cached motor states.
Returns:
Set of payload recv_ids that were successfully mapped to motors.
"""
processed_recv_ids: set[int] = set()
for msg in messages:
if len(msg.data) < 1:
logger.debug(
f"Dropping short CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()})"
)
continue
recv_id = int(msg.data[0])
motor_name = self._recv_id_to_motor.get(recv_id)
if motor_name is None:
logger.debug(
f"Unmapped CAN frame on {self.port} "
f"(arb=0x{int(msg.arbitration_id):02X}, recv_id=0x{recv_id:02X}, data={bytes(msg.data).hex()})"
)
continue
self._process_response(motor_name, msg)
processed_recv_ids.add(recv_id)
return processed_recv_ids
def flush_rx_queue(self, poll_timeout_s: float = 0.0005, max_messages: int = 4096) -> int:
"""
Drain pending RX frames from the CAN interface.
This is used by higher-level controllers to drop stale feedback before issuing
a fresh read cycle, so subsequent state reads are based on most recent replies.
It should also be called once when a controller instance is created/connected,
to clear residual frames left on the interface from previous sessions.
"""
drained = 0
poll_timeout_s = max(0.0, poll_timeout_s)
max_messages = max(1, max_messages)
try:
while drained < max_messages:
msg = self._bus().recv(timeout=poll_timeout_s)
if msg is None:
break
drained += 1
except (can.CanError, OSError) as e:
logger.debug(f"Failed to flush CAN RX queue on {self.port}: {e}")
return drained
def _speed_control( def _speed_control(
self, self,
motor: NameOrID, motor: NameOrID,
@@ -644,11 +728,14 @@ class RobstrideMotorsBus(MotorsBusBase):
msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False) msg = can.Message(arbitration_id=motor_id, data=data, is_extended_id=False)
self._bus().send(msg) self._bus().send(msg)
recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name recv_id_to_motor[self._get_motor_recv_id(motor)] = motor_name
# Read every feedback frame until RX goes quiet, then decode all of them.
# This avoids dropping useful frames when responses from different motors interleave.
messages = self._recv_all_messages_until_quiet()
processed_recv_ids = self._process_feedback_messages(messages)
responses = self._recv_all_responses(list(recv_id_to_motor.keys()), timeout=RUNNING_TIMEOUT)
for recv_id, motor_name in recv_id_to_motor.items(): for recv_id, motor_name in recv_id_to_motor.items():
if msg := responses.get(recv_id): if recv_id not in processed_recv_ids:
self._process_response(motor_name, msg) logger.warning(f"Packet drop: {motor_name} (ID: 0x{recv_id:02X}). Using last known state.")
def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int: def _float_to_uint(self, x: float, x_min: float, x_max: float, bits: int) -> int:
"""Convert float to unsigned integer for CAN transmission.""" """Convert float to unsigned integer for CAN transmission."""
@@ -711,7 +798,10 @@ class RobstrideMotorsBus(MotorsBusBase):
try: try:
self._decode_motor_state(msg.data) self._decode_motor_state(msg.data)
except Exception as e: except Exception as e:
logger.warning(f"Failed to decode response from {motor}: {e}") logger.warning(
f"Failed to decode response from {motor} "
f"(arb=0x{int(msg.arbitration_id):02X}, data={bytes(msg.data).hex()}): {e}"
)
def _get_cached_value(self, motor: str, data_name: str) -> Value: def _get_cached_value(self, motor: str, data_name: str) -> Value:
"""Retrieve a specific value from the state cache.""" """Retrieve a specific value from the state cache."""
@@ -848,20 +938,12 @@ class RobstrideMotorsBus(MotorsBusBase):
self._bus().send(msg) self._bus().send(msg)
updated_motors.append(motor) updated_motors.append(motor)
expected_recv_ids = [self._get_motor_recv_id(motor) for motor in updated_motors] messages = self._recv_all_messages_until_quiet()
responses = self._recv_all_responses(expected_recv_ids, timeout=RUNNING_TIMEOUT) processed_recv_ids = self._process_feedback_messages(messages)
for response in responses.values():
payload_motor_name = self._recv_id_to_motor.get(response.data[0])
if payload_motor_name is not None:
self._process_response(payload_motor_name, response)
else:
# Fallback: still attempt to decode based on payload byte0 mapping.
self._decode_motor_state(response.data)
for motor in updated_motors: for motor in updated_motors:
recv_id = self._get_motor_recv_id(motor) recv_id = self._get_motor_recv_id(motor)
if recv_id not in responses: if recv_id not in processed_recv_ids:
logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.") logger.warning(f"Packet drop: {motor} (ID: 0x{recv_id:02X}). Using last known state.")
def read_calibration(self) -> dict[str, MotorCalibration]: def read_calibration(self) -> dict[str, MotorCalibration]:
+2 -1
View File
@@ -114,7 +114,8 @@ CAN_CMD_SAVE_PARAM = 0xAA
CAN_PARAM_ID = 0x7FF CAN_PARAM_ID = 0x7FF
RUNNING_TIMEOUT = 0.001 RUNNING_TIMEOUT = 0.003
HANDSHAKE_TIMEOUT_S = 0.05
PARAM_TIMEOUT = 0.01 PARAM_TIMEOUT = 0.01
STATE_CACHE_TTL_S = 0.02 STATE_CACHE_TTL_S = 0.02
+36
View File
@@ -24,6 +24,7 @@ import torch
pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])") pytest.importorskip("datasets", reason="datasets is required (install lerobot[dataset])")
import datasets
from huggingface_hub import HfApi from huggingface_hub import HfApi
from PIL import Image from PIL import Image
from safetensors.torch import load_file from safetensors.torch import load_file
@@ -360,6 +361,41 @@ def test_add_frame_image_pil(image_dataset):
assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW) assert dataset[0]["image"].shape == torch.Size(DUMMY_CHW)
@pytest.mark.parametrize(
"dtype,np_dtype,values,assert_fn",
[
("float32", np.float32, [1.0, 2.0], np.testing.assert_allclose),
("int64", np.int64, [1, 2], np.testing.assert_array_equal),
("bool", np.bool_, [True, False], np.testing.assert_array_equal),
],
ids=["float32", "int64", "bool"],
)
def test_save_episode_shape_1_scalar_is_scalarized_before_hf_encoding(
tmp_path, empty_lerobot_dataset_factory, monkeypatch, dtype, np_dtype, values, assert_fn
):
features = {"state": {"dtype": dtype, "shape": (1,), "names": None}}
dataset = empty_lerobot_dataset_factory(root=tmp_path / "test", features=features)
dataset.add_frame({"state": np.array([values[0]], dtype=np_dtype), "task": "Dummy task"})
dataset.add_frame({"state": np.array([values[1]], dtype=np_dtype), "task": "Dummy task"})
captured = {}
original_from_dict = datasets.Dataset.from_dict
def _from_dict_spy(cls, mapping, *args, **kwargs):
captured["state"] = mapping["state"]
return original_from_dict(mapping, *args, **kwargs)
monkeypatch.setattr(datasets.Dataset, "from_dict", classmethod(_from_dict_spy))
dataset.save_episode()
dataset.finalize()
assert "state" in captured
assert isinstance(captured["state"], np.ndarray)
assert captured["state"].shape == (2,)
assert_fn(captured["state"], np.array(values, dtype=np_dtype))
def test_set_image_transforms_applies_transparently(image_dataset): def test_set_image_transforms_applies_transparently(image_dataset):
dataset = image_dataset dataset = image_dataset
dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"}) dataset.add_frame({"image": np.random.rand(*DUMMY_CHW), "task": "Dummy task"})
+140
View File
@@ -0,0 +1,140 @@
#!/usr/bin/env python
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Unit tests for ``lerobot.datasets.video_utils.VideoDecoderCache``.
These cover the LRU bounding + file-handle release behaviour added to prevent
unbounded growth when iterating over datasets with many distinct video files
(observed: ~35 GB anon-rss per DataLoader worker on an 8 k-file dataset).
"""
import shutil
from pathlib import Path
import pytest
pytest.importorskip("torchcodec", reason="torchcodec is required (install lerobot[dataset])")
from lerobot.datasets.video_utils import VideoDecoderCache # noqa: E402
TEST_ARTIFACTS_DIR = Path(__file__).resolve().parent.parent / "artifacts" / "encoded_videos"
SRC_CLIP = TEST_ARTIFACTS_DIR / "clip_4frames.mp4"
def _make_distinct_clips(tmp_path: Path, n: int) -> list[Path]:
"""Copy the small reference mp4 to ``n`` distinct paths.
The cache keys on absolute path, so distinct paths force distinct cache entries
even though the file contents are identical.
"""
assert SRC_CLIP.exists(), f"missing test artifact {SRC_CLIP}"
paths = []
for i in range(n):
dst = tmp_path / f"clip_{i:04d}.mp4"
shutil.copyfile(SRC_CLIP, dst)
paths.append(dst)
return paths
class TestVideoDecoderCacheBounded:
def test_default_cache_is_bounded(self):
"""The default cache must have a finite ``max_size`` to bound RSS growth."""
cache = VideoDecoderCache()
assert cache.max_size is not None, "default cache must be bounded"
assert cache.max_size > 0
def test_size_capped_at_max_size(self, tmp_path):
"""``get_decoder`` for >``max_size`` distinct paths must NOT grow without bound."""
paths = _make_distinct_clips(tmp_path, n=5)
cache = VideoDecoderCache(max_size=2)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 2
def test_evicts_least_recently_used(self, tmp_path):
"""Re-accessing an entry must promote it; the LRU entry is the one evicted."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=2)
cache.get_decoder(paths[0])
cache.get_decoder(paths[1])
cache.get_decoder(paths[0]) # promote paths[0] to MRU; paths[1] is now LRU
cache.get_decoder(paths[2]) # should evict paths[1]
assert str(paths[0]) in cache # MRU stays
assert str(paths[1]) not in cache # LRU evicted
assert str(paths[2]) in cache # newest stays
def test_eviction_closes_file_handle(self, tmp_path):
"""Evicting an entry must close its fsspec file handle (otherwise we leak FDs)."""
paths = _make_distinct_clips(tmp_path, n=2)
cache = VideoDecoderCache(max_size=1)
cache.get_decoder(paths[0])
# Reach into the cache to capture the handle before it is evicted. This is
# the only assertion in the suite that touches a private attribute, and it
# is the most direct way to prove the file descriptor is actually released.
evicted_handle = cache._cache[str(paths[0])][1]
assert evicted_handle.closed is False
cache.get_decoder(paths[1]) # forces eviction of paths[0]
assert evicted_handle.closed is True
def test_clear_closes_all_file_handles(self, tmp_path):
"""``clear()`` must close every cached file handle."""
paths = _make_distinct_clips(tmp_path, n=3)
cache = VideoDecoderCache(max_size=10)
for p in paths:
cache.get_decoder(p)
handles = [entry[1] for entry in cache._cache.values()]
assert all(not h.closed for h in handles)
cache.clear()
assert cache.size() == 0
assert all(h.closed for h in handles)
def test_hit_does_not_reopen_or_evict(self, tmp_path):
"""A cache hit must return the same decoder instance without touching the cap."""
paths = _make_distinct_clips(tmp_path, n=1)
cache = VideoDecoderCache(max_size=2)
first = cache.get_decoder(paths[0])
second = cache.get_decoder(paths[0])
assert first is second
assert cache.size() == 1
def test_unbounded_when_max_size_none(self, tmp_path):
"""``max_size=None`` preserves the legacy unbounded behaviour."""
paths = _make_distinct_clips(tmp_path, n=4)
cache = VideoDecoderCache(max_size=None)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 4
def test_env_var_overrides_default(self, tmp_path, monkeypatch):
"""``LEROBOT_VIDEO_DECODER_CACHE_SIZE`` env var sets the default ``max_size``."""
monkeypatch.setenv("LEROBOT_VIDEO_DECODER_CACHE_SIZE", "3")
cache = VideoDecoderCache()
assert cache.max_size == 3
paths = _make_distinct_clips(tmp_path, n=5)
for p in paths:
cache.get_decoder(p)
assert cache.size() == 3