mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-11 14:49:43 +00:00
Compare commits
42 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 8770c011b0 | |||
| ddcda8f1ca | |||
| 4f8ebe41b3 | |||
| 066976e078 | |||
| b3c2592ace | |||
| b97ea8999f | |||
| 69aeda68f5 | |||
| a9e355bd03 | |||
| aae68e3448 | |||
| 4b9f6c4aed | |||
| 6057638fc1 | |||
| e52e7e644a | |||
| 8633608d26 | |||
| 900e6b59c8 | |||
| f844fe683c | |||
| 4403675b31 | |||
| d18be0c3f4 | |||
| 866f8adf11 | |||
| 3d6310c03d | |||
| c3b26382e7 | |||
| e54e582a6f | |||
| 418791ebba | |||
| ee3354a885 | |||
| 2cd06fe95b | |||
| 7be84cb545 | |||
| c35af1ae6a | |||
| 6fc024704e | |||
| c3b7a18f01 | |||
| 7fc0cdf68a | |||
| 23bf69ebab | |||
| 3d5d8fa88a | |||
| e80c9e6270 | |||
| 39cf11d5dc | |||
| 285c500aef | |||
| f60d163588 | |||
| 4a04465bb8 | |||
| 464532ec37 | |||
| 89f9bd78ab | |||
| c9cfc88602 | |||
| 7bef12a461 | |||
| db5c26f07d | |||
| 8904768db4 |
@@ -12,6 +12,11 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Python virtual environments — never copy into Docker images
|
||||
.venv
|
||||
venv
|
||||
env/
|
||||
|
||||
# Misc
|
||||
.git
|
||||
tmp
|
||||
|
||||
@@ -0,0 +1,200 @@
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
# Benchmark evaluation container — one image per benchmark, built via BENCHMARK arg.
|
||||
#
|
||||
# Supported values for BENCHMARK:
|
||||
# libero — LIBERO suite (spatial / object / goal / 10 / 90)
|
||||
# libero_plus — LIBERO-plus extended benchmark (requires robosuite, bddl, robomimic)
|
||||
# robomme — RoboMME memory-augmented manipulation benchmark
|
||||
# robocasa — RoboCasa kitchen composite-task benchmark
|
||||
#
|
||||
# Build:
|
||||
# docker build --build-arg BENCHMARK=libero -f docker/Dockerfile.benchmark \
|
||||
# -t lerobot-benchmark-libero .
|
||||
#
|
||||
# Run (interactive):
|
||||
# docker run --gpus all --rm -it lerobot-benchmark-libero
|
||||
# Run eval:
|
||||
# docker run --gpus all --rm lerobot-benchmark-libero lerobot-eval --help
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
ARG BENCHMARK=libero
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
MUJOCO_GL=egl \
|
||||
PYOPENGL_PLATFORM=egl \
|
||||
EGL_PLATFORM=device \
|
||||
NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
NVIDIA_VISIBLE_DEVICES=all \
|
||||
PATH=/lerobot/.venv/bin:$PATH \
|
||||
CMAKE_POLICY_VERSION_MINIMUM=3.5 \
|
||||
CUDA_VISIBLE_DEVICES=0 \
|
||||
DEVICE=cuda \
|
||||
BENCHMARK=${BENCHMARK}
|
||||
|
||||
# ── Base system deps (shared across all benchmarks) ───────────────────────────
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1 libgl1-mesa-glx libgles2 \
|
||||
libegl1 libegl1-mesa libegl1-mesa-dev \
|
||||
libglew-dev libglfw3 libglfw3-dev libgl1-mesa-dri \
|
||||
libglvnd-dev libosmesa6 libosmesa6-dev \
|
||||
libvulkan1 mesa-vulkan-drivers \
|
||||
libsm6 libxext6 libxrender-dev \
|
||||
ffmpeg libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& usermod -aG sudo user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# ── NVIDIA EGL + Vulkan vendor ICDs (lets GLVND find the GPU driver) ──────────
|
||||
RUN mkdir -p /usr/share/vulkan/icd.d /usr/share/glvnd/egl_vendor.d \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.2.155"}}\n' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}\n' \
|
||||
> /usr/share/glvnd/egl_vendor.d/10_nvidia.json
|
||||
|
||||
# ── Benchmark-specific system deps ────────────────────────────────────────────
|
||||
# libero_plus: the `wand` Python package requires ImageMagick headers.
|
||||
RUN case "${BENCHMARK}" in \
|
||||
libero_plus) \
|
||||
apt-get update && apt-get install -y --no-install-recommends \
|
||||
libmagickwand-dev \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* ;; \
|
||||
esac
|
||||
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
|
||||
USER user_lerobot
|
||||
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
RUN uv venv --seed --python python${PYTHON_VERSION}
|
||||
|
||||
# Copy only the dependency manifests first so Docker can cache this layer
|
||||
# independently of source-code changes.
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
|
||||
ARG UNBOUND_DEPS=false
|
||||
RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
sed -i 's/,[[:space:]]*<[0-9\.]*//g' pyproject.toml; \
|
||||
echo "Dependencies unbound:" && cat pyproject.toml; \
|
||||
fi
|
||||
|
||||
# Install lerobot core + the selected benchmark extra.
|
||||
# LIBERO-plus needs a dedicated install path because the upstream package is
|
||||
# import-broken when installed via the extras chain alone.
|
||||
RUN case "${BENCHMARK}" in \
|
||||
libero_plus) \
|
||||
PATH=/usr/bin:/bin:/lerobot/.venv/bin:$PATH /lerobot/.venv/bin/python -m pip install --no-cache-dir \
|
||||
"hf-libero>=0.1.3,<0.2.0" \
|
||||
"hf-egl-probe>=1.0.1" \
|
||||
"transformers>=5.3.0,<6.0.0" \
|
||||
"scipy>=1.14.0,<2.0.0" \
|
||||
"bddl>=1.0.1,<2.0.0" \
|
||||
"future" \
|
||||
"easydict>=1.9" \
|
||||
"wand" \
|
||||
"scikit-image>=0.20.0" \
|
||||
"gym>=0.25.0,<0.27.0" \
|
||||
&& git clone --depth 1 https://github.com/sylvestf/LIBERO-plus.git /tmp/LIBERO-plus \
|
||||
&& PATH=/usr/bin:/bin:/lerobot/.venv/bin:$PATH /lerobot/.venv/bin/python -m pip install --no-cache-dir --no-deps /tmp/LIBERO-plus \
|
||||
&& /lerobot/.venv/bin/python -c "import pathlib, site; pathlib.Path(site.getsitepackages()[0], 'libero_plus_repo.pth').write_text('/tmp/LIBERO-plus\n')" \
|
||||
&& /lerobot/.venv/bin/python -m pip install --no-cache-dir . \
|
||||
&& /lerobot/.venv/bin/python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& /lerobot/.venv/bin/python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')" ;; \
|
||||
libero) \
|
||||
uv pip install --no-cache ".[libero]" \
|
||||
&& /lerobot/.venv/bin/python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& /lerobot/.venv/bin/python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')" ;; \
|
||||
*) \
|
||||
uv pip install --no-cache ".[${BENCHMARK}]" ;; \
|
||||
esac
|
||||
|
||||
# LIBERO-plus requires ~6 GB of scene/texture/object assets from HuggingFace.
|
||||
# Download at build time so containers don't need network access at runtime.
|
||||
USER root
|
||||
COPY <<'FETCH_ASSETS' /tmp/fetch_assets.py
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download("Sylvest/LIBERO-plus", "assets.zip",
|
||||
repo_type="dataset", local_dir="/tmp/libero-plus-assets")
|
||||
FETCH_ASSETS
|
||||
COPY <<'VERIFY_ASSETS' /tmp/verify_assets.py
|
||||
from pathlib import Path
|
||||
from libero.libero import get_libero_path
|
||||
d = Path(get_libero_path("benchmark_root")) / "assets" / "scenes"
|
||||
assert d.is_dir(), f"assets missing at {d}"
|
||||
print("assets OK:", d)
|
||||
VERIFY_ASSETS
|
||||
RUN if [ "${BENCHMARK}" = "libero_plus" ]; then \
|
||||
apt-get update && apt-get install -y --no-install-recommends unzip \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/* \
|
||||
&& /lerobot/.venv/bin/python /tmp/fetch_assets.py \
|
||||
&& unzip -q /tmp/libero-plus-assets/assets.zip -d /tmp/libero-plus-unzipped \
|
||||
&& ASSETS_DIR=$(/lerobot/.venv/bin/python -c "from libero.libero import get_libero_path; print(get_libero_path('benchmark_root'))") \
|
||||
&& SRC=$(find /tmp/libero-plus-unzipped -type d -name assets | head -1) \
|
||||
&& mv "$SRC" "$ASSETS_DIR/assets" \
|
||||
&& chown -R user_lerobot:user_lerobot "$ASSETS_DIR/assets" \
|
||||
&& rm -rf /tmp/libero-plus-assets /tmp/libero-plus-unzipped /tmp/fetch_assets.py \
|
||||
&& /lerobot/.venv/bin/python /tmp/verify_assets.py \
|
||||
&& rm /tmp/verify_assets.py; \
|
||||
fi
|
||||
USER user_lerobot
|
||||
|
||||
# Triton requires its ptxas binary to be executable (NVIDIA-specific).
|
||||
RUN if [ -f /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas ]; then \
|
||||
chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas; \
|
||||
fi
|
||||
|
||||
# Verify EGL probe is importable (runtime GPU check requires NVIDIA drivers at container start).
|
||||
RUN /lerobot/.venv/bin/python -c "import egl_probe; print('egl_probe OK')" \
|
||||
2>/dev/null || echo 'NOTE: egl_probe not installed (non-libero build), skipping'
|
||||
|
||||
# Copy full source (tests, examples, configs, etc.)
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,78 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
ARG CUDA_VERSION=12.4.1
|
||||
ARG OS_VERSION=22.04
|
||||
FROM nvidia/cuda:${CUDA_VERSION}-base-ubuntu${OS_VERSION}
|
||||
|
||||
ARG PYTHON_VERSION=3.12
|
||||
|
||||
ENV DEBIAN_FRONTEND=noninteractive \
|
||||
MUJOCO_GL=egl \
|
||||
PYOPENGL_PLATFORM=egl \
|
||||
EGL_PLATFORM=device \
|
||||
NVIDIA_DRIVER_CAPABILITIES=all \
|
||||
NVIDIA_VISIBLE_DEVICES=all \
|
||||
PATH=/lerobot/.venv/bin:$PATH \
|
||||
# cmake 4.x removed backward compat with cmake_minimum_required < 3.5.
|
||||
# This env var re-enables it so packages like egl-probe can compile.
|
||||
CMAKE_POLICY_VERSION_MINIMUM=3.5
|
||||
|
||||
RUN apt-get update && apt-get install -y --no-install-recommends \
|
||||
software-properties-common build-essential git curl \
|
||||
libglib2.0-0 libgl1 libgl1-mesa-glx libgles2 \
|
||||
libegl1 libegl1-mesa libegl1-mesa-dev \
|
||||
libglew-dev libglfw3 libglvnd-dev \
|
||||
libosmesa6 libosmesa6-dev \
|
||||
libvulkan1 mesa-vulkan-drivers \
|
||||
libsm6 libxext6 libxrender-dev \
|
||||
ffmpeg libusb-1.0-0-dev speech-dispatcher libgeos-dev portaudio19-dev \
|
||||
cmake pkg-config ninja-build \
|
||||
&& add-apt-repository -y ppa:deadsnakes/ppa \
|
||||
&& apt-get update \
|
||||
&& apt-get install -y --no-install-recommends \
|
||||
python${PYTHON_VERSION} \
|
||||
python${PYTHON_VERSION}-venv \
|
||||
python${PYTHON_VERSION}-dev \
|
||||
&& curl -LsSf https://astral.sh/uv/install.sh | sh \
|
||||
&& mv /root/.local/bin/uv /usr/local/bin/uv \
|
||||
&& useradd --create-home --shell /bin/bash user_lerobot \
|
||||
&& apt-get clean && rm -rf /var/lib/apt/lists/*
|
||||
|
||||
# NVIDIA EGL + Vulkan vendor ICDs (lets GLVND find the GPU driver)
|
||||
RUN mkdir -p /usr/share/vulkan/icd.d /usr/share/glvnd/egl_vendor.d \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libGLX_nvidia.so.0","api_version":"1.2.155"}}\n' \
|
||||
> /usr/share/vulkan/icd.d/nvidia_icd.json \
|
||||
&& printf '{"file_format_version":"1.0.0","ICD":{"library_path":"libEGL_nvidia.so.0"}}\n' \
|
||||
> /usr/share/glvnd/egl_vendor.d/10_nvidia.json
|
||||
|
||||
WORKDIR /lerobot
|
||||
RUN chown -R user_lerobot:user_lerobot /lerobot
|
||||
USER user_lerobot
|
||||
|
||||
ENV HOME=/home/user_lerobot \
|
||||
HF_HOME=/home/user_lerobot/.cache/huggingface \
|
||||
HF_LEROBOT_HOME=/home/user_lerobot/.cache/huggingface/lerobot \
|
||||
TORCH_HOME=/home/user_lerobot/.cache/torch \
|
||||
TRITON_CACHE_DIR=/home/user_lerobot/.cache/triton
|
||||
|
||||
RUN uv venv --seed --python python${PYTHON_VERSION}
|
||||
|
||||
COPY --chown=user_lerobot:user_lerobot setup.py pyproject.toml README.md MANIFEST.in ./
|
||||
COPY --chown=user_lerobot:user_lerobot src/ src/
|
||||
RUN uv pip install --no-cache .
|
||||
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
RUN uv pip install --no-cache ".[libero]" \
|
||||
&& python -c "import libero"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,47 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# Install libero_plus deps explicitly rather than via ".[libero_plus]" extras chain.
|
||||
# uv has a bug where it considers packages "already resolved" when coming through
|
||||
# a nested lerobot[libero] → lerobot[libero_plus] extras chain, silently skipping them.
|
||||
RUN uv pip install --no-cache \
|
||||
"hf-libero>=0.1.3,<0.2.0" \
|
||||
"hf-egl-probe>=1.0.1" \
|
||||
"transformers>=5.3.0,<6.0.0" \
|
||||
"scipy>=1.14.0,<2.0.0" \
|
||||
"bddl>=1.0.1,<2.0.0" \
|
||||
"future" \
|
||||
"easydict>=1.9" \
|
||||
"wand" \
|
||||
"scikit-image>=0.20.0" \
|
||||
"gym>=0.25.0,<0.27.0"
|
||||
|
||||
# Clone LIBERO-plus; install with --no-deps (runtime deps declared above via hf-libero).
|
||||
# Add .pth so the libero module can locate its data files at runtime.
|
||||
RUN git clone --depth 1 https://github.com/sylvestf/LIBERO-plus.git /tmp/LIBERO-plus \
|
||||
&& uv pip install --no-cache --no-deps /tmp/LIBERO-plus \
|
||||
&& python -c "import pathlib, site; pathlib.Path(site.getsitepackages()[0], 'libero_plus_repo.pth').write_text('/tmp/LIBERO-plus\n')" \
|
||||
&& python -c "\
|
||||
import os, yaml, importlib.util; \
|
||||
root = os.path.dirname(importlib.util.find_spec('libero.libero').origin); \
|
||||
d = dict(benchmark_root=root, bddl_files=os.path.join(root,'bddl_files'), \
|
||||
init_states=os.path.join(root,'init_files'), datasets=os.path.join(root,'..','datasets'), \
|
||||
assets=os.path.join(root,'assets')); \
|
||||
cfg_dir = os.path.expanduser('~/.libero'); os.makedirs(cfg_dir, exist_ok=True); \
|
||||
yaml.dump(d, open(os.path.join(cfg_dir,'config.yaml'),'w')); print('libero config created')" \
|
||||
&& python -c "from libero.libero import benchmark, get_libero_path; print('libero OK')"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,20 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
RUN uv pip install --no-cache ".[metaworld]" \
|
||||
&& python -c "import metaworld"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,40 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# robocasa README says to use master branch of ARISE-Initiative/robosuite.
|
||||
# Install it with deps (robosuite from master has modern dep declarations).
|
||||
RUN git clone --depth 1 https://github.com/ARISE-Initiative/robosuite.git /tmp/robosuite \
|
||||
&& uv pip install --no-cache /tmp/robosuite
|
||||
|
||||
# Clone robocasa and install with --no-deps to skip its lerobot==0.3.3 pin.
|
||||
# Install robocasa's actual runtime deps explicitly instead.
|
||||
RUN git clone --depth 1 https://github.com/robocasa/robocasa.git /tmp/robocasa \
|
||||
&& uv pip install --no-cache --no-deps /tmp/robocasa \
|
||||
&& uv pip install --no-cache \
|
||||
"scikit-image>=0.20.0" \
|
||||
"numba>=0.61.0,<0.62.0" \
|
||||
"mujoco==3.3.1" \
|
||||
"h5py" \
|
||||
"lxml" \
|
||||
"tianshou==0.4.10" \
|
||||
"easydict>=1.9"
|
||||
|
||||
# robocasa/__init__.py asserts numpy.__version__ in ["2.2.5"] — pin it last
|
||||
# so no subsequent package can bump it away.
|
||||
RUN uv pip install --no-cache "numpy==2.2.5" \
|
||||
&& python -c "import robocasa"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
@@ -0,0 +1,26 @@
|
||||
# Copyright 2026 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
FROM lerobot-eval-base:latest
|
||||
|
||||
# mani-skill==3.0.0b21 (robomme dep) pins gymnasium==0.29.1 and numpy<2.0.0,
|
||||
# conflicting with lerobot's gymnasium>=1.1.1 and numpy>=2.0.0.
|
||||
# Both overrides are safe at runtime:
|
||||
# - gymnasium 0.29.x has the same 5-tuple step() API as 1.x (since gym 0.26)
|
||||
# - numpy 1.26.4 is API-compatible with lerobot's actual usage (no 2.x-only APIs used)
|
||||
RUN printf 'gymnasium==0.29.1\nnumpy==1.26.4\n' > /tmp/robomme_override.txt \
|
||||
&& uv pip install --no-cache --override /tmp/robomme_override.txt ".[robomme]" \
|
||||
&& python -c "import robomme"
|
||||
|
||||
CMD ["/bin/bash"]
|
||||
Executable
+120
@@ -0,0 +1,120 @@
|
||||
#!/usr/bin/env bash
|
||||
# Build (and optionally push) all lerobot benchmark eval images.
|
||||
#
|
||||
# Usage:
|
||||
# # Build locally only (for testing on this machine)
|
||||
# bash docker/build_benchmark_images.sh
|
||||
#
|
||||
# # Build and push to Docker Hub under your org
|
||||
# bash docker/build_benchmark_images.sh --push --hub_org=pepijn223
|
||||
#
|
||||
# # Force-rebuild base image (e.g. after Dockerfile.eval-base changes)
|
||||
# bash docker/build_benchmark_images.sh --no-cache-base --push --hub_org=pepijn223
|
||||
#
|
||||
# # Build only specific benchmarks
|
||||
# bash docker/build_benchmark_images.sh --benchmarks="libero_plus robomme"
|
||||
#
|
||||
# After building, run eval with:
|
||||
# lerobot-eval --eval.runtime=docker --eval.docker.pull=false \
|
||||
# --eval.docker.image=<hub_org>/lerobot-benchmark-<benchmark>:latest ...
|
||||
# OR (if run locally with the default tag):
|
||||
# lerobot-eval --eval.runtime=docker --eval.docker.pull=false \
|
||||
# --env.type=<benchmark> ... # auto-resolves to lerobot-benchmark-<benchmark>
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
PUSH=false
|
||||
HUB_ORG=""
|
||||
BENCHMARKS="libero libero_plus robomme robocasa metaworld"
|
||||
NO_CACHE_BASE=false
|
||||
PROGRESS="auto"
|
||||
|
||||
for arg in "$@"; do
|
||||
case "$arg" in
|
||||
--push) PUSH=true ;;
|
||||
--hub_org=*) HUB_ORG="${arg#*=}" ;;
|
||||
--benchmarks=*) BENCHMARKS="${arg#*=}" ;;
|
||||
--no-cache-base) NO_CACHE_BASE=true ;;
|
||||
--plain) PROGRESS="plain" ;;
|
||||
*) echo "Unknown arg: $arg"; exit 1 ;;
|
||||
esac
|
||||
done
|
||||
|
||||
SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
|
||||
REPO_ROOT="$(cd "${SCRIPT_DIR}/.." && pwd)"
|
||||
|
||||
if [[ "$PUSH" == "true" && -z "$HUB_ORG" ]]; then
|
||||
echo "ERROR: --push requires --hub_org=<your-dockerhub-org>"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
ok() { echo "[OK] $*"; }
|
||||
fail() { echo "[FAIL] $*"; exit 1; }
|
||||
|
||||
BASE_CACHE_FLAG=""
|
||||
if [[ "$NO_CACHE_BASE" == "true" ]]; then
|
||||
BASE_CACHE_FLAG="--no-cache"
|
||||
fi
|
||||
|
||||
echo "=== Building lerobot-eval-base ==="
|
||||
docker build \
|
||||
${BASE_CACHE_FLAG} \
|
||||
--progress="${PROGRESS}" \
|
||||
-f "${SCRIPT_DIR}/Dockerfile.eval-base" \
|
||||
-t lerobot-eval-base:latest \
|
||||
"${REPO_ROOT}" || fail "lerobot-eval-base build failed"
|
||||
ok "lerobot-eval-base"
|
||||
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
LOCAL_TAG="lerobot-benchmark-${BENCHMARK}:latest"
|
||||
DOCKERFILE="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK//_/-}"
|
||||
|
||||
# Handle underscore → hyphen mapping for filename lookup
|
||||
DOCKERFILE_HYPHEN="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK//_/-}"
|
||||
DOCKERFILE_UNDERSCORE="${SCRIPT_DIR}/Dockerfile.eval-${BENCHMARK}"
|
||||
if [[ -f "$DOCKERFILE_HYPHEN" ]]; then
|
||||
DOCKERFILE="$DOCKERFILE_HYPHEN"
|
||||
elif [[ -f "$DOCKERFILE_UNDERSCORE" ]]; then
|
||||
DOCKERFILE="$DOCKERFILE_UNDERSCORE"
|
||||
else
|
||||
fail "No Dockerfile found for benchmark '${BENCHMARK}' (tried ${DOCKERFILE_HYPHEN} and ${DOCKERFILE_UNDERSCORE})"
|
||||
fi
|
||||
|
||||
echo ""
|
||||
echo "=== Building ${LOCAL_TAG} from $(basename ${DOCKERFILE}) ==="
|
||||
docker build \
|
||||
--progress="${PROGRESS}" \
|
||||
-f "${DOCKERFILE}" \
|
||||
-t "${LOCAL_TAG}" \
|
||||
"${REPO_ROOT}" || fail "${LOCAL_TAG} build failed"
|
||||
ok "${LOCAL_TAG}"
|
||||
|
||||
if [[ "$PUSH" == "true" ]]; then
|
||||
HUB_TAG="${HUB_ORG}/lerobot-benchmark-${BENCHMARK}:latest"
|
||||
docker tag "${LOCAL_TAG}" "${HUB_TAG}"
|
||||
docker push "${HUB_TAG}" || fail "push ${HUB_TAG} failed"
|
||||
ok "Pushed ${HUB_TAG}"
|
||||
fi
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "=== Smoke-testing images ==="
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
LOCAL_TAG="lerobot-benchmark-${BENCHMARK}:latest"
|
||||
echo " Smoke test: ${LOCAL_TAG}"
|
||||
docker run --rm -e BENCHMARK="${BENCHMARK}" \
|
||||
"${LOCAL_TAG}" bash docker/smoke_test_benchmark.sh \
|
||||
&& ok "smoke test ${BENCHMARK}" \
|
||||
|| echo "[WARN] smoke test failed for ${BENCHMARK} (may need GPU)"
|
||||
done
|
||||
|
||||
echo ""
|
||||
echo "All benchmark images built successfully."
|
||||
if [[ "$PUSH" == "true" ]]; then
|
||||
echo "Pushed to Docker Hub under: ${HUB_ORG}/"
|
||||
echo ""
|
||||
echo "To use Hub images in eval, pass:"
|
||||
for BENCHMARK in $BENCHMARKS; do
|
||||
echo " --eval.docker.image=${HUB_ORG}/lerobot-benchmark-${BENCHMARK}:latest"
|
||||
done
|
||||
fi
|
||||
Executable
+115
@@ -0,0 +1,115 @@
|
||||
#!/usr/bin/env bash
|
||||
# Smoke-test a benchmark container: verifies imports and CLI entry-points.
|
||||
#
|
||||
# Build and run for a specific benchmark:
|
||||
# docker build --build-arg BENCHMARK=libero -f docker/Dockerfile.benchmark -t lerobot-benchmark-libero .
|
||||
# docker run --gpus all --rm -e BENCHMARK=libero lerobot-benchmark-libero bash docker/smoke_test_benchmark.sh
|
||||
#
|
||||
# Test all benchmarks individually:
|
||||
# for b in libero libero_plus robomme robocasa; do
|
||||
# docker build --build-arg BENCHMARK=$b -f docker/Dockerfile.benchmark -t lerobot-benchmark-$b .
|
||||
# docker run --gpus all --rm -e BENCHMARK=$b lerobot-benchmark-$b bash docker/smoke_test_benchmark.sh
|
||||
# done
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
BENCHMARK="${BENCHMARK:-libero}"
|
||||
PASS=0
|
||||
FAIL=0
|
||||
|
||||
ok() { echo "[PASS] $*"; PASS=$((PASS + 1)); }
|
||||
fail() { echo "[FAIL] $*"; FAIL=$((FAIL + 1)); }
|
||||
|
||||
python_import() {
|
||||
local module="$1"
|
||||
if python -c "import ${module}" 2>/dev/null; then
|
||||
ok "import ${module}"
|
||||
else
|
||||
fail "import ${module}"
|
||||
fi
|
||||
}
|
||||
|
||||
cli_help() {
|
||||
local cmd="$1"
|
||||
if "${cmd}" --help > /dev/null 2>&1; then
|
||||
ok "${cmd} --help"
|
||||
else
|
||||
fail "${cmd} --help"
|
||||
fi
|
||||
}
|
||||
|
||||
echo "=== Smoke test: benchmark=${BENCHMARK} ==="
|
||||
|
||||
# ── lerobot core ──────────────────────────────────────────────────────────────
|
||||
python_import "lerobot"
|
||||
python_import "lerobot.envs"
|
||||
python_import "lerobot.configs.eval"
|
||||
cli_help "lerobot-eval"
|
||||
|
||||
# ── Benchmark-specific env import ─────────────────────────────────────────────
|
||||
case "${BENCHMARK}" in
|
||||
libero)
|
||||
python_import "lerobot.envs.libero"
|
||||
python -c "
|
||||
from lerobot.envs.configs import LiberoEnv
|
||||
cfg = LiberoEnv(task='libero_spatial/KITCHEN_SCENE1_open_the_bottom_drawer_of_the_cabinet')
|
||||
print(' LiberoEnv config OK:', cfg.type)
|
||||
" && ok "LiberoEnv config instantiation" || fail "LiberoEnv config instantiation"
|
||||
;;
|
||||
|
||||
libero_plus)
|
||||
python_import "lerobot.envs.libero"
|
||||
python -c "
|
||||
from lerobot.envs.configs import LiberoPlusEnv
|
||||
cfg = LiberoPlusEnv()
|
||||
print(' LiberoPlusEnv config OK:', cfg.type)
|
||||
" && ok "LiberoPlusEnv config instantiation" || fail "LiberoPlusEnv config instantiation"
|
||||
# Verify the LIBERO-plus package itself is importable
|
||||
python_import "libero"
|
||||
python_import "robosuite"
|
||||
;;
|
||||
|
||||
robomme)
|
||||
python_import "lerobot.envs.robomme"
|
||||
python -c "
|
||||
from lerobot.envs.robomme import ROBOMME_TASKS, RoboMMEGymEnv
|
||||
assert len(ROBOMME_TASKS) == 16, f'Expected 16 tasks, got {len(ROBOMME_TASKS)}'
|
||||
print(' ROBOMME_TASKS OK:', ROBOMME_TASKS[:3], '...')
|
||||
" && ok "RoboMME task list" || fail "RoboMME task list"
|
||||
python -c "
|
||||
from lerobot.envs.configs import RoboMMEEnv
|
||||
cfg = RoboMMEEnv(task='PickXtimes')
|
||||
print(' RoboMMEEnv config OK:', cfg.type)
|
||||
" && ok "RoboMMEEnv config instantiation" || fail "RoboMMEEnv config instantiation"
|
||||
python_import "robomme"
|
||||
;;
|
||||
|
||||
robocasa)
|
||||
python_import "lerobot.envs.robocasa"
|
||||
python -c "
|
||||
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM
|
||||
assert ACTION_DIM == 12, f'Expected ACTION_DIM=12, got {ACTION_DIM}'
|
||||
assert STATE_DIM == 16, f'Expected STATE_DIM=16, got {STATE_DIM}'
|
||||
print(' ACTION_DIM:', ACTION_DIM, ' STATE_DIM:', STATE_DIM)
|
||||
" && ok "RoboCasa constants" || fail "RoboCasa constants"
|
||||
python -c "
|
||||
from lerobot.envs.configs import RoboCasaEnv
|
||||
cfg = RoboCasaEnv(task='PickPlaceCounterToCabinet')
|
||||
print(' RoboCasaEnv config OK:', cfg.type)
|
||||
" && ok "RoboCasaEnv config instantiation" || fail "RoboCasaEnv config instantiation"
|
||||
python_import "robocasa"
|
||||
python_import "robosuite"
|
||||
;;
|
||||
|
||||
*)
|
||||
echo "Unknown BENCHMARK='${BENCHMARK}'. Valid values: libero, libero_plus, robomme, robocasa"
|
||||
exit 1
|
||||
;;
|
||||
esac
|
||||
|
||||
# ── Summary ───────────────────────────────────────────────────────────────────
|
||||
echo ""
|
||||
echo "=== Results: ${PASS} passed, ${FAIL} failed ==="
|
||||
if [ "${FAIL}" -gt 0 ]; then
|
||||
exit 1
|
||||
fi
|
||||
@@ -19,6 +19,8 @@
|
||||
title: Multi GPU training
|
||||
- local: peft_training
|
||||
title: Training with PEFT (e.g., LoRA)
|
||||
- local: benchmark_training
|
||||
title: Benchmark Training & Evaluation
|
||||
title: "Tutorials"
|
||||
- sections:
|
||||
- local: lerobot-dataset-v3
|
||||
|
||||
@@ -0,0 +1,398 @@
|
||||
# Benchmark Training & Evaluation
|
||||
|
||||
This guide explains how to train and evaluate policies on the simulation benchmarks
|
||||
integrated in LeRobot: **LIBERO**, **LIBERO-plus**, **MetaWorld**, **RoboCasa**, and **RoboMME**.
|
||||
|
||||
The workflow is:
|
||||
|
||||
1. Pick one or more benchmarks.
|
||||
2. For each benchmark, train a policy on its combined dataset (multi-GPU).
|
||||
3. Upload the trained policy to the Hugging Face Hub.
|
||||
4. Evaluate the policy on every task suite within that benchmark.
|
||||
|
||||
## Prerequisites
|
||||
|
||||
Install the benchmark-specific dependencies for the environments you want to evaluate on:
|
||||
|
||||
```bash
|
||||
# LIBERO (original)
|
||||
pip install -e ".[libero]"
|
||||
|
||||
# LIBERO-plus
|
||||
pip install -e ".[libero_plus]"
|
||||
|
||||
# MetaWorld
|
||||
pip install -e ".[metaworld]"
|
||||
|
||||
# RoboCasa
|
||||
pip install -e ".[robocasa]"
|
||||
|
||||
# RoboMME
|
||||
pip install -e ".[robomme]"
|
||||
```
|
||||
|
||||
`libero_plus` includes the same EGL probe dependencies as `libero` so headless
|
||||
renderer setup is consistent between both installs.
|
||||
|
||||
If your environment has CMake build-isolation issues, use the same fallback as
|
||||
standard LIBERO installs:
|
||||
|
||||
```bash
|
||||
PATH=/usr/bin:/bin:$PATH pip install --no-build-isolation -e ".[libero-plus]"
|
||||
```
|
||||
|
||||
For multi-GPU training you also need [Accelerate](https://huggingface.co/docs/accelerate):
|
||||
|
||||
```bash
|
||||
pip install accelerate
|
||||
```
|
||||
|
||||
## Docker-isolated evaluation (EnvHub)
|
||||
|
||||
LeRobot eval now supports running the full eval worker in a Docker container
|
||||
while keeping policy loading compatible with local checkpoints and local code changes.
|
||||
|
||||
Use `lerobot-eval` with `--eval.runtime=docker`:
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=outputs/train/my_policy/checkpoints/050000/pretrained_model \
|
||||
--env.type=libero_plus \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1 \
|
||||
--eval.n_episodes=10 \
|
||||
--eval.batch_size=10
|
||||
```
|
||||
|
||||
`eval.docker.envhub_ref` is optional. If omitted, LeRobot resolves a default
|
||||
image from `env.type`. You can also override the image directly:
|
||||
|
||||
```bash
|
||||
--eval.docker.image=docker://ghcr.io/huggingface/lerobot-eval-libero-plus:latest
|
||||
```
|
||||
|
||||
By default (`eval.docker.use_local_code=true`), the local repository is mounted
|
||||
in the container and added to `PYTHONPATH`, so edited policy/env code and local
|
||||
checkpoints continue to work without rebuilding the image for each change.
|
||||
|
||||
Common Docker runtime options:
|
||||
|
||||
```bash
|
||||
--eval.docker.pull=true \
|
||||
--eval.docker.gpus=all \
|
||||
--eval.docker.shm_size=8g \
|
||||
--eval.docker.use_local_code=true
|
||||
```
|
||||
|
||||
The benchmark runner supports the same Docker eval path (extra args are
|
||||
forwarded to each generated `lerobot-eval` call):
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus,robocasa \
|
||||
--hub-user $HF_USER \
|
||||
--n-episodes 50 \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.pull=true
|
||||
```
|
||||
|
||||
Build benchmark images locally:
|
||||
|
||||
```bash
|
||||
make build-eval-images
|
||||
```
|
||||
|
||||
## Fast single-machine eval tuning
|
||||
|
||||
`lerobot-eval` now has two orthogonal throughput knobs:
|
||||
|
||||
- `eval.batch_size`: number of sub-envs per task (inside one vector env).
|
||||
- `env.max_parallel_tasks`: number of tasks scheduled concurrently.
|
||||
- `eval.instance_count`: number of full eval instances (process-level sharding).
|
||||
|
||||
Use them in this order:
|
||||
|
||||
1. Increase `eval.batch_size` first for per-task throughput.
|
||||
2. Then increase `env.max_parallel_tasks` to overlap tasks, while monitoring RAM/VRAM.
|
||||
3. Optionally increase `eval.instance_count` for process-level parallelism (best with enough CPU/RAM and small models).
|
||||
|
||||
The eval logs print the active scheduler mode (`sequential`, `threaded`, or `batched_lazy`) so you can verify the effective concurrency path.
|
||||
|
||||
### Suggested starting points
|
||||
|
||||
| Benchmark | Conservative | Faster (single GPU) | Notes |
|
||||
|---|---|---|---|
|
||||
| `libero` / `libero_plus` | `eval.batch_size=1`, `env.max_parallel_tasks=4` | `eval.batch_size=1`, `env.max_parallel_tasks=16` | For large suite sweeps, increase `max_parallel_tasks` before `batch_size` to avoid MuJoCo memory spikes. |
|
||||
| `metaworld` | `eval.batch_size=8`, `env.max_parallel_tasks=1` | `eval.batch_size=16`, `env.max_parallel_tasks=2` | Prefer larger per-task vectorization first. |
|
||||
| `robocasa` | `eval.batch_size=4`, `env.max_parallel_tasks=1` | `eval.batch_size=8`, `env.max_parallel_tasks=2` | Rendering/memory can dominate at high image resolution. |
|
||||
| `robomme` | `eval.batch_size=4`, `env.max_parallel_tasks=1` | `eval.batch_size=8`, `env.max_parallel_tasks=2` | Start small and scale gradually with task count. |
|
||||
|
||||
### Local fast eval recipe
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--env.max_parallel_tasks=16 \
|
||||
--eval.instance_count=2 \
|
||||
--rename_map='{"observation.images.image":"observation.images.camera1","observation.images.image2":"observation.images.camera2"}' \
|
||||
--output_dir=outputs/eval/smolvla_libero_plus \
|
||||
--push_to_hub=true
|
||||
```
|
||||
|
||||
### Docker fast eval recipe
|
||||
|
||||
```bash
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1 \
|
||||
--eval.docker.gpus=all \
|
||||
--eval.docker.shm_size=16g \
|
||||
--eval.n_episodes=1 \
|
||||
--eval.batch_size=1 \
|
||||
--env.max_parallel_tasks=16
|
||||
```
|
||||
|
||||
## Quick start — single benchmark
|
||||
|
||||
Train SmolVLA on LIBERO-plus with 4 GPUs for 50 000 steps:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--batch-size 32 \
|
||||
--wandb
|
||||
```
|
||||
|
||||
This trains on the combined LIBERO-plus dataset and pushes the checkpoint to
|
||||
`$HF_USER/smolvla_libero_plus` on the Hub.
|
||||
|
||||
Then evaluate on **all four** LIBERO suites (spatial, object, goal, 10):
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus \
|
||||
--hub-user $HF_USER \
|
||||
--n-episodes 50
|
||||
```
|
||||
|
||||
This automatically runs a separate `lerobot-eval` for each suite.
|
||||
|
||||
## Full sweep — multiple benchmarks
|
||||
|
||||
Run training **and** evaluation across all benchmarks:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark all \
|
||||
--benchmarks libero,libero_plus,metaworld,robocasa,robomme \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--batch-size 32 \
|
||||
--wandb \
|
||||
--push-eval-to-hub
|
||||
```
|
||||
|
||||
For each benchmark the runner:
|
||||
1. Trains a policy on its dataset.
|
||||
2. Evaluates on every eval task in the benchmark (e.g. 4 suites for LIBERO).
|
||||
3. Pushes HF-native `.eval_results` rows (and optional artifacts) to the Hub.
|
||||
|
||||
<Tip>
|
||||
|
||||
Use `--dry-run` to print the exact `lerobot-train` / `lerobot-eval` commands without executing them, so you can inspect or modify them before running.
|
||||
|
||||
</Tip>
|
||||
|
||||
## Using the CLI directly (without the benchmark runner)
|
||||
|
||||
You can also compose the commands yourself. The benchmark runner is a thin wrapper; here is what it does under the hood.
|
||||
|
||||
### Training
|
||||
|
||||
```bash
|
||||
accelerate launch \
|
||||
--multi_gpu \
|
||||
--num_processes=4 \
|
||||
$(which lerobot-train) \
|
||||
--policy.path=lerobot/smolvla_base \
|
||||
--dataset.repo_id=$HF_USER/libero_plus \
|
||||
--policy.repo_id=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--env.task=libero_spatial \
|
||||
--steps=50000 \
|
||||
--batch_size=32 \
|
||||
--eval_freq=10000 \
|
||||
--save_freq=10000 \
|
||||
--output_dir=outputs/train/smolvla_libero_plus \
|
||||
--job_name=smolvla_libero_plus \
|
||||
--policy.push_to_hub=true \
|
||||
--wandb.enable=true
|
||||
```
|
||||
|
||||
### Evaluation (run once per suite)
|
||||
|
||||
```bash
|
||||
for SUITE in libero_spatial libero_object libero_goal libero_10; do
|
||||
lerobot-eval \
|
||||
--policy.path=$HF_USER/smolvla_libero_plus \
|
||||
--env.type=libero_plus \
|
||||
--env.task=$SUITE \
|
||||
--eval.n_episodes=50 \
|
||||
--eval.batch_size=10 \
|
||||
--output_dir=outputs/eval/smolvla_libero_plus/$SUITE \
|
||||
--policy.device=cuda \
|
||||
--push_to_hub=true \
|
||||
--benchmark_dataset_id=lerobot/sim-benchmarks
|
||||
done
|
||||
```
|
||||
|
||||
## Available benchmarks
|
||||
|
||||
| Benchmark | Env type | Dataset | Eval tasks | Action dim |
|
||||
|---|---|---|---|---|
|
||||
| `libero` | `libero` | `{hub_user}/libero` | spatial, object, goal, 10 | 7 |
|
||||
| `libero_plus` | `libero_plus` | `{hub_user}/libero_plus` | spatial, object, goal, 10 | 7 |
|
||||
| `metaworld` | `metaworld` | `{hub_user}/metaworld` | push-v2 | 4 |
|
||||
| `robocasa` | `robocasa` | `{hub_user}/robocasa` | PickPlaceCounterToCabinet | 12 |
|
||||
| `robomme` | `robomme` | `{hub_user}/robomme` | PickXtimes | 8 |
|
||||
|
||||
Run `lerobot-benchmark list` to see the full registry with all eval tasks.
|
||||
|
||||
## Policy naming convention
|
||||
|
||||
The benchmark runner stores trained policies under:
|
||||
|
||||
```
|
||||
{hub_user}/{policy_name}_{benchmark}
|
||||
```
|
||||
|
||||
The default `--policy-name` is `smolvla`. So training on `libero_plus` as user `alice` produces `alice/smolvla_libero_plus`.
|
||||
|
||||
You can override this, e.g. `--policy-name pi05` if training π₀.₅ instead.
|
||||
|
||||
## Multi-GPU considerations
|
||||
|
||||
The effective batch size is `batch_size × num_gpus`. With `--batch-size=32` and
|
||||
`--num-gpus=4`, you train with an effective batch of 128 per step. LeRobot does **not**
|
||||
auto-scale the learning rate; see the [Multi-GPU Training guide](./multi_gpu_training) for
|
||||
details on when and how to adjust it.
|
||||
|
||||
## Custom benchmarks
|
||||
|
||||
To add a new benchmark, edit the `BENCHMARK_REGISTRY` in
|
||||
`src/lerobot/scripts/lerobot_benchmark.py`:
|
||||
|
||||
```python
|
||||
from lerobot.scripts.lerobot_benchmark import BenchmarkEntry, BENCHMARK_REGISTRY
|
||||
|
||||
BENCHMARK_REGISTRY["my_benchmark"] = BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/my_dataset",
|
||||
env_type="my_env",
|
||||
env_task="MyDefaultTask",
|
||||
eval_tasks=["TaskA", "TaskB", "TaskC"],
|
||||
)
|
||||
```
|
||||
|
||||
Then use `--benchmarks my_benchmark` as usual. The runner will train once and
|
||||
evaluate separately on TaskA, TaskB, and TaskC.
|
||||
|
||||
## Outputs
|
||||
|
||||
After training and evaluation, your outputs directory looks like:
|
||||
|
||||
```
|
||||
outputs/
|
||||
├── train/
|
||||
│ ├── smolvla_libero/
|
||||
│ │ ├── checkpoints/
|
||||
│ │ └── ...
|
||||
│ ├── smolvla_libero_plus/
|
||||
│ ├── smolvla_robocasa/
|
||||
│ └── smolvla_robomme/
|
||||
└── eval/
|
||||
├── smolvla_libero/
|
||||
│ ├── libero_spatial/
|
||||
│ │ ├── eval_info.json
|
||||
│ │ └── videos/
|
||||
│ ├── libero_object/
|
||||
│ ├── libero_goal/
|
||||
│ └── libero_10/
|
||||
├── smolvla_libero_plus/
|
||||
│ ├── libero_spatial/
|
||||
│ ├── libero_object/
|
||||
│ ├── libero_goal/
|
||||
│ └── libero_10/
|
||||
├── smolvla_robocasa/
|
||||
└── smolvla_robomme/
|
||||
```
|
||||
|
||||
Each `eval_info.json` contains per-episode rewards, success rates, and aggregate metrics.
|
||||
|
||||
## HF Eval Results + Leaderboard
|
||||
|
||||
LeRobot publishes benchmark scores using Hugging Face's native
|
||||
`/.eval_results/*.yaml` format, which powers model-page eval cards and
|
||||
benchmark leaderboards.
|
||||
|
||||
Add `--push-eval-to-hub` to push results after each eval run:
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus,robocasa \
|
||||
--hub-user $HF_USER \
|
||||
--benchmark-dataset-id lerobot/sim-benchmarks \
|
||||
--push-eval-to-hub
|
||||
```
|
||||
|
||||
This writes one or more files under `.eval_results/` in the model repo, for example:
|
||||
|
||||
```yaml
|
||||
- dataset:
|
||||
id: lerobot/sim-benchmarks
|
||||
task_id: libero_plus/spatial
|
||||
value: 82.4
|
||||
notes: lerobot-eval
|
||||
```
|
||||
|
||||
Notes:
|
||||
- `--benchmark-dataset-id` points to your consolidated benchmark dataset repo.
|
||||
- `task_id` values are derived from `env.type` and evaluated suite/task names.
|
||||
- Eval artifacts (`eval_info.json`, `eval_config.json`, videos) are still uploaded
|
||||
for provenance, but leaderboard ranking comes from `.eval_results`.
|
||||
|
||||
## Passing extra arguments
|
||||
|
||||
Any arguments after the recognized flags are forwarded to `lerobot-train` or
|
||||
`lerobot-eval`.
|
||||
|
||||
Example (training): use PEFT/LoRA during training.
|
||||
|
||||
```bash
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user $HF_USER \
|
||||
--num-gpus 4 \
|
||||
--steps 50000 \
|
||||
--peft.method_type=LORA --peft.r=16
|
||||
```
|
||||
|
||||
Example (evaluation): forward Docker runtime flags to each `lerobot-eval` call.
|
||||
|
||||
```bash
|
||||
lerobot-benchmark eval \
|
||||
--benchmarks libero_plus \
|
||||
--hub-user $HF_USER \
|
||||
--eval.runtime=docker \
|
||||
--eval.docker.envhub_ref=envhub://lerobot/libero_plus@v1
|
||||
```
|
||||
+38
-1
@@ -174,7 +174,41 @@ video_benchmark = ["scikit-image>=0.23.2,<0.26.0", "pandas>=2.2.2,<2.4.0"]
|
||||
# NOTE: Explicitly listing scipy helps flatten the dependecy tree.
|
||||
aloha = ["gym-aloha>=0.1.2,<0.2.0", "lerobot[scipy-dep]"]
|
||||
pusht = ["gym-pusht>=0.1.5,<0.2.0", "pymunk>=6.6.0,<7.0.0"] # TODO: Fix pymunk version in gym-pusht instead
|
||||
libero = ["lerobot[transformers-dep]", "hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'", "lerobot[scipy-dep]"]
|
||||
libero = [
|
||||
"lerobot[transformers-dep]",
|
||||
"hf-libero>=0.1.3,<0.2.0; sys_platform == 'linux'",
|
||||
# hf-egl-probe is the fixed fork of egl-probe (robomimic transitive dep).
|
||||
# egl-probe's CMakeLists.txt requires cmake_minimum_required < 3.5 which
|
||||
# modern cmake rejects. Installing hf-egl-probe first satisfies the egl_probe
|
||||
# import without source compilation.
|
||||
"hf-egl-probe>=1.0.1; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
libero_plus = [
|
||||
# Inherit all of libero's deps (hf-libero → robosuite/robomimic/egl-probe/scipy/transformers).
|
||||
# LIBERO-plus extends LIBERO with extra task suites; its Python module is installed
|
||||
# from the git clone in Dockerfile.eval-libero-plus (overrides hf-libero via .pth).
|
||||
"lerobot[libero]",
|
||||
# Additional runtime deps declared by LIBERO-plus but absent from its setup.py:
|
||||
"bddl>=1.0.1,<2.0.0; sys_platform == 'linux'",
|
||||
"future; sys_platform == 'linux'", # bddl transitive dep not declared in its metadata
|
||||
"easydict>=1.9; sys_platform == 'linux'",
|
||||
"wand; sys_platform == 'linux'",
|
||||
"scikit-image>=0.20.0; sys_platform == 'linux'",
|
||||
"gym>=0.25.0,<0.27.0; sys_platform == 'linux'",
|
||||
]
|
||||
libero-plus = ["lerobot[libero_plus]"]
|
||||
robomme = [
|
||||
"robomme @ git+https://github.com/RoboMME/robomme_benchmark.git@main ; sys_platform == 'linux'",
|
||||
]
|
||||
robocasa = [
|
||||
# robocasa and its robosuite fork are not on PyPI; both installed from source
|
||||
# in Dockerfile.eval-robocasa (requires ARISE-Initiative/robosuite@robocasa_v1.4.1
|
||||
# for PandaOmron and other robocasa-specific robots).
|
||||
"easydict>=1.9; sys_platform == 'linux'",
|
||||
"scikit-image>=0.20.0; sys_platform == 'linux'",
|
||||
"lerobot[scipy-dep]",
|
||||
]
|
||||
metaworld = ["metaworld==3.0.0", "lerobot[scipy-dep]"]
|
||||
|
||||
# All
|
||||
@@ -220,6 +254,7 @@ lerobot-replay="lerobot.scripts.lerobot_replay:main"
|
||||
lerobot-setup-motors="lerobot.scripts.lerobot_setup_motors:main"
|
||||
lerobot-teleoperate="lerobot.scripts.lerobot_teleoperate:main"
|
||||
lerobot-eval="lerobot.scripts.lerobot_eval:main"
|
||||
lerobot-eval-worker="lerobot.scripts.lerobot_eval_worker:main"
|
||||
lerobot-train="lerobot.scripts.lerobot_train:main"
|
||||
lerobot-train-tokenizer="lerobot.scripts.lerobot_train_tokenizer:main"
|
||||
lerobot-dataset-viz="lerobot.scripts.lerobot_dataset_viz:main"
|
||||
@@ -227,7 +262,9 @@ lerobot-info="lerobot.scripts.lerobot_info:main"
|
||||
lerobot-find-joint-limits="lerobot.scripts.lerobot_find_joint_limits:main"
|
||||
lerobot-imgtransform-viz="lerobot.scripts.lerobot_imgtransform_viz:main"
|
||||
lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-leaderboard="lerobot.scripts.lerobot_leaderboard:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
lerobot-benchmark="lerobot.scripts.lerobot_benchmark:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
|
||||
@@ -0,0 +1,689 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""
|
||||
Chunk-level multi-modality analysis for comparing full/mixed vs curated datasets.
|
||||
|
||||
Treats each action chunk (sliding window of CHUNK_SIZE consecutive frames) as the
|
||||
atomic unit, tagged by the SARM progress score at its start frame. For each
|
||||
progress band, compares the full vs HQ dataset on:
|
||||
|
||||
1. Intra-band action variance
|
||||
2. Progress delta per chunk
|
||||
3. GMM + BIC optimal K (number of distinct strategies)
|
||||
4. PCA embedding (visual cluster inspection)
|
||||
|
||||
Usage:
|
||||
python chunk_multimodality_analysis.py \\
|
||||
--full-dataset lerobot-data-collection/level12_rac_2_2026-02-08_1 \\
|
||||
--hq-dataset lerobot-data-collection/level2_final_quality3 \\
|
||||
--output-dir ./chunk_analysis
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from pathlib import Path
|
||||
|
||||
import matplotlib.pyplot as plt
|
||||
import numpy as np
|
||||
import pandas as pd
|
||||
from huggingface_hub import hf_hub_download
|
||||
from scipy.stats import gaussian_kde
|
||||
from sklearn.decomposition import PCA
|
||||
from sklearn.mixture import GaussianMixture
|
||||
from sklearn.preprocessing import StandardScaler
|
||||
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ── Visual style ──────────────────────────────────────────────────────────
|
||||
|
||||
BG = "#0e1117"
|
||||
CARD = "#1a1d27"
|
||||
BORDER = "#2a2d3a"
|
||||
SUB = "#8b8fa8"
|
||||
TEXT = "#e8eaf0"
|
||||
C_FULL = "#f7934f"
|
||||
C_HQ = "#4dc98a"
|
||||
|
||||
|
||||
def _style_ax(ax: plt.Axes) -> None:
|
||||
ax.set_facecolor(CARD)
|
||||
ax.tick_params(colors=SUB, labelsize=8)
|
||||
for spine in ax.spines.values():
|
||||
spine.set_color(BORDER)
|
||||
|
||||
|
||||
def _save(fig: plt.Figure, path: Path) -> None:
|
||||
fig.savefig(path, dpi=150, bbox_inches="tight", facecolor=BG)
|
||||
plt.close(fig)
|
||||
logger.info("Saved %s", path)
|
||||
|
||||
|
||||
# ── Step 0: Load episodes ────────────────────────────────────────────────
|
||||
|
||||
def _load_sarm_progress(repo_id: str) -> pd.DataFrame | None:
|
||||
"""Try to download sarm_progress.parquet from the Hub."""
|
||||
try:
|
||||
path = hf_hub_download(
|
||||
repo_id=repo_id, filename="sarm_progress.parquet",
|
||||
repo_type="dataset",
|
||||
)
|
||||
df = pd.read_parquet(path)
|
||||
col = "progress_sparse" if "progress_sparse" in df.columns else "progress_dense"
|
||||
if col not in df.columns:
|
||||
logger.warning("sarm_progress.parquet has no progress columns — ignoring")
|
||||
return None
|
||||
logger.info("Loaded SARM progress (%s) for %s (%d rows)", col, repo_id, len(df))
|
||||
return df.rename(columns={col: "progress"})[["episode_index", "frame_index", "progress"]]
|
||||
except Exception as exc:
|
||||
logger.warning("Could not load sarm_progress.parquet for %s: %s", repo_id, exc)
|
||||
return None
|
||||
|
||||
|
||||
def load_episodes(
|
||||
repo_id: str,
|
||||
n_joints: int = 16,
|
||||
max_episodes: int | None = None,
|
||||
) -> list[dict]:
|
||||
dataset = LeRobotDataset(repo_id, download_videos=False)
|
||||
raw = dataset.hf_dataset
|
||||
|
||||
sarm_df = _load_sarm_progress(repo_id)
|
||||
# Build per-episode progress arrays from SARM parquet (indexed by frame_index)
|
||||
sarm_by_ep: dict[int, dict[int, float]] = {}
|
||||
if sarm_df is not None:
|
||||
if max_episodes is not None:
|
||||
sarm_df = sarm_df[sarm_df["episode_index"] < max_episodes]
|
||||
for ep_id, grp in sarm_df.groupby("episode_index"):
|
||||
sarm_by_ep[int(ep_id)] = dict(
|
||||
zip(grp["frame_index"].astype(int), grp["progress"].astype(float))
|
||||
)
|
||||
|
||||
episodes: dict[int, dict] = defaultdict(lambda: {"actions": [], "progress": []})
|
||||
for row in raw:
|
||||
ep = int(row["episode_index"])
|
||||
if max_episodes is not None and ep >= max_episodes:
|
||||
continue
|
||||
action = np.array(row["action"], dtype=np.float32)[:n_joints]
|
||||
episodes[ep]["actions"].append(action)
|
||||
fi = int(row["frame_index"])
|
||||
ep_prog = sarm_by_ep.get(ep, {})
|
||||
episodes[ep]["progress"].append(ep_prog.get(fi, float("nan")))
|
||||
|
||||
has_sarm = len(sarm_lookup) > 0
|
||||
result = []
|
||||
for ep_id, d in sorted(episodes.items()):
|
||||
actions = np.stack(d["actions"])
|
||||
T = len(actions)
|
||||
if has_sarm:
|
||||
prog = np.array(d["progress"], dtype=np.float32)
|
||||
prog = np.clip(np.nan_to_num(prog, nan=0.0), 0.0, 1.0)
|
||||
prog = np.maximum.accumulate(prog)
|
||||
else:
|
||||
prog = np.linspace(0.0, 1.0, T, dtype=np.float32)
|
||||
result.append({"episode": ep_id, "actions": actions, "progress": prog})
|
||||
|
||||
src = "SARM" if has_sarm else "time-based"
|
||||
logger.info("Progress source: %s", src)
|
||||
return result
|
||||
|
||||
|
||||
# ── Step 1: Filter short episodes ────────────────────────────────────────
|
||||
|
||||
def auto_length_threshold(
|
||||
episodes_full: list[dict], episodes_hq: list[dict]
|
||||
) -> int:
|
||||
all_lengths = np.array(
|
||||
[e["actions"].shape[0] for e in episodes_full + episodes_hq]
|
||||
)
|
||||
kde = gaussian_kde(all_lengths, bw_method=0.25)
|
||||
xs = np.linspace(all_lengths.min(), np.percentile(all_lengths, 40), 300)
|
||||
return int(xs[np.argmin(kde(xs))])
|
||||
|
||||
|
||||
def plot_length_distribution(
|
||||
episodes_full: list[dict],
|
||||
episodes_hq: list[dict],
|
||||
threshold: int,
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
lens_full = np.array([e["actions"].shape[0] for e in episodes_full])
|
||||
lens_hq = np.array([e["actions"].shape[0] for e in episodes_hq])
|
||||
all_lens = np.concatenate([lens_full, lens_hq])
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
bins = np.linspace(all_lens.min(), all_lens.max(), 50)
|
||||
ax.hist(lens_full, bins=bins, alpha=0.5, color=C_FULL, label="Full/Mixed")
|
||||
ax.hist(lens_hq, bins=bins, alpha=0.5, color=C_HQ, label="HQ")
|
||||
|
||||
xs = np.linspace(all_lens.min(), all_lens.max(), 300)
|
||||
kde = gaussian_kde(all_lens, bw_method=0.25)
|
||||
ax.plot(xs, kde(xs) * len(all_lens) * (bins[1] - bins[0]), color=TEXT, lw=1.5, label="KDE (combined)")
|
||||
|
||||
ax.axvline(threshold, color="#ff4b4b", ls="--", lw=1.5, label=f"Threshold = {threshold}")
|
||||
ax.set_xlabel("Episode length (frames)", color=SUB)
|
||||
ax.set_ylabel("Count", color=SUB)
|
||||
ax.set_title("Episode Length Distribution", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
def filter_episodes(episodes: list[dict], min_length: int) -> list[dict]:
|
||||
kept = [e for e in episodes if e["actions"].shape[0] >= min_length]
|
||||
logger.info("Kept %d / %d episodes (min_length=%d)", len(kept), len(episodes), min_length)
|
||||
return kept
|
||||
|
||||
|
||||
# ── Step 2: Extract chunks ───────────────────────────────────────────────
|
||||
|
||||
def extract_chunks(
|
||||
episodes: list[dict],
|
||||
chunk_size: int = 30,
|
||||
chunk_stride: int = 15,
|
||||
) -> list[dict]:
|
||||
chunks = []
|
||||
for ep in episodes:
|
||||
actions = ep["actions"]
|
||||
T = len(actions)
|
||||
prog = ep["progress"]
|
||||
|
||||
for t in range(0, T - chunk_size, chunk_stride):
|
||||
chunk = actions[t : t + chunk_size]
|
||||
p_start = float(prog[t])
|
||||
p_end = float(prog[min(t + chunk_size, T - 1)])
|
||||
|
||||
chunks.append({
|
||||
"action_mean": chunk.mean(axis=0).astype(np.float32),
|
||||
"action_flat": chunk.flatten().astype(np.float32),
|
||||
"progress_start": p_start,
|
||||
"progress_delta": p_end - p_start,
|
||||
"episode": ep["episode"],
|
||||
})
|
||||
return chunks
|
||||
|
||||
|
||||
# ── Step 3: Adaptive progress bands ─────────────────────────────────────
|
||||
|
||||
def make_bands(n_bands: int = 5) -> list[tuple[float, float]]:
|
||||
edges = np.linspace(0.0, 1.0, n_bands + 1)
|
||||
return [(float(edges[i]), float(edges[i + 1])) for i in range(n_bands)]
|
||||
|
||||
|
||||
def assign_bands(
|
||||
chunks: list[dict], band_edges: list[tuple[float, float]]
|
||||
) -> list[dict]:
|
||||
n = len(band_edges)
|
||||
for c in chunks:
|
||||
p = c["progress_start"]
|
||||
c["band"] = next(
|
||||
(bi for bi, (lo, hi) in enumerate(band_edges) if p < hi),
|
||||
n - 1,
|
||||
)
|
||||
return chunks
|
||||
|
||||
|
||||
def split_by_band(chunks: list[dict], n_bands: int) -> dict[int, list[dict]]:
|
||||
out: dict[int, list[dict]] = {b: [] for b in range(n_bands)}
|
||||
for c in chunks:
|
||||
out[c["band"]].append(c)
|
||||
return out
|
||||
|
||||
|
||||
# ── Step 4: Intra-band action variance ──────────────────────────────────
|
||||
|
||||
def band_variance_matrix(
|
||||
bands: dict[int, list[dict]], n_bands: int, n_joints: int
|
||||
) -> np.ndarray:
|
||||
var_mat = np.full((n_bands, n_joints), np.nan)
|
||||
for b, clist in bands.items():
|
||||
if len(clist) < 3:
|
||||
continue
|
||||
means = np.stack([c["action_mean"] for c in clist])
|
||||
var_mat[b] = np.var(means, axis=0)
|
||||
return var_mat
|
||||
|
||||
|
||||
def plot_variance_heatmap(
|
||||
var_full: np.ndarray,
|
||||
var_hq: np.ndarray,
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = var_full.shape[0]
|
||||
vmin = 0.0
|
||||
vmax = max(np.nanmax(var_full), np.nanmax(var_hq))
|
||||
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
joint_labels = [f"J{j}" for j in range(var_full.shape[1])]
|
||||
|
||||
fig, axes = plt.subplots(3, 1, figsize=(12, 10), gridspec_kw={"height_ratios": [3, 3, 2]})
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("Intra-Band Action Variance", color=TEXT, fontsize=14, y=0.98)
|
||||
|
||||
for ax_idx, (mat, label) in enumerate([(var_full, "Full/Mixed"), (var_hq, "HQ")]):
|
||||
ax = axes[ax_idx]
|
||||
_style_ax(ax)
|
||||
im = ax.imshow(mat, aspect="auto", cmap="YlOrRd", vmin=vmin, vmax=vmax)
|
||||
ax.set_yticks(range(n_bands))
|
||||
ax.set_yticklabels(band_labels, fontsize=7, color=SUB)
|
||||
ax.set_xticks(range(var_full.shape[1]))
|
||||
ax.set_xticklabels(joint_labels, fontsize=7, color=SUB)
|
||||
ax.set_title(f"Panel {'A' if ax_idx == 0 else 'B'}: {label}", color=TEXT, fontsize=11)
|
||||
fig.colorbar(im, ax=ax, fraction=0.02, pad=0.02)
|
||||
|
||||
with np.errstate(invalid="ignore"):
|
||||
mean_full = np.nanmean(var_full, axis=1)
|
||||
mean_hq = np.nanmean(var_hq, axis=1)
|
||||
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||
mean_full / (mean_hq + 1e-8))
|
||||
ax_bar = axes[2]
|
||||
_style_ax(ax_bar)
|
||||
colors = [
|
||||
"#ff4b4b" if r > 2.0 else "#ffaa33" if r > 1.2 else C_HQ
|
||||
for r in ratio
|
||||
]
|
||||
ax_bar.bar(range(n_bands), ratio, color=colors, edgecolor=BORDER)
|
||||
ax_bar.axhline(1.0, color=SUB, ls="--", lw=0.8)
|
||||
ax_bar.set_xticks(range(n_bands))
|
||||
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB)
|
||||
ax_bar.set_ylabel("Variance ratio\n(Full / HQ)", color=SUB, fontsize=9)
|
||||
ax_bar.set_title("Panel C: Variance Ratio per Band", color=TEXT, fontsize=11)
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.96])
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Step 5: Progress delta per band ──────────────────────────────────────
|
||||
|
||||
def plot_progress_delta(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = len(band_edges)
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
x = np.arange(n_bands)
|
||||
w = 0.35
|
||||
|
||||
means_full, stds_full = [], []
|
||||
means_hq, stds_hq = [], []
|
||||
all_deltas_full, all_deltas_hq = [], []
|
||||
|
||||
for b in range(n_bands):
|
||||
df = np.array([c["progress_delta"] for c in bands_full.get(b, [])])
|
||||
dh = np.array([c["progress_delta"] for c in bands_hq.get(b, [])])
|
||||
means_full.append(np.mean(df) if len(df) > 0 else 0)
|
||||
stds_full.append(np.std(df) if len(df) > 0 else 0)
|
||||
means_hq.append(np.mean(dh) if len(dh) > 0 else 0)
|
||||
stds_hq.append(np.std(dh) if len(dh) > 0 else 0)
|
||||
all_deltas_full.extend(df.tolist())
|
||||
all_deltas_hq.extend(dh.tolist())
|
||||
|
||||
fig, (ax_bar, ax_viol) = plt.subplots(1, 2, figsize=(14, 5), gridspec_kw={"width_ratios": [3, 1]})
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("Progress Delta per Chunk", color=TEXT, fontsize=14)
|
||||
|
||||
_style_ax(ax_bar)
|
||||
ax_bar.bar(x - w / 2, means_full, w, yerr=stds_full, color=C_FULL, edgecolor=BORDER,
|
||||
capsize=3, label="Full/Mixed", error_kw={"ecolor": SUB})
|
||||
ax_bar.bar(x + w / 2, means_hq, w, yerr=stds_hq, color=C_HQ, edgecolor=BORDER,
|
||||
capsize=3, label="HQ", error_kw={"ecolor": SUB})
|
||||
ax_bar.set_xticks(x)
|
||||
ax_bar.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax_bar.set_ylabel("Mean progress Δ", color=SUB)
|
||||
ax_bar.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
|
||||
_style_ax(ax_viol)
|
||||
data_viol = [np.array(all_deltas_full), np.array(all_deltas_hq)]
|
||||
if all(len(d) > 0 for d in data_viol):
|
||||
parts = ax_viol.violinplot(data_viol, positions=[0, 1], showmeans=True, showmedians=True)
|
||||
for pc, c in zip(parts["bodies"], [C_FULL, C_HQ]):
|
||||
pc.set_facecolor(c)
|
||||
pc.set_alpha(0.7)
|
||||
for key in ("cmeans", "cmedians", "cbars", "cmins", "cmaxes"):
|
||||
if key in parts:
|
||||
parts[key].set_color(SUB)
|
||||
ax_viol.set_xticks([0, 1])
|
||||
ax_viol.set_xticklabels(["Full", "HQ"], color=SUB)
|
||||
ax_viol.set_ylabel("Progress Δ", color=SUB)
|
||||
ax_viol.set_title("Overall Distribution", color=TEXT, fontsize=10)
|
||||
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Step 6: GMM + BIC per band ──────────────────────────────────────────
|
||||
|
||||
def gmm_optimal_k(
|
||||
band_chunks: list[dict],
|
||||
pca_components: int = 15,
|
||||
max_k: int = 12,
|
||||
seed: int = 42,
|
||||
) -> int | None:
|
||||
if len(band_chunks) < 20:
|
||||
return None
|
||||
X = np.stack([c["action_flat"] for c in band_chunks])
|
||||
X = StandardScaler().fit_transform(X)
|
||||
n = min(pca_components, X.shape[1], X.shape[0] - 1)
|
||||
X_r = PCA(n_components=n, random_state=seed).fit_transform(X)
|
||||
bics = []
|
||||
for k in range(1, min(max_k + 1, len(X_r) // 6)):
|
||||
gmm = GaussianMixture(
|
||||
n_components=k, covariance_type="full",
|
||||
n_init=5, max_iter=300, random_state=seed,
|
||||
)
|
||||
gmm.fit(X_r)
|
||||
bics.append((k, gmm.bic(X_r)))
|
||||
if not bics:
|
||||
return None
|
||||
return min(bics, key=lambda x: x[1])[0]
|
||||
|
||||
|
||||
def plot_gmm_bic(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
seed: int,
|
||||
out_path: Path,
|
||||
) -> tuple[list[int | None], list[int | None]]:
|
||||
n_bands = len(band_edges)
|
||||
ks_full = [gmm_optimal_k(bands_full.get(b, []), seed=seed) for b in range(n_bands)]
|
||||
ks_hq = [gmm_optimal_k(bands_hq.get(b, []), seed=seed) for b in range(n_bands)]
|
||||
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
xs = np.arange(n_bands)
|
||||
valid_full = [(i, k) for i, k in enumerate(ks_full) if k is not None]
|
||||
valid_hq = [(i, k) for i, k in enumerate(ks_hq) if k is not None]
|
||||
|
||||
if valid_full:
|
||||
xi, yi = zip(*valid_full)
|
||||
ax.plot(xi, yi, "o-", color=C_FULL, label="Full/Mixed", lw=2, markersize=7)
|
||||
if valid_hq:
|
||||
xi, yi = zip(*valid_hq)
|
||||
ax.plot(xi, yi, "o-", color=C_HQ, label="HQ", lw=2, markersize=7)
|
||||
|
||||
if valid_full and valid_hq:
|
||||
all_x = sorted(set([i for i, _ in valid_full]) & set([i for i, _ in valid_hq]))
|
||||
if len(all_x) >= 2:
|
||||
kf_interp = {i: k for i, k in valid_full}
|
||||
kh_interp = {i: k for i, k in valid_hq}
|
||||
shared_x = [i for i in all_x if i in kf_interp and i in kh_interp]
|
||||
yf = [kf_interp[i] for i in shared_x]
|
||||
yh = [kh_interp[i] for i in shared_x]
|
||||
ax.fill_between(shared_x, yf, yh, alpha=0.15, color=TEXT)
|
||||
|
||||
ax.set_xticks(xs)
|
||||
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax.set_ylabel("Optimal K (GMM-BIC)", color=SUB)
|
||||
ax.set_title("Number of Distinct Strategies per Band", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=9)
|
||||
ax.yaxis.set_major_locator(plt.MaxNLocator(integer=True))
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
return ks_full, ks_hq
|
||||
|
||||
|
||||
# ── Step 7: PCA scatter per band ────────────────────────────────────────
|
||||
|
||||
def plot_pca_scatter(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_plot = min(4, len(band_edges))
|
||||
fig, axes = plt.subplots(2, n_plot, figsize=(4 * n_plot, 7))
|
||||
fig.patch.set_facecolor(BG)
|
||||
fig.suptitle("PCA of Action Chunks per Band", color=TEXT, fontsize=14)
|
||||
|
||||
if n_plot == 1:
|
||||
axes = axes.reshape(2, 1)
|
||||
|
||||
for col, b in enumerate(range(n_plot)):
|
||||
cf = bands_full.get(b, [])
|
||||
ch = bands_hq.get(b, [])
|
||||
lo, hi = band_edges[b]
|
||||
|
||||
for row, (clist, color, label) in enumerate([
|
||||
(cf, C_FULL, "Full/Mixed"), (ch, C_HQ, "HQ")
|
||||
]):
|
||||
ax = axes[row, col]
|
||||
_style_ax(ax)
|
||||
if row == 0:
|
||||
ax.set_title(f"{lo:.0%}–{hi:.0%}", color=TEXT, fontsize=10)
|
||||
if col == 0:
|
||||
ax.set_ylabel(label, color=SUB, fontsize=9)
|
||||
|
||||
if len(cf) < 3 or len(ch) < 3:
|
||||
ax.text(0.5, 0.5, "Too few\nchunks", transform=ax.transAxes,
|
||||
ha="center", va="center", color=SUB, fontsize=9)
|
||||
continue
|
||||
|
||||
X_full_b = np.stack([c["action_flat"] for c in cf])
|
||||
X_hq_b = np.stack([c["action_flat"] for c in ch])
|
||||
X_all = np.vstack([X_full_b, X_hq_b])
|
||||
X_all = StandardScaler().fit_transform(X_all)
|
||||
X_2d = PCA(n_components=2, random_state=42).fit_transform(X_all)
|
||||
|
||||
X_2d_full = X_2d[: len(cf)]
|
||||
X_2d_hq = X_2d[len(cf) :]
|
||||
|
||||
pts = X_2d_full if row == 0 else X_2d_hq
|
||||
ax.scatter(pts[:, 0], pts[:, 1], s=8, alpha=0.5, color=color, edgecolors="none")
|
||||
|
||||
fig.tight_layout(rect=[0, 0, 1, 0.95])
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Plot 1: Chunk counts per band ───────────────────────────────────────
|
||||
|
||||
def plot_chunk_counts(
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
band_edges: list[tuple[float, float]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
n_bands = len(band_edges)
|
||||
band_labels = [f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges]
|
||||
x = np.arange(n_bands)
|
||||
w = 0.35
|
||||
|
||||
counts_full = [len(bands_full.get(b, [])) for b in range(n_bands)]
|
||||
counts_hq = [len(bands_hq.get(b, [])) for b in range(n_bands)]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(10, 5))
|
||||
fig.patch.set_facecolor(BG)
|
||||
_style_ax(ax)
|
||||
|
||||
ax.bar(x - w / 2, counts_full, w, color=C_FULL, edgecolor=BORDER, label="Full/Mixed")
|
||||
ax.bar(x + w / 2, counts_hq, w, color=C_HQ, edgecolor=BORDER, label="HQ")
|
||||
ax.set_xticks(x)
|
||||
ax.set_xticklabels(band_labels, fontsize=7, color=SUB, rotation=30)
|
||||
ax.set_ylabel("Chunk count", color=SUB)
|
||||
ax.set_title("Chunk Counts per Progress Band", color=TEXT, fontsize=13)
|
||||
ax.legend(facecolor=CARD, edgecolor=BORDER, labelcolor=TEXT, fontsize=8)
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
|
||||
# ── Summary figure ───────────────────────────────────────────────────────
|
||||
|
||||
def plot_summary(
|
||||
var_full: np.ndarray,
|
||||
var_hq: np.ndarray,
|
||||
band_edges: list[tuple[float, float]],
|
||||
ks_full: list[int | None],
|
||||
ks_hq: list[int | None],
|
||||
bands_full: dict[int, list[dict]],
|
||||
bands_hq: dict[int, list[dict]],
|
||||
out_path: Path,
|
||||
) -> None:
|
||||
with np.errstate(invalid="ignore"):
|
||||
mean_full = np.nanmean(var_full, axis=1)
|
||||
mean_hq = np.nanmean(var_hq, axis=1)
|
||||
ratio = np.where(np.isnan(mean_full) | np.isnan(mean_hq), np.nan,
|
||||
mean_full / (mean_hq + 1e-8))
|
||||
valid_ratio = ratio[~np.isnan(ratio)]
|
||||
mean_ratio = float(np.mean(valid_ratio)) if len(valid_ratio) > 0 else float("nan")
|
||||
peak_idx = int(np.argmax(valid_ratio)) if len(valid_ratio) > 0 else 0
|
||||
peak_ratio = float(valid_ratio[peak_idx]) if len(valid_ratio) > 0 else float("nan")
|
||||
lo, hi = band_edges[peak_idx]
|
||||
peak_band = f"{lo:.0%}–{hi:.0%}"
|
||||
|
||||
valid_kf = [k for k in ks_full if k is not None]
|
||||
valid_kh = [k for k in ks_hq if k is not None]
|
||||
mean_k_full = np.mean(valid_kf) if valid_kf else float("nan")
|
||||
mean_k_hq = np.mean(valid_kh) if valid_kh else float("nan")
|
||||
|
||||
n_bands = len(band_edges)
|
||||
deltas_full = [c["progress_delta"] for b in range(n_bands) for c in bands_full.get(b, [])]
|
||||
deltas_hq = [c["progress_delta"] for b in range(n_bands) for c in bands_hq.get(b, [])]
|
||||
mean_delta_full = float(np.mean(deltas_full)) if deltas_full else float("nan")
|
||||
mean_delta_hq = float(np.mean(deltas_hq)) if deltas_hq else float("nan")
|
||||
|
||||
rows = [
|
||||
("Mean variance ratio (Full / HQ)", f"{mean_ratio:.2f}x"),
|
||||
("Peak variance ratio", f"{peak_ratio:.2f}x at {peak_band}"),
|
||||
("Mean GMM K — Full", f"{mean_k_full:.1f}"),
|
||||
("Mean GMM K — HQ", f"{mean_k_hq:.1f}"),
|
||||
("Mean progress Δ — Full", f"{mean_delta_full:.4f}"),
|
||||
("Mean progress Δ — HQ", f"{mean_delta_hq:.4f}"),
|
||||
]
|
||||
|
||||
fig, ax = plt.subplots(figsize=(8, 3))
|
||||
fig.patch.set_facecolor(BG)
|
||||
ax.set_facecolor(CARD)
|
||||
ax.axis("off")
|
||||
|
||||
table = ax.table(
|
||||
cellText=[[m, v] for m, v in rows],
|
||||
colLabels=["Metric", "Value"],
|
||||
loc="center",
|
||||
cellLoc="left",
|
||||
)
|
||||
table.auto_set_font_size(False)
|
||||
table.set_fontsize(10)
|
||||
for key, cell in table.get_celld().items():
|
||||
cell.set_edgecolor(BORDER)
|
||||
cell.set_facecolor(CARD)
|
||||
cell.set_text_props(color=TEXT)
|
||||
if key[0] == 0:
|
||||
cell.set_text_props(color=TEXT, fontweight="bold")
|
||||
table.scale(1, 1.6)
|
||||
ax.set_title("Summary Statistics", color=TEXT, fontsize=13, pad=15)
|
||||
fig.tight_layout()
|
||||
_save(fig, out_path)
|
||||
|
||||
for metric, value in rows:
|
||||
logger.info(" %s: %s", metric, value)
|
||||
|
||||
|
||||
# ── Main ─────────────────────────────────────────────────────────────────
|
||||
|
||||
def main(args: argparse.Namespace) -> None:
|
||||
out = Path(args.output_dir)
|
||||
out.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
logger.info("Loading FULL dataset: %s", args.full_dataset)
|
||||
episodes_full = load_episodes(args.full_dataset, args.n_joints, args.max_episodes)
|
||||
logger.info("Loading HQ dataset: %s", args.hq_dataset)
|
||||
episodes_hq = load_episodes(args.hq_dataset, args.n_joints, args.max_episodes)
|
||||
logger.info("Loaded %d full episodes, %d HQ episodes", len(episodes_full), len(episodes_hq))
|
||||
|
||||
# Step 1: length threshold + filter
|
||||
if args.min_episode_length is not None:
|
||||
threshold = args.min_episode_length
|
||||
else:
|
||||
threshold = auto_length_threshold(episodes_full, episodes_hq)
|
||||
logger.info("Episode length threshold: %d", threshold)
|
||||
|
||||
plot_length_distribution(episodes_full, episodes_hq, threshold, out / "0_length_distribution.png")
|
||||
episodes_full = filter_episodes(episodes_full, threshold)
|
||||
episodes_hq = filter_episodes(episodes_hq, threshold)
|
||||
|
||||
# Step 2: extract chunks
|
||||
chunks_full = extract_chunks(episodes_full, args.chunk_size, args.chunk_stride)
|
||||
chunks_hq = extract_chunks(episodes_hq, args.chunk_size, args.chunk_stride)
|
||||
logger.info("Extracted %d full chunks, %d HQ chunks", len(chunks_full), len(chunks_hq))
|
||||
|
||||
# Step 3: fixed equal-width bands over episode-relative progress
|
||||
band_edges = make_bands(args.n_bands)
|
||||
n_bands = len(band_edges)
|
||||
logger.info("Progress bands (%d): %s", n_bands,
|
||||
[f"{lo:.0%}–{hi:.0%}" for lo, hi in band_edges])
|
||||
|
||||
chunks_full = assign_bands(chunks_full, band_edges)
|
||||
chunks_hq = assign_bands(chunks_hq, band_edges)
|
||||
bands_full = split_by_band(chunks_full, n_bands)
|
||||
bands_hq = split_by_band(chunks_hq, n_bands)
|
||||
|
||||
# Plot 1: chunk counts
|
||||
plot_chunk_counts(bands_full, bands_hq, band_edges, out / "1_chunk_counts_per_band.png")
|
||||
|
||||
# Step 4: variance heatmap
|
||||
var_full = band_variance_matrix(bands_full, n_bands, args.n_joints)
|
||||
var_hq = band_variance_matrix(bands_hq, n_bands, args.n_joints)
|
||||
plot_variance_heatmap(var_full, var_hq, band_edges, out / "2_variance_heatmap.png")
|
||||
|
||||
# Step 5: progress delta
|
||||
plot_progress_delta(bands_full, bands_hq, band_edges, out / "3_progress_delta_per_band.png")
|
||||
|
||||
# Step 6: GMM BIC
|
||||
ks_full, ks_hq = plot_gmm_bic(bands_full, bands_hq, band_edges, args.seed, out / "4_gmm_bic_per_band.png")
|
||||
|
||||
# Step 7: PCA scatter
|
||||
plot_pca_scatter(bands_full, bands_hq, band_edges, out / "5_pca_per_band.png")
|
||||
|
||||
# Summary
|
||||
plot_summary(var_full, var_hq, band_edges, ks_full, ks_hq,
|
||||
bands_full, bands_hq, out / "6_summary.png")
|
||||
|
||||
logger.info("All figures saved to %s", out)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
p = argparse.ArgumentParser(
|
||||
description="Chunk-level multi-modality analysis: Full/Mixed vs HQ dataset.",
|
||||
formatter_class=argparse.RawDescriptionHelpFormatter,
|
||||
)
|
||||
p.add_argument("--full-dataset", default="lerobot-data-collection/level12_rac_2_2026-02-08_1")
|
||||
p.add_argument("--hq-dataset", default="lerobot-data-collection/level2_final_quality3_trim_0_hil_data")
|
||||
p.add_argument("--output-dir", default="./chunk_analysis")
|
||||
p.add_argument("--chunk-size", type=int, default=30)
|
||||
p.add_argument("--chunk-stride", type=int, default=15)
|
||||
p.add_argument("--n-bands", type=int, default=5, help="Number of equal-width progress bands")
|
||||
p.add_argument("--max-episodes", type=int, default=500)
|
||||
p.add_argument("--n-joints", type=int, default=16)
|
||||
p.add_argument("--min-episode-length", type=int, default=None,
|
||||
help="Override auto-detected length filter threshold")
|
||||
p.add_argument("--seed", type=int, default=42)
|
||||
args = p.parse_args()
|
||||
main(args)
|
||||
@@ -0,0 +1,29 @@
|
||||
#!/bin/bash
|
||||
#SBATCH --job-name=smolvla_libero_plus
|
||||
#SBATCH --partition=hopper-prod
|
||||
#SBATCH --nodes=1
|
||||
#SBATCH --ntasks-per-node=1
|
||||
#SBATCH --gpus-per-node=4
|
||||
#SBATCH --cpus-per-task=48
|
||||
#SBATCH --mem=200G
|
||||
#SBATCH --time=12:00:00
|
||||
#SBATCH --output=logs/smolvla_libero_plus_%j.out
|
||||
#SBATCH --error=logs/smolvla_libero_plus_%j.err
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
eval "$(conda shell.bash hook 2>/dev/null)"
|
||||
conda activate lerobot312
|
||||
|
||||
cd /admin/home/pepijn/lerobot_wt_robocasa
|
||||
|
||||
lerobot-benchmark train \
|
||||
--benchmarks libero_plus \
|
||||
--policy-path lerobot/smolvla_base \
|
||||
--hub-user pepijn223 \
|
||||
--num-gpus 4 \
|
||||
--steps 30000 \
|
||||
--batch-size 32 \
|
||||
--eval-freq 0 \
|
||||
--wandb \
|
||||
--dataset.repo_id=pepijn223/libero_plus_lerobot
|
||||
@@ -49,15 +49,64 @@ class WandBConfig:
|
||||
mode: str | None = None # Allowed values: 'online', 'offline' 'disabled'. Defaults to 'online'
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalDockerConfig:
|
||||
# Docker image to use for evaluation (e.g., "ghcr.io/org/lerobot-eval-libero:latest").
|
||||
# Takes precedence over eval.envhub_ref.
|
||||
image: str | None = None
|
||||
# Optional EnvHub reference to resolve an image, e.g. "envhub://lerobot/libero_plus@v1".
|
||||
envhub_ref: str | None = None
|
||||
# If true, mount the local repository and prefer local source code in the container.
|
||||
use_local_code: bool = True
|
||||
# Pull the image before running.
|
||||
pull: bool = True
|
||||
# Docker --gpus value. Set to None to disable GPU flags and run CPU-only.
|
||||
gpus: str | None = "all"
|
||||
# Docker --shm-size value (increase when using larger eval.batch_size values).
|
||||
shm_size: str = "8g"
|
||||
# Port on which the host HTTP policy inference server listens.
|
||||
port: int = 50051
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalConfig:
|
||||
n_episodes: int = 50
|
||||
# `batch_size` specifies the number of environments to use in a gym.vector.VectorEnv.
|
||||
# Number of sub-envs per task inside one VectorEnv. Increase to improve per-task
|
||||
# inference throughput until GPU or simulator memory saturates.
|
||||
batch_size: int = 50
|
||||
# `use_async_envs` specifies whether to use asynchronous environments (multiprocessing).
|
||||
# Use AsyncVectorEnv (multiprocessing). Prefer SyncVectorEnv unless your environment
|
||||
# spends significant time in Python-side stepping and can benefit from process parallelism.
|
||||
use_async_envs: bool = False
|
||||
# Runtime where evaluation executes: "local", "docker", or "multiprocess".
|
||||
# "multiprocess" spawns local worker processes + policy servers.
|
||||
runtime: str = "local"
|
||||
docker: EvalDockerConfig = field(default_factory=EvalDockerConfig)
|
||||
# Number of parallel eval script instances to launch for one run.
|
||||
# instance_count > 1 enables multi-instance task sharding.
|
||||
instance_count: int = 1
|
||||
# 0-indexed shard id for this process. Users usually leave this at 0.
|
||||
# Additional shards are launched automatically by `lerobot-eval` when instance_count > 1.
|
||||
instance_id: int = 0
|
||||
# Number of policy inference servers to run in parallel (docker/multiprocess runtimes).
|
||||
# Each server loads a copy of the model and listens on consecutive ports
|
||||
# starting from eval.docker.port. Workers are round-robin assigned.
|
||||
policy_servers: int = 1
|
||||
# Base port for policy servers in multiprocess mode.
|
||||
port: int = 50051
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if self.runtime not in {"local", "docker", "multiprocess"}:
|
||||
raise ValueError(
|
||||
f"Unsupported eval.runtime '{self.runtime}'. Expected one of: local, docker, multiprocess."
|
||||
)
|
||||
if self.instance_count < 1:
|
||||
raise ValueError("eval.instance_count must be >= 1.")
|
||||
if self.instance_id < 0 or self.instance_id >= self.instance_count:
|
||||
raise ValueError(
|
||||
f"eval.instance_id must be in [0, {self.instance_count - 1}] (got {self.instance_id})."
|
||||
)
|
||||
if self.policy_servers < 1:
|
||||
raise ValueError("eval.policy_servers must be >= 1.")
|
||||
if self.batch_size > self.n_episodes:
|
||||
raise ValueError(
|
||||
"The eval batch size is greater than the number of eval episodes "
|
||||
|
||||
@@ -40,6 +40,8 @@ class EvalPipelineConfig:
|
||||
rename_map: dict[str, str] = field(default_factory=dict)
|
||||
# Explicit consent to execute remote code from the Hub (required for hub environments).
|
||||
trust_remote_code: bool = False
|
||||
# Push eval results (metrics JSON, rollout videos, model card update) to the model's Hub repo.
|
||||
push_to_hub: bool = False
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
|
||||
@@ -126,7 +126,11 @@ def make_dataset(cfg: TrainPipelineConfig) -> LeRobotDataset | MultiLeRobotDatas
|
||||
)
|
||||
|
||||
if cfg.dataset.use_imagenet_stats:
|
||||
if dataset.meta.stats is None:
|
||||
dataset.meta.stats = {}
|
||||
for key in dataset.meta.camera_keys:
|
||||
if key not in dataset.meta.stats:
|
||||
dataset.meta.stats[key] = {}
|
||||
for stats_type, stats in IMAGENET_STATS.items():
|
||||
dataset.meta.stats[key][stats_type] = torch.tensor(stats, dtype=torch.float32)
|
||||
|
||||
|
||||
@@ -45,6 +45,10 @@ class EnvConfig(draccus.ChoiceRegistry, abc.ABC):
|
||||
fps: int = 30
|
||||
features: dict[str, PolicyFeature] = field(default_factory=dict)
|
||||
features_map: dict[str, str] = field(default_factory=dict)
|
||||
# Upper bound on concurrent task evaluation in `lerobot-eval`.
|
||||
# - For lazy wrappers (e.g. LIBERO/LIBERO-plus), values >1 can enable chunked
|
||||
# task batching with one policy forward pass over multiple tasks.
|
||||
# - For other envs, values >1 use a threaded task scheduler fallback.
|
||||
max_parallel_tasks: int = 1
|
||||
disable_env_checker: bool = True
|
||||
|
||||
@@ -346,6 +350,105 @@ class LiberoEnv(EnvConfig):
|
||||
return kwargs
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("libero_plus")
|
||||
@dataclass
|
||||
class LiberoPlusEnv(LiberoEnv):
|
||||
"""Alias config for LIBERO-plus benchmarks.
|
||||
|
||||
LIBERO-plus keeps the same Python package/module names as LIBERO, so this
|
||||
config reuses the existing LIBERO env implementation while making intent explicit
|
||||
in experiment configs (`env.type=libero_plus`).
|
||||
"""
|
||||
|
||||
task: str = "libero_spatial,libero_object,libero_goal,libero_10"
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robocasa")
|
||||
@dataclass
|
||||
class RoboCasaEnv(EnvConfig):
|
||||
"""RoboCasa kitchen composite-task environments.
|
||||
|
||||
Wraps ``robocasa.wrappers.gym_wrapper.RoboCasaGymEnv`` with a flat 12-D Box
|
||||
action space and a structured pixel + state observation dict.
|
||||
|
||||
Selected benchmark tasks (3 short + 2 long):
|
||||
Short: PickPlaceCounterToCabinet, PrepareToast, CoffeeSetupMug
|
||||
Long: PrepareCoffee, RestockPantry
|
||||
"""
|
||||
|
||||
task: str = "PickPlaceCounterToCabinet"
|
||||
tasks: list[str] | None = None # multi-task: list of task names (without robocasa/ prefix)
|
||||
fps: int = 20
|
||||
episode_length: int = 500
|
||||
image_size: int = 128
|
||||
split: str = "target" # "pretrain" or "target"
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(12,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"agentview_left": f"{OBS_IMAGES}.agentview_left",
|
||||
"agentview_right": f"{OBS_IMAGES}.agentview_right",
|
||||
"eye_in_hand": f"{OBS_IMAGES}.eye_in_hand",
|
||||
"robot_state": OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
self.features[cam] = PolicyFeature(
|
||||
type=FeatureType.VISUAL, shape=(self.image_size, self.image_size, 3)
|
||||
)
|
||||
self.features["robot_state"] = PolicyFeature(type=FeatureType.STATE, shape=(16,))
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {"split": self.split}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("robomme")
|
||||
@dataclass
|
||||
class RoboMMEEnv(EnvConfig):
|
||||
"""RoboMME memory-augmented manipulation benchmark (ManiSkill/SAPIEN).
|
||||
|
||||
16 tasks across 4 suites: Counting, Permanence, Reference, Imitation.
|
||||
Uses BenchmarkEnvBuilder from the robomme package.
|
||||
"""
|
||||
|
||||
task: str = "PickXtimes"
|
||||
fps: int = 10
|
||||
episode_length: int = 300
|
||||
action_space: str = "joint_angle"
|
||||
dataset_split: str = "test"
|
||||
task_ids: list[int] | None = None
|
||||
features: dict[str, PolicyFeature] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: PolicyFeature(type=FeatureType.ACTION, shape=(8,)),
|
||||
"front_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
"wrist_rgb": PolicyFeature(type=FeatureType.VISUAL, shape=(256, 256, 3)),
|
||||
OBS_STATE: PolicyFeature(type=FeatureType.STATE, shape=(8,)),
|
||||
}
|
||||
)
|
||||
features_map: dict[str, str] = field(
|
||||
default_factory=lambda: {
|
||||
ACTION: ACTION,
|
||||
"front_rgb": f"{OBS_IMAGES}.front",
|
||||
"wrist_rgb": f"{OBS_IMAGES}.wrist",
|
||||
OBS_STATE: OBS_STATE,
|
||||
}
|
||||
)
|
||||
|
||||
@property
|
||||
def gym_kwargs(self) -> dict:
|
||||
return {
|
||||
"action_space": self.action_space,
|
||||
"dataset": self.dataset_split,
|
||||
}
|
||||
|
||||
|
||||
@EnvConfig.register_subclass("metaworld")
|
||||
@dataclass
|
||||
class MetaworldEnv(EnvConfig):
|
||||
|
||||
@@ -0,0 +1,442 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Docker runtime for lerobot-eval.
|
||||
|
||||
The policy stays on the host GPU; gym environments run inside Docker containers.
|
||||
Each container runs `lerobot-eval-worker`, which calls back to a host HTTP inference
|
||||
server for action chunks.
|
||||
|
||||
Architecture:
|
||||
host (GPU):
|
||||
1. Load policy + preprocessors from EvalPipelineConfig.
|
||||
2. Start ``policy_servers`` HTTP inference servers on consecutive ports.
|
||||
3. Spawn ``instance_count`` Docker containers, round-robin assigned to servers.
|
||||
4. Wait; collect per-task JSON written to the mounted output volume.
|
||||
5. Merge shards → aggregate → write eval_info.json.
|
||||
|
||||
container (CPU only):
|
||||
1. make_env(cfg.env) → shard tasks by (instance_id, instance_count).
|
||||
2. For each task: run n_episodes, POST obs to /predict_chunk, step env.
|
||||
3. Write per-task JSON to /results/worker_{instance_id}.json.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403 — internal serialisation only
|
||||
import platform
|
||||
import subprocess # nosec B404
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
|
||||
from lerobot.envs.factory import make_env_pre_post_processors
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.utils.utils import get_safe_torch_device
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# HTTP inference server (host side)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class _PolicyInferenceHandler(BaseHTTPRequestHandler):
|
||||
"""POST /predict_chunk → pickled numpy action chunk."""
|
||||
|
||||
server: _InferenceServer
|
||||
|
||||
def do_POST(self) -> None:
|
||||
if self.path != "/predict_chunk":
|
||||
self.send_error(404)
|
||||
return
|
||||
length = int(self.headers["Content-Length"])
|
||||
body = self.rfile.read(length)
|
||||
payload: dict = pickle.loads(body) # nosec B301
|
||||
obs_t: dict = payload["obs_t"]
|
||||
|
||||
with self.server._lock:
|
||||
chunk_np = self.server._predict(obs_t)
|
||||
|
||||
resp = pickle.dumps(chunk_np) # nosec B301
|
||||
self.send_response(200)
|
||||
self.send_header("Content-Type", "application/octet-stream")
|
||||
self.send_header("Content-Length", str(len(resp)))
|
||||
self.end_headers()
|
||||
self.wfile.write(resp)
|
||||
|
||||
def log_message(self, fmt: str, *args: Any) -> None: # noqa: ANN401
|
||||
pass # suppress per-request logs
|
||||
|
||||
|
||||
class _InferenceServer(HTTPServer):
|
||||
"""Wraps the loaded policy behind a trivial HTTP interface."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
addr: tuple[str, int],
|
||||
policy: Any,
|
||||
env_preprocessor: Any,
|
||||
preprocessor: Any,
|
||||
postprocessor: Any,
|
||||
) -> None:
|
||||
super().__init__(addr, _PolicyInferenceHandler)
|
||||
self._policy = policy
|
||||
self._env_preprocessor = env_preprocessor
|
||||
self._preprocessor = preprocessor
|
||||
self._postprocessor = postprocessor
|
||||
self._lock = threading.Lock()
|
||||
self._device = torch.device(str(policy.config.device))
|
||||
|
||||
def _predict(self, obs_t: dict) -> np.ndarray:
|
||||
"""Apply full preprocessing pipeline and return (n_action_steps, A) numpy chunk."""
|
||||
obs = self._env_preprocessor(obs_t)
|
||||
obs = self._preprocessor(obs)
|
||||
obs_gpu: dict = {k: v.to(self._device) if isinstance(v, torch.Tensor) else v for k, v in obs.items()}
|
||||
with torch.no_grad():
|
||||
chunk: torch.Tensor = self._policy.predict_action_chunk(obs_gpu) # (B, T, A)
|
||||
|
||||
n_action_steps = getattr(self._policy.config, "n_action_steps", chunk.shape[1])
|
||||
batch, n_steps, action_dim = chunk.shape
|
||||
chunk_2d = chunk.reshape(batch * n_steps, action_dim) # (B*T, A)
|
||||
chunk_2d = self._postprocessor(chunk_2d) # (B*T, A)
|
||||
return chunk_2d[:n_action_steps].cpu().numpy() # (n_action_steps, A)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def _get_host_ip() -> str:
|
||||
"""Return the IP that containers can use to reach the host."""
|
||||
if platform.system() in ("Darwin", "Windows"):
|
||||
return "host.docker.internal"
|
||||
return "172.17.0.1" # Linux Docker bridge default gateway
|
||||
|
||||
|
||||
def _resolve_image(cfg: EvalPipelineConfig) -> str:
|
||||
"""Return the Docker image name to use for the env containers."""
|
||||
if cfg.eval.docker.image:
|
||||
return cfg.eval.docker.image
|
||||
return f"lerobot-benchmark-{cfg.env.type}"
|
||||
|
||||
|
||||
def _env_argv() -> list[str]:
|
||||
"""Extract --env.* args from sys.argv to forward verbatim to the worker."""
|
||||
return [arg for arg in sys.argv[1:] if arg.startswith("--env.")]
|
||||
|
||||
|
||||
def _spawn_container(
|
||||
*,
|
||||
image: str,
|
||||
instance_id: int,
|
||||
instance_count: int,
|
||||
server_address: str,
|
||||
n_episodes: int,
|
||||
seed: int,
|
||||
output_dir: Path,
|
||||
docker_cfg: Any,
|
||||
env_argv: list[str],
|
||||
) -> subprocess.Popen:
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
container_results = "/results"
|
||||
|
||||
cmd: list[str] = ["docker", "run", "--rm"]
|
||||
if docker_cfg.gpus:
|
||||
cmd += [f"--gpus={docker_cfg.gpus}"]
|
||||
cmd += [f"--shm-size={docker_cfg.shm_size}"]
|
||||
cmd += ["-v", f"{output_dir.resolve()}:{container_results}"]
|
||||
# Allow containers on Linux to resolve host.docker.internal.
|
||||
cmd += ["--add-host=host.docker.internal:host-gateway"]
|
||||
cmd.append(image)
|
||||
|
||||
cmd += [
|
||||
"lerobot-eval-worker",
|
||||
*env_argv,
|
||||
f"--server_address={server_address}",
|
||||
f"--n_episodes={n_episodes}",
|
||||
f"--seed={seed}",
|
||||
f"--instance_id={instance_id}",
|
||||
f"--instance_count={instance_count}",
|
||||
f"--output_path={container_results}/worker_{instance_id}.json",
|
||||
]
|
||||
|
||||
logger.info(
|
||||
"Spawning container %d/%d: %s",
|
||||
instance_id + 1,
|
||||
instance_count,
|
||||
" ".join(cmd),
|
||||
)
|
||||
return subprocess.Popen(cmd) # nosec B603 B607
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Public entry point
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
def run_eval_in_docker(cfg: EvalPipelineConfig) -> None:
|
||||
"""Run eval with env in Docker containers and policy on the host GPU.
|
||||
|
||||
Writes ``eval_info.json`` to ``cfg.output_dir``. Called by
|
||||
``lerobot_eval._run_eval_worker`` when ``eval.runtime == "docker"``.
|
||||
"""
|
||||
# Import here to avoid circular import at module level.
|
||||
from lerobot.scripts.lerobot_eval import _aggregate_eval_from_per_task
|
||||
|
||||
start_t = time.time()
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
docker_cfg = cfg.eval.docker
|
||||
|
||||
# Optionally pull the image before starting.
|
||||
image = _resolve_image(cfg)
|
||||
if docker_cfg.pull:
|
||||
logger.info("Pulling Docker image: %s", image)
|
||||
subprocess.run(["docker", "pull", image], check=True) # nosec B603 B607
|
||||
|
||||
# ── Load policy + all preprocessors on the host GPU ──────────────────
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides: dict = {
|
||||
"device_processor": {"device": str(device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env,
|
||||
policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
# ── Start HTTP inference server(s) ────────────────────────────────────
|
||||
n_policy_servers = cfg.eval.policy_servers
|
||||
base_port = docker_cfg.port
|
||||
host_ip = _get_host_ip()
|
||||
instance_count = cfg.eval.instance_count
|
||||
env_argv = _env_argv()
|
||||
|
||||
servers: list[_InferenceServer] = []
|
||||
for s_idx in range(n_policy_servers):
|
||||
port = base_port + s_idx
|
||||
if s_idx > 0:
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _ = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
srv = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
t = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
t.start()
|
||||
servers.append(srv)
|
||||
logger.info("Policy inference server %d/%d running on port %d", s_idx + 1, n_policy_servers, port)
|
||||
|
||||
# ── Spawn containers (round-robin across policy servers) ──────────────
|
||||
container_dirs: list[Path] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
try:
|
||||
for i in range(instance_count):
|
||||
assigned_port = base_port + (i % n_policy_servers)
|
||||
server_address = f"{host_ip}:{assigned_port}"
|
||||
shard_dir = output_dir / "shards" / str(i)
|
||||
container_dirs.append(shard_dir)
|
||||
proc = _spawn_container(
|
||||
image=image,
|
||||
instance_id=i,
|
||||
instance_count=instance_count,
|
||||
server_address=server_address,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
seed=cfg.seed,
|
||||
output_dir=shard_dir,
|
||||
docker_cfg=docker_cfg,
|
||||
env_argv=env_argv,
|
||||
)
|
||||
procs.append(proc)
|
||||
|
||||
failed: list[tuple[int, int]] = []
|
||||
for i, proc in enumerate(procs):
|
||||
rc = proc.wait()
|
||||
if rc != 0:
|
||||
failed.append((i, rc))
|
||||
logger.error("Container %d/%d exited with code %d", i + 1, instance_count, rc)
|
||||
if failed:
|
||||
raise RuntimeError(f"Docker eval containers failed (instance_id, exit_code): {failed}")
|
||||
|
||||
finally:
|
||||
for srv in servers:
|
||||
srv.shutdown()
|
||||
|
||||
# ── Collect and merge per-task results ───────────────────────────────
|
||||
per_task: list[dict] = []
|
||||
for i, shard_dir in enumerate(container_dirs):
|
||||
result_file = shard_dir / f"worker_{i}.json"
|
||||
with open(result_file) as f:
|
||||
shard_data: dict = json.load(f)
|
||||
per_task.extend(shard_data.get("per_task", []))
|
||||
|
||||
per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||
|
||||
info = _aggregate_eval_from_per_task(per_task, total_eval_s=time.time() - start_t)
|
||||
with open(output_dir / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info("Docker eval complete. Results: %s/eval_info.json", output_dir)
|
||||
|
||||
|
||||
def run_eval_multiprocess(cfg: EvalPipelineConfig) -> None:
|
||||
"""Run eval with multiple local worker processes and policy servers (no Docker).
|
||||
|
||||
Same architecture as Docker runtime but spawns `lerobot-eval-worker` as local
|
||||
subprocesses. Works on SLURM clusters and anywhere Docker is unavailable.
|
||||
"""
|
||||
from lerobot.scripts.lerobot_eval import _aggregate_eval_from_per_task
|
||||
|
||||
start_t = time.time()
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
torch.backends.cudnn.benchmark = True
|
||||
torch.backends.cuda.matmul.allow_tf32 = True
|
||||
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
|
||||
preprocessor_overrides: dict = {
|
||||
"device_processor": {"device": str(device)},
|
||||
"rename_observations_processor": {"rename_map": cfg.rename_map},
|
||||
}
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _env_postprocessor = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
|
||||
n_policy_servers = cfg.eval.policy_servers
|
||||
base_port = cfg.eval.port
|
||||
instance_count = cfg.eval.instance_count
|
||||
env_argv = _env_argv()
|
||||
|
||||
servers: list[_InferenceServer] = []
|
||||
for s_idx in range(n_policy_servers):
|
||||
port = base_port + s_idx
|
||||
if s_idx > 0:
|
||||
policy = make_policy(cfg=cfg.policy, env_cfg=cfg.env, rename_map=cfg.rename_map)
|
||||
policy.eval()
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=cfg.policy,
|
||||
pretrained_path=cfg.policy.pretrained_path,
|
||||
preprocessor_overrides=preprocessor_overrides,
|
||||
)
|
||||
env_preprocessor, _ = make_env_pre_post_processors(
|
||||
env_cfg=cfg.env, policy_cfg=cfg.policy,
|
||||
)
|
||||
srv = _InferenceServer(
|
||||
("0.0.0.0", port), # nosec B104
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
)
|
||||
t = threading.Thread(target=srv.serve_forever, daemon=True)
|
||||
t.start()
|
||||
servers.append(srv)
|
||||
logger.info("Policy server %d/%d on port %d", s_idx + 1, n_policy_servers, port)
|
||||
|
||||
worker_dirs: list[Path] = []
|
||||
procs: list[subprocess.Popen] = []
|
||||
try:
|
||||
for i in range(instance_count):
|
||||
assigned_port = base_port + (i % n_policy_servers)
|
||||
shard_dir = output_dir / "shards" / str(i)
|
||||
shard_dir.mkdir(parents=True, exist_ok=True)
|
||||
worker_dirs.append(shard_dir)
|
||||
|
||||
cmd = [
|
||||
sys.executable, "-m", "lerobot.scripts.lerobot_eval_worker",
|
||||
*env_argv,
|
||||
f"--server_address=127.0.0.1:{assigned_port}",
|
||||
f"--n_episodes={cfg.eval.n_episodes}",
|
||||
f"--seed={cfg.seed}",
|
||||
f"--instance_id={i}",
|
||||
f"--instance_count={instance_count}",
|
||||
f"--output_path={shard_dir / f'worker_{i}.json'}",
|
||||
]
|
||||
logger.info("Spawning worker %d/%d → port %d", i + 1, instance_count, assigned_port)
|
||||
procs.append(subprocess.Popen(cmd)) # nosec B603
|
||||
|
||||
failed: list[tuple[int, int]] = []
|
||||
for i, proc in enumerate(procs):
|
||||
rc = proc.wait()
|
||||
if rc != 0:
|
||||
failed.append((i, rc))
|
||||
logger.error("Worker %d/%d exited with code %d", i + 1, instance_count, rc)
|
||||
if failed:
|
||||
raise RuntimeError(f"Multiprocess eval workers failed (id, exit_code): {failed}")
|
||||
|
||||
finally:
|
||||
for srv in servers:
|
||||
srv.shutdown()
|
||||
|
||||
per_task: list[dict] = []
|
||||
for i, shard_dir in enumerate(worker_dirs):
|
||||
result_file = shard_dir / f"worker_{i}.json"
|
||||
with open(result_file) as f:
|
||||
shard_data: dict = json.load(f)
|
||||
per_task.extend(shard_data.get("per_task", []))
|
||||
|
||||
per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||
|
||||
info = _aggregate_eval_from_per_task(per_task, total_eval_s=time.time() - start_t)
|
||||
with open(output_dir / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
logger.info("Multiprocess eval complete. Results: %s/eval_info.json", output_dir)
|
||||
@@ -20,11 +20,21 @@ import gymnasium as gym
|
||||
from gymnasium.envs.registration import registry as gym_registry
|
||||
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.envs.configs import AlohaEnv, EnvConfig, HubEnvConfig, IsaaclabArenaEnv, LiberoEnv, PushtEnv
|
||||
from lerobot.envs.configs import (
|
||||
AlohaEnv,
|
||||
EnvConfig,
|
||||
HubEnvConfig,
|
||||
IsaaclabArenaEnv,
|
||||
LiberoEnv,
|
||||
LiberoPlusEnv,
|
||||
PushtEnv,
|
||||
RoboCasaEnv,
|
||||
RoboMMEEnv,
|
||||
)
|
||||
from lerobot.envs.utils import _call_make_env, _download_hub_file, _import_hub_module, _normalize_hub_result
|
||||
from lerobot.policies.xvla.configuration_xvla import XVLAConfig
|
||||
from lerobot.processor import ProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep
|
||||
from lerobot.processor.env_processor import IsaaclabArenaProcessorStep, LiberoProcessorStep, RoboCasaProcessorStep
|
||||
from lerobot.processor.pipeline import PolicyProcessorPipeline
|
||||
|
||||
|
||||
@@ -35,6 +45,12 @@ def make_env_config(env_type: str, **kwargs) -> EnvConfig:
|
||||
return PushtEnv(**kwargs)
|
||||
elif env_type == "libero":
|
||||
return LiberoEnv(**kwargs)
|
||||
elif env_type == "libero_plus":
|
||||
return LiberoPlusEnv(**kwargs)
|
||||
elif env_type == "robocasa":
|
||||
return RoboCasaEnv(**kwargs)
|
||||
elif env_type == "robomme":
|
||||
return RoboMMEEnv(**kwargs)
|
||||
else:
|
||||
raise ValueError(f"Policy type '{env_type}' is not available.")
|
||||
|
||||
@@ -70,9 +86,13 @@ def make_env_pre_post_processors(
|
||||
return make_xvla_libero_pre_post_processors()
|
||||
|
||||
# For LIBERO environments, add the LiberoProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, LiberoEnv) or "libero" in env_cfg.type:
|
||||
if isinstance(env_cfg, (LiberoEnv, LiberoPlusEnv)) or "libero" in env_cfg.type:
|
||||
preprocessor_steps.append(LiberoProcessorStep())
|
||||
|
||||
# For RoboCasa environments, add the RoboCasaProcessorStep to preprocessor
|
||||
if isinstance(env_cfg, RoboCasaEnv) or "robocasa" in env_cfg.type:
|
||||
preprocessor_steps.append(RoboCasaProcessorStep())
|
||||
|
||||
# For Isaaclab Arena environments, add the IsaaclabArenaProcessorStep
|
||||
if isinstance(env_cfg, IsaaclabArenaEnv) or "isaaclab_arena" in env_cfg.type:
|
||||
# Parse comma-separated keys (handle None for state-based policies)
|
||||
@@ -105,7 +125,7 @@ def make_env(
|
||||
use_async_envs: bool = False,
|
||||
hub_cache_dir: str | None = None,
|
||||
trust_remote_code: bool = False,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Makes a gym vector environment according to the config or Hub reference.
|
||||
|
||||
Args:
|
||||
@@ -123,8 +143,9 @@ def make_env(
|
||||
ModuleNotFoundError: If the requested env package is not installed
|
||||
|
||||
Returns:
|
||||
dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
A mapping from suite name to indexed vectorized environments.
|
||||
dict[str, dict[int, Any]]:
|
||||
A mapping from suite name to indexed environments. Values are either
|
||||
materialized vector envs or lazy wrappers that materialize on first use.
|
||||
- For multi-task benchmarks (e.g., LIBERO): one entry per suite, and one vec env per task_id.
|
||||
- For single-task environments: a single suite entry (cfg.type) with task_id=0.
|
||||
|
||||
@@ -171,6 +192,11 @@ def make_env(
|
||||
if cfg.task is None:
|
||||
raise ValueError("LiberoEnv requires a task to be specified")
|
||||
|
||||
if cfg.type == "libero_plus":
|
||||
from lerobot.envs.libero import _check_libero_plus_assets
|
||||
|
||||
_check_libero_plus_assets()
|
||||
|
||||
return create_libero_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
@@ -181,6 +207,33 @@ def make_env(
|
||||
control_mode=cfg.control_mode,
|
||||
episode_length=cfg.episode_length,
|
||||
)
|
||||
elif "robocasa" in cfg.type:
|
||||
from lerobot.envs.robocasa import create_robocasa_envs
|
||||
|
||||
tasks = cfg.tasks if cfg.tasks else [cfg.task]
|
||||
return create_robocasa_envs(
|
||||
tasks=tasks,
|
||||
n_envs=n_envs,
|
||||
image_size=cfg.image_size,
|
||||
split=cfg.split,
|
||||
episode_length=cfg.episode_length,
|
||||
gym_kwargs=cfg.gym_kwargs,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "robomme" in cfg.type:
|
||||
from lerobot.envs.robomme import create_robomme_envs
|
||||
|
||||
return create_robomme_envs(
|
||||
task=cfg.task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=cfg.action_space,
|
||||
dataset=cfg.dataset_split,
|
||||
episode_length=cfg.episode_length,
|
||||
task_ids=cfg.task_ids,
|
||||
env_cls=env_cls,
|
||||
)
|
||||
|
||||
elif "metaworld" in cfg.type:
|
||||
from lerobot.envs.metaworld import create_metaworld_envs
|
||||
|
||||
|
||||
@@ -0,0 +1,58 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from typing import Any
|
||||
|
||||
|
||||
class LazyVectorEnv:
|
||||
"""Defer vector-env construction until first usage.
|
||||
|
||||
This is useful for benchmarks with many tasks: we can register one env object
|
||||
per task without eagerly allocating all simulator/rendering resources.
|
||||
"""
|
||||
|
||||
def __init__(self, env_cls: Callable[[Sequence[Callable[[], Any]]], Any], factory_fns: list[Callable]):
|
||||
self._env_cls = env_cls
|
||||
self._factory_fns = factory_fns
|
||||
self._env = None
|
||||
|
||||
@property
|
||||
def env_cls(self) -> Callable[[Sequence[Callable[[], Any]]], Any]:
|
||||
return self._env_cls
|
||||
|
||||
@property
|
||||
def factory_fns(self) -> list[Callable]:
|
||||
return self._factory_fns
|
||||
|
||||
@property
|
||||
def num_factory_fns(self) -> int:
|
||||
return len(self._factory_fns)
|
||||
|
||||
def materialize(self):
|
||||
if self._env is None:
|
||||
self._env = self._env_cls(self._factory_fns)
|
||||
return self._env
|
||||
|
||||
def close(self):
|
||||
if self._env is not None:
|
||||
self._env.close()
|
||||
self._env = None
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.materialize(), name)
|
||||
|
||||
+310
-13
@@ -16,6 +16,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import re
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Iterable, Mapping, Sequence
|
||||
from functools import partial
|
||||
@@ -26,11 +27,222 @@ import gymnasium as gym
|
||||
import numpy as np
|
||||
import torch
|
||||
from gymnasium import spaces
|
||||
|
||||
try:
|
||||
import libero as _libero_pkg # noqa: F401
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
"Could not import libero. Install benchmark dependencies with one of:\n"
|
||||
" pip install -e \".[libero]\"\n"
|
||||
" pip install -e \".[libero_plus]\" (alias: \".[libero-plus]\")"
|
||||
)
|
||||
|
||||
# LIBERO's env_wrapper unconditionally imports wand (ImageMagick Python binding)
|
||||
# which requires the system-level libMagickWand library. The wand features are only
|
||||
# used for visual noise perturbations and are not needed for standard evaluation.
|
||||
# Pre-install a stub so the import succeeds even without ImageMagick.
|
||||
import sys
|
||||
import types
|
||||
|
||||
if "wand" not in sys.modules:
|
||||
try:
|
||||
import wand.api # noqa: F401
|
||||
except (ImportError, OSError):
|
||||
|
||||
class _AttrSink:
|
||||
"""Accepts any attribute get/set without error."""
|
||||
|
||||
def __getattr__(self, _name):
|
||||
return self
|
||||
|
||||
def __setattr__(self, _name, _value):
|
||||
pass
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
pass
|
||||
|
||||
_wand = types.ModuleType("wand")
|
||||
_wand_api = types.ModuleType("wand.api")
|
||||
_wand_api.library = _AttrSink()
|
||||
_wand_image = types.ModuleType("wand.image")
|
||||
_wand_image.Image = type("Image", (), {})
|
||||
_wand.api = _wand_api
|
||||
_wand.image = _wand_image
|
||||
sys.modules["wand"] = _wand
|
||||
sys.modules["wand.api"] = _wand_api
|
||||
sys.modules["wand.image"] = _wand_image
|
||||
|
||||
from libero.libero import benchmark, get_libero_path
|
||||
from libero.libero.envs import OffScreenRenderEnv
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
_ASSET_DOWNLOAD_INSTRUCTIONS = """\
|
||||
LIBERO-plus assets not found at: {assets_dir}
|
||||
|
||||
The LIBERO-plus benchmark requires ~6 GB of scene/texture/object assets that
|
||||
are hosted separately on Hugging Face. To download and install them:
|
||||
|
||||
python -c "
|
||||
from huggingface_hub import hf_hub_download
|
||||
hf_hub_download('Sylvest/LIBERO-plus', 'assets.zip',
|
||||
repo_type='dataset', local_dir='/tmp/libero-plus-assets')
|
||||
"
|
||||
unzip /tmp/libero-plus-assets/assets.zip -d /tmp/libero-plus-assets-unzipped
|
||||
# The zip contains a deeply nested path; move the assets directory:
|
||||
mv /tmp/libero-plus-assets-unzipped/inspire/*/assets {assets_dir}
|
||||
rm -rf /tmp/libero-plus-assets /tmp/libero-plus-assets-unzipped
|
||||
|
||||
See https://huggingface.co/datasets/Sylvest/LIBERO-plus for details.
|
||||
"""
|
||||
|
||||
|
||||
def _check_libero_plus_assets() -> None:
|
||||
"""Validate that LIBERO-plus scene assets are present."""
|
||||
assets_dir = Path(get_libero_path("benchmark_root")) / "assets"
|
||||
if not (assets_dir / "scenes").is_dir():
|
||||
raise FileNotFoundError(_ASSET_DOWNLOAD_INSTRUCTIONS.format(assets_dir=assets_dir))
|
||||
|
||||
|
||||
# ---- Perturbation support for LIBERO-Plus -----------------------------------
|
||||
|
||||
PERTURBATION_DIMENSIONS = (
|
||||
"Camera Viewpoints",
|
||||
"Robot Initial States",
|
||||
"Language Instructions",
|
||||
"Light Conditions",
|
||||
"Background Textures",
|
||||
"Sensor Noise",
|
||||
"Objects Layout",
|
||||
)
|
||||
|
||||
PERTURBATION_SHORT_KEYS = {
|
||||
"Camera Viewpoints": "camera",
|
||||
"Robot Initial States": "robot",
|
||||
"Language Instructions": "language",
|
||||
"Light Conditions": "light",
|
||||
"Background Textures": "background",
|
||||
"Sensor Noise": "noise",
|
||||
"Objects Layout": "layout",
|
||||
}
|
||||
|
||||
|
||||
def load_task_classification() -> dict:
|
||||
"""Load task_classification.json shipped with LIBERO-Plus."""
|
||||
import json
|
||||
|
||||
benchmark_root = Path(get_libero_path("benchmark_root"))
|
||||
candidates = [
|
||||
benchmark_root / "benchmark" / "task_classification.json",
|
||||
benchmark_root / "task_classification.json",
|
||||
benchmark_root.parent / "benchmark" / "task_classification.json",
|
||||
]
|
||||
for p in candidates:
|
||||
if p.exists():
|
||||
with open(p) as f:
|
||||
return json.load(f)
|
||||
raise FileNotFoundError(
|
||||
f"task_classification.json not found. Tried: {[str(c) for c in candidates]}"
|
||||
)
|
||||
|
||||
|
||||
def build_perturbation_index(suite_name: str) -> dict[int, str]:
|
||||
"""Return {0-indexed task_id: perturbation_dimension} for *suite_name*."""
|
||||
tc = load_task_classification()
|
||||
suite_data = tc.get(suite_name, {})
|
||||
index: dict[int, str] = {}
|
||||
|
||||
# LIBERO-Plus task_classification.json has appeared in two shapes:
|
||||
# 1) dict[suite][task_id_str] -> meta
|
||||
# 2) dict[suite] -> list[{id, category, ...}]
|
||||
if isinstance(suite_data, list):
|
||||
for item in suite_data:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
raw_id = item.get("id")
|
||||
if raw_id is None:
|
||||
continue
|
||||
try:
|
||||
# list-form ids are 1-indexed in current LIBERO-Plus release.
|
||||
tid = int(raw_id) - 1
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if tid < 0:
|
||||
continue
|
||||
dim = item.get("perturbation_type") or item.get("category", "unknown")
|
||||
index[tid] = dim
|
||||
return index
|
||||
|
||||
if isinstance(suite_data, dict):
|
||||
# Handle both 0-indexed and 1-indexed key conventions.
|
||||
numeric_keys: list[int] = []
|
||||
for k in suite_data:
|
||||
try:
|
||||
numeric_keys.append(int(k))
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
one_indexed = bool(numeric_keys) and 0 not in numeric_keys and min(numeric_keys) >= 1
|
||||
|
||||
for task_id_str, meta in suite_data.items():
|
||||
try:
|
||||
tid = int(task_id_str)
|
||||
except (TypeError, ValueError):
|
||||
continue
|
||||
if one_indexed:
|
||||
tid -= 1
|
||||
if tid < 0:
|
||||
continue
|
||||
if isinstance(meta, dict):
|
||||
dim = meta.get("perturbation_type") or meta.get("category", "unknown")
|
||||
else:
|
||||
dim = "unknown"
|
||||
index[tid] = dim
|
||||
return index
|
||||
|
||||
return index
|
||||
|
||||
|
||||
def aggregate_by_perturbation(
|
||||
per_task: list[dict], suite_indices: dict[str, dict[int, str]]
|
||||
) -> dict[str, dict]:
|
||||
"""Aggregate per-task eval results by perturbation dimension.
|
||||
|
||||
Args:
|
||||
per_task: list of {"task_group": str, "task_id": int, "metrics": {...}}
|
||||
suite_indices: {suite_name: {task_id: dimension_name}} from build_perturbation_index
|
||||
|
||||
Returns:
|
||||
{short_key: {"pc_success": float, "n_episodes": int}} for each perturbation dimension
|
||||
"""
|
||||
dim_successes: dict[str, list] = defaultdict(list)
|
||||
for entry in per_task:
|
||||
suite = entry["task_group"]
|
||||
tid = entry["task_id"]
|
||||
idx = suite_indices.get(suite, {})
|
||||
dim = idx.get(tid, "unknown")
|
||||
short = PERTURBATION_SHORT_KEYS.get(dim, dim.lower().replace(" ", "_"))
|
||||
successes = entry["metrics"].get("successes", [])
|
||||
dim_successes[short].extend(successes)
|
||||
|
||||
results: dict[str, dict] = {}
|
||||
all_successes: list = []
|
||||
for short_key in list(PERTURBATION_SHORT_KEYS.values()) + ["unknown"]:
|
||||
if short_key not in dim_successes:
|
||||
continue
|
||||
s = dim_successes[short_key]
|
||||
all_successes.extend(s)
|
||||
results[short_key] = {
|
||||
"pc_success": float(np.nanmean(s) * 100) if s else float("nan"),
|
||||
"n_episodes": len(s),
|
||||
}
|
||||
if all_successes:
|
||||
results["total"] = {
|
||||
"pc_success": float(np.nanmean(all_successes) * 100),
|
||||
"n_episodes": len(all_successes),
|
||||
}
|
||||
return results
|
||||
|
||||
|
||||
def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
@@ -68,13 +280,35 @@ def _select_task_ids(total_tasks: int, task_ids: Iterable[int] | None) -> list[i
|
||||
|
||||
|
||||
def get_task_init_states(task_suite: Any, i: int) -> np.ndarray:
|
||||
init_states_path = (
|
||||
Path(get_libero_path("init_states"))
|
||||
/ task_suite.tasks[i].problem_folder
|
||||
/ task_suite.tasks[i].init_states_file
|
||||
init_states_dir = Path(get_libero_path("init_states")) / task_suite.tasks[i].problem_folder
|
||||
init_states_file = task_suite.tasks[i].init_states_file
|
||||
|
||||
# 1. Direct match
|
||||
direct = init_states_dir / init_states_file
|
||||
if direct.exists():
|
||||
return torch.load(direct, weights_only=False) # nosec B614
|
||||
|
||||
# 2. LIBERO-Plus perturbation filenames append suffixes like
|
||||
# _view_0_0_100_0_0_initstate_1, _language_19, _noise_45, _table_1, _tb_9, _add_16
|
||||
# to the base task name. Instead of regex-matching every variant, find the
|
||||
# longest existing base file whose stem is a prefix of the perturbation stem.
|
||||
stem, ext = os.path.splitext(init_states_file)
|
||||
best_match: Path | None = None
|
||||
best_len = 0
|
||||
for candidate in init_states_dir.glob(f"*{ext}"):
|
||||
cstem = candidate.stem
|
||||
if stem == cstem or (stem.startswith(cstem) and stem[len(cstem)] == "_"):
|
||||
if len(cstem) > best_len:
|
||||
best_len = len(cstem)
|
||||
best_match = candidate
|
||||
|
||||
if best_match is not None:
|
||||
return torch.load(best_match, weights_only=False) # nosec B614
|
||||
|
||||
raise FileNotFoundError(
|
||||
f"Could not find init states for task {i}. "
|
||||
f"Tried '{init_states_file}' and prefix matching in '{init_states_dir}'."
|
||||
)
|
||||
init_states = torch.load(init_states_path, weights_only=False) # nosec B614
|
||||
return init_states
|
||||
|
||||
|
||||
def get_libero_dummy_action():
|
||||
@@ -94,6 +328,29 @@ TASK_SUITE_MAX_STEPS: dict[str, int] = {
|
||||
}
|
||||
|
||||
|
||||
def _make_offscreen_env_with_renderer_fallback(env_args: dict[str, Any]) -> Any:
|
||||
"""Create OffScreenRenderEnv and fallback to OSMesa if EGL is unavailable."""
|
||||
try:
|
||||
return OffScreenRenderEnv(**env_args)
|
||||
except ImportError as exc:
|
||||
msg = str(exc)
|
||||
if "EGL" not in msg and "PLATFORM_DEVICE" not in msg:
|
||||
raise
|
||||
|
||||
# Headless clusters often miss EGL PLATFORM_DEVICE support. Retry with
|
||||
# software rendering to keep evaluation working.
|
||||
os.environ["MUJOCO_GL"] = "osmesa"
|
||||
os.environ["PYOPENGL_PLATFORM"] = "osmesa"
|
||||
try:
|
||||
return OffScreenRenderEnv(**env_args)
|
||||
except Exception as fallback_exc:
|
||||
raise ImportError(
|
||||
"Failed to initialize robosuite offscreen renderer with both EGL and "
|
||||
"OSMesa backends. Set up EGL-capable drivers or install OSMesa (e.g. "
|
||||
"`conda install -c conda-forge mesalib`) and retry."
|
||||
) from fallback_exc
|
||||
|
||||
|
||||
class LiberoEnv(gym.Env):
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 80}
|
||||
|
||||
@@ -147,6 +404,7 @@ class LiberoEnv(gym.Env):
|
||||
# Load once and keep
|
||||
self._init_states = get_task_init_states(task_suite, self.task_id) if self.init_states else None
|
||||
self._reset_stride = n_envs # when performing a reset, append `_reset_stride` to `init_state_id`.
|
||||
self._init_state_error_warned = False
|
||||
|
||||
self.init_state_id = self.episode_index # tie each sub-env to a fixed init state
|
||||
|
||||
@@ -238,7 +496,7 @@ class LiberoEnv(gym.Env):
|
||||
"camera_heights": self.observation_height,
|
||||
"camera_widths": self.observation_width,
|
||||
}
|
||||
env = OffScreenRenderEnv(**env_args)
|
||||
env = _make_offscreen_env_with_renderer_fallback(env_args)
|
||||
env.reset()
|
||||
return env
|
||||
|
||||
@@ -298,8 +556,21 @@ class LiberoEnv(gym.Env):
|
||||
self._env.seed(seed)
|
||||
raw_obs = self._env.reset()
|
||||
if self.init_states and self._init_states is not None:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
try:
|
||||
raw_obs = self._env.set_init_state(self._init_states[self.init_state_id % len(self._init_states)])
|
||||
self.init_state_id += self._reset_stride # Change init_state_id when reset
|
||||
except Exception as exc:
|
||||
# Some LIBERO-Plus perturbation tasks (notably object-layout variants)
|
||||
# can have different simulator state dimensions than their base init files.
|
||||
# Fall back to plain env.reset() instead of aborting the whole evaluation.
|
||||
self.init_states = False
|
||||
if not self._init_state_error_warned:
|
||||
print(
|
||||
"WARNING: Failed to apply init state for "
|
||||
f"task_id={self.task_id} ({self.task}). "
|
||||
f"Falling back to plain reset. Error: {exc}"
|
||||
)
|
||||
self._init_state_error_warned = True
|
||||
|
||||
# After reset, objects may be unstable (slightly floating, intersecting, etc.).
|
||||
# Step the simulator with a no-op action for a few frames so everything settles.
|
||||
@@ -325,7 +596,17 @@ class LiberoEnv(gym.Env):
|
||||
f"Expected action to be 1-D (shape (action_dim,)), "
|
||||
f"but got shape {action.shape} with ndim={action.ndim}"
|
||||
)
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
|
||||
try:
|
||||
raw_obs, reward, done, info = self._env.step(action)
|
||||
except ValueError as e:
|
||||
if "terminated episode" not in str(e):
|
||||
raise
|
||||
# Robosuite's internal done flag is stale (e.g. from a previous
|
||||
# termination that wasn't properly cleared by SyncVectorEnv).
|
||||
# Signal termination so the caller resets us.
|
||||
obs, reset_info = self.reset()
|
||||
return obs, 0.0, True, False, {"is_success": False, **reset_info}
|
||||
|
||||
is_success = self._env.check_success()
|
||||
terminated = done or is_success
|
||||
@@ -345,7 +626,6 @@ class LiberoEnv(gym.Env):
|
||||
"done": bool(done),
|
||||
"is_success": bool(is_success),
|
||||
}
|
||||
self.reset()
|
||||
truncated = False
|
||||
return observation, reward, terminated, truncated, info
|
||||
|
||||
@@ -388,6 +668,9 @@ def _make_env_fns(
|
||||
return fns
|
||||
|
||||
|
||||
_LazyVecEnv = LazyVectorEnv
|
||||
|
||||
|
||||
# ---- Main API ----------------------------------------------------------------
|
||||
|
||||
|
||||
@@ -431,12 +714,23 @@ def create_libero_envs(
|
||||
print(f"Restricting to task_ids={task_ids_filter}")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
total_tasks = 0
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
if not selected:
|
||||
raise ValueError(f"No tasks selected for suite '{suite_name}' (available: {total}).")
|
||||
total_tasks += len(selected)
|
||||
|
||||
lazy = total_tasks > 1
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
|
||||
for suite_name in suite_names:
|
||||
suite = _get_suite(suite_name)
|
||||
total = len(suite.tasks)
|
||||
selected = _select_task_ids(total, task_ids_filter)
|
||||
|
||||
for tid in selected:
|
||||
fns = _make_env_fns(
|
||||
@@ -450,8 +744,11 @@ def create_libero_envs(
|
||||
gym_kwargs=gym_kwargs,
|
||||
control_mode=control_mode,
|
||||
)
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
if lazy:
|
||||
out[suite_name][tid] = LazyVectorEnv(env_cls, fns)
|
||||
else:
|
||||
out[suite_name][tid] = env_cls(fns)
|
||||
print(f"Built vec env | suite={suite_name} | task_id={tid} | n_envs={n_envs}")
|
||||
|
||||
# return plain dicts for predictability
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
|
||||
@@ -25,6 +25,7 @@ import metaworld.policies as policies
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.processor import RobotObservation
|
||||
|
||||
# ---- Load configuration data from the external JSON file ----
|
||||
@@ -297,19 +298,24 @@ def create_metaworld_envs(
|
||||
|
||||
print(f"Creating Meta-World envs | task_groups={task_groups} | n_envs(per task)={n_envs}")
|
||||
|
||||
group_to_tasks = {group: DIFFICULTY_TO_TASKS.get(group, [group]) for group in task_groups}
|
||||
total_tasks = sum(len(tasks) for tasks in group_to_tasks.values())
|
||||
lazy = total_tasks > 50
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
|
||||
for group in task_groups:
|
||||
# if not in difficulty presets, treat it as a single custom task
|
||||
tasks = DIFFICULTY_TO_TASKS.get(group, [group])
|
||||
tasks = group_to_tasks[group]
|
||||
|
||||
for tid, task_name in enumerate(tasks):
|
||||
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||
if not lazy:
|
||||
print(f"Building vec env | group={group} | task_id={tid} | task={task_name}")
|
||||
|
||||
# build n_envs factories
|
||||
fns = [(lambda tn=task_name: MetaworldEnv(task=tn, **gym_kwargs)) for _ in range(n_envs)]
|
||||
|
||||
out[group][tid] = env_cls(fns)
|
||||
out[group][tid] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
|
||||
# return a plain dict for consistency
|
||||
return {group: dict(task_map) for group, task_map in out.items()}
|
||||
|
||||
@@ -0,0 +1,279 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
from __future__ import annotations
|
||||
|
||||
from collections import defaultdict
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
|
||||
# Action layout (flat 12D, normalized to [-1, 1]):
|
||||
# [0:3] end_effector_position (delta x, y, z)
|
||||
# [3:6] end_effector_rotation (delta roll, pitch, yaw)
|
||||
# [6:7] gripper_close (open=-1, close=+1)
|
||||
# [7:11] base_motion (x, y, theta, torso_height)
|
||||
# [11:12] control_mode (arm=-1, base=+1)
|
||||
ACTION_DIM = 12
|
||||
ACTION_LOW = -1.0
|
||||
ACTION_HIGH = 1.0
|
||||
|
||||
# Proprioceptive state layout (flat 16D):
|
||||
# [0:2] gripper_qpos
|
||||
# [2:5] base_position
|
||||
# [5:9] base_rotation (quaternion)
|
||||
# [9:12] end_effector_position_relative
|
||||
# [12:16] end_effector_rotation_relative (quaternion)
|
||||
STATE_DIM = 16
|
||||
|
||||
# Obs dict keys from RoboCasaGymEnv.get_observation()
|
||||
_CAM_KEYS = (
|
||||
"video.robot0_agentview_left",
|
||||
"video.robot0_agentview_right",
|
||||
"video.robot0_eye_in_hand",
|
||||
)
|
||||
_STATE_KEYS_ORDERED = (
|
||||
"state.gripper_qpos", # (2,)
|
||||
"state.base_position", # (3,)
|
||||
"state.base_rotation", # (4,)
|
||||
"state.end_effector_position_relative", # (3,)
|
||||
"state.end_effector_rotation_relative", # (4,)
|
||||
)
|
||||
|
||||
# Mapping from video.* key → short image name used in features_map
|
||||
CAM_KEY_TO_NAME = {
|
||||
"video.robot0_agentview_left": "agentview_left",
|
||||
"video.robot0_agentview_right": "agentview_right",
|
||||
"video.robot0_eye_in_hand": "eye_in_hand",
|
||||
}
|
||||
|
||||
|
||||
def _flat_to_action_dict(flat: np.ndarray) -> dict[str, np.ndarray]:
|
||||
"""Convert a 12D flat action array to the Dict format expected by RoboCasaGymEnv."""
|
||||
return {
|
||||
"action.end_effector_position": flat[0:3],
|
||||
"action.end_effector_rotation": flat[3:6],
|
||||
"action.gripper_close": flat[6:7],
|
||||
"action.base_motion": flat[7:11],
|
||||
"action.control_mode": flat[11:12],
|
||||
}
|
||||
|
||||
|
||||
class RoboCasaEnv(gym.Env):
|
||||
"""Thin wrapper around RoboCasaGymEnv that provides a flat Box action space
|
||||
and a structured observation dict compatible with LeRobot policies.
|
||||
|
||||
Observations returned by step/reset:
|
||||
{
|
||||
"pixels": {
|
||||
"agentview_left": (H, W, 3) uint8,
|
||||
"agentview_right": (H, W, 3) uint8,
|
||||
"eye_in_hand": (H, W, 3) uint8,
|
||||
},
|
||||
"robot_state": (16,) float32,
|
||||
}
|
||||
|
||||
Actions: flat float32 ndarray of shape (12,), normalized to [-1, 1].
|
||||
"""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"], "render_fps": 20}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str,
|
||||
split: str = "target",
|
||||
image_size: int = 128,
|
||||
render_mode: str = "rgb_array",
|
||||
episode_length: int = 500,
|
||||
**gym_kwargs: Any,
|
||||
):
|
||||
super().__init__()
|
||||
# Lazy import — robocasa is optional
|
||||
import robocasa.environments # noqa: F401 — registers all gym envs
|
||||
|
||||
self.task = task
|
||||
self.render_mode = render_mode
|
||||
self.image_size = image_size
|
||||
self._max_episode_steps = episode_length
|
||||
self._step_count = 0
|
||||
|
||||
self._env = gym.make(
|
||||
f"robocasa/{task}",
|
||||
split=split,
|
||||
camera_widths=image_size,
|
||||
camera_heights=image_size,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
# Flat 12D Box action space
|
||||
self.action_space = spaces.Box(
|
||||
low=ACTION_LOW,
|
||||
high=ACTION_HIGH,
|
||||
shape=(ACTION_DIM,),
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
images = {
|
||||
name: spaces.Box(low=0, high=255, shape=(image_size, image_size, 3), dtype=np.uint8)
|
||||
for name in CAM_KEY_TO_NAME.values()
|
||||
}
|
||||
self.observation_space = spaces.Dict(
|
||||
{
|
||||
"pixels": spaces.Dict(images),
|
||||
"robot_state": spaces.Box(
|
||||
low=-np.inf, high=np.inf, shape=(STATE_DIM,), dtype=np.float32
|
||||
),
|
||||
}
|
||||
)
|
||||
|
||||
def _format_obs(self, raw_obs: dict) -> dict:
|
||||
pixels = {
|
||||
CAM_KEY_TO_NAME[k]: raw_obs[k]
|
||||
for k in _CAM_KEYS
|
||||
if k in raw_obs
|
||||
}
|
||||
state_parts = [
|
||||
np.asarray(raw_obs[k], dtype=np.float32)
|
||||
for k in _STATE_KEYS_ORDERED
|
||||
if k in raw_obs
|
||||
]
|
||||
robot_state = np.concatenate(state_parts) if state_parts else np.zeros(STATE_DIM, dtype=np.float32)
|
||||
return {"pixels": pixels, "robot_state": robot_state}
|
||||
|
||||
def reset(self, seed: int | None = None, **kwargs) -> tuple[dict, dict]:
|
||||
super().reset(seed=seed)
|
||||
self._step_count = 0
|
||||
raw_obs, info = self._env.reset(seed=seed)
|
||||
info.setdefault("is_success", False)
|
||||
info["task"] = self.task
|
||||
return self._format_obs(raw_obs), info
|
||||
|
||||
def step(self, action: np.ndarray) -> tuple[dict, float, bool, bool, dict]:
|
||||
if action.ndim != 1 or action.shape[0] != ACTION_DIM:
|
||||
raise ValueError(
|
||||
f"Expected 1-D action of shape ({ACTION_DIM},), got {action.shape}"
|
||||
)
|
||||
action_dict = _flat_to_action_dict(action)
|
||||
raw_obs, reward, terminated, truncated, info = self._env.step(action_dict)
|
||||
self._step_count += 1
|
||||
|
||||
is_success = bool(info.get("success", False))
|
||||
terminated = terminated or is_success
|
||||
if self._step_count >= self._max_episode_steps:
|
||||
truncated = True
|
||||
|
||||
info.update({"task": self.task, "is_success": is_success})
|
||||
obs = self._format_obs(raw_obs)
|
||||
|
||||
if terminated or truncated:
|
||||
info["final_info"] = {"task": self.task, "is_success": is_success}
|
||||
|
||||
return obs, reward, terminated, truncated, info
|
||||
|
||||
def render(self) -> np.ndarray | None:
|
||||
if self.render_mode == "rgb_array":
|
||||
return self._env.render()
|
||||
return None
|
||||
|
||||
def close(self) -> None:
|
||||
self._env.close()
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
image_size: int,
|
||||
split: str,
|
||||
episode_length: int,
|
||||
gym_kwargs: dict[str, Any],
|
||||
) -> list[Callable[[], RoboCasaEnv]]:
|
||||
"""Build n_envs factory callables for a single task."""
|
||||
def _make(episode_index: int) -> RoboCasaEnv: # noqa: ARG001
|
||||
return RoboCasaEnv(
|
||||
task=task,
|
||||
split=split,
|
||||
image_size=image_size,
|
||||
episode_length=episode_length,
|
||||
**gym_kwargs,
|
||||
)
|
||||
|
||||
return [partial(_make, i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robocasa_envs(
|
||||
tasks: str | Sequence[str],
|
||||
n_envs: int,
|
||||
image_size: int = 128,
|
||||
split: str = "target",
|
||||
episode_length: int = 500,
|
||||
gym_kwargs: dict[str, Any] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, Any]]:
|
||||
"""Create vectorized RoboCasa environments.
|
||||
|
||||
Args:
|
||||
tasks: A single task name or list of task names (without "robocasa/" prefix).
|
||||
E.g. "PickPlaceCounterToCabinet" or ["BoilPot", "PrepareCoffee"].
|
||||
n_envs: Number of parallel envs per task.
|
||||
image_size: Square image resolution for all cameras.
|
||||
split: RoboCasa dataset split — "pretrain" or "target".
|
||||
episode_length: Max steps per episode before truncation.
|
||||
gym_kwargs: Extra kwargs forwarded to each RoboCasaEnv.
|
||||
env_cls: Callable to wrap list of factory fns (SyncVectorEnv or AsyncVectorEnv).
|
||||
|
||||
Returns:
|
||||
dict[task_name][task_id=0] -> vec_env
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable wrapping a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if isinstance(tasks, str):
|
||||
task_list = [t.strip() for t in tasks.split(",") if t.strip()]
|
||||
else:
|
||||
task_list = [str(t).strip() for t in tasks if str(t).strip()]
|
||||
if not task_list:
|
||||
raise ValueError("`tasks` must contain at least one task name.")
|
||||
|
||||
gym_kwargs = dict(gym_kwargs or {})
|
||||
out: dict[str, dict[int, Any]] = defaultdict(dict)
|
||||
total_tasks = len(task_list)
|
||||
lazy = total_tasks > 50
|
||||
|
||||
print(f"Creating RoboCasa envs | tasks={task_list} | n_envs(per task)={n_envs} | split={split}")
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {total_tasks} tasks (envs created on demand)")
|
||||
for task in task_list:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
image_size=image_size,
|
||||
split=split,
|
||||
episode_length=episode_length,
|
||||
gym_kwargs=gym_kwargs,
|
||||
)
|
||||
out["robocasa"][len(out["robocasa"])] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
if not lazy:
|
||||
print(f" Built vec env | task={task} | n_envs={n_envs}")
|
||||
|
||||
return {suite: dict(task_map) for suite, task_map in out.items()}
|
||||
@@ -0,0 +1,181 @@
|
||||
"""RoboMME environment wrapper for LeRobot evaluation.
|
||||
|
||||
Wraps the RoboMME ``BenchmarkEnvBuilder`` into a Gymnasium-compatible
|
||||
``VectorEnv`` suitable for ``lerobot_eval``.
|
||||
|
||||
RoboMME tasks:
|
||||
Counting: BinFill, PickXtimes, SwingXtimes, StopCube
|
||||
Permanence: VideoUnmask, VideoUnmaskSwap, ButtonUnmask, ButtonUnmaskSwap
|
||||
Reference: PickHighlight, VideoRepick, VideoPlaceButton, VideoPlaceOrder
|
||||
Imitation: MoveCube, InsertPeg, PatternLock, RouteStick
|
||||
|
||||
Install: pip install robomme (or from source: https://github.com/RoboMME/robomme_benchmark)
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Callable, Sequence
|
||||
from functools import partial
|
||||
from typing import Any
|
||||
|
||||
import gymnasium as gym
|
||||
import numpy as np
|
||||
from gymnasium import spaces
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
|
||||
ROBOMME_TASKS = [
|
||||
"BinFill", "PickXtimes", "SwingXtimes", "StopCube",
|
||||
"VideoUnmask", "VideoUnmaskSwap", "ButtonUnmask", "ButtonUnmaskSwap",
|
||||
"PickHighlight", "VideoRepick", "VideoPlaceButton", "VideoPlaceOrder",
|
||||
"MoveCube", "InsertPeg", "PatternLock", "RouteStick",
|
||||
]
|
||||
|
||||
|
||||
class RoboMMEGymEnv(gym.Env):
|
||||
"""Thin Gymnasium wrapper around a single RoboMME episode env."""
|
||||
|
||||
metadata = {"render_modes": ["rgb_array"]}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
task: str = "PickXtimes",
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_idx: int = 0,
|
||||
max_steps: int = 300,
|
||||
):
|
||||
super().__init__()
|
||||
from robomme.env_record_wrapper import BenchmarkEnvBuilder
|
||||
|
||||
self._task = task
|
||||
self._action_space_type = action_space_type
|
||||
self._dataset = dataset
|
||||
self._episode_idx = episode_idx
|
||||
self._max_steps = max_steps
|
||||
|
||||
self._builder = BenchmarkEnvBuilder(
|
||||
env_id=task,
|
||||
dataset=dataset,
|
||||
action_space=action_space_type,
|
||||
gui_render=False,
|
||||
max_steps=max_steps,
|
||||
)
|
||||
self._env = None
|
||||
|
||||
action_dim = 8 if action_space_type == "joint_angle" else 7
|
||||
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(action_dim,), dtype=np.float32)
|
||||
self.observation_space = spaces.Dict({
|
||||
"front_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"wrist_rgb": spaces.Box(0, 255, shape=(256, 256, 3), dtype=np.uint8),
|
||||
"state": spaces.Box(-np.inf, np.inf, shape=(8,), dtype=np.float32),
|
||||
})
|
||||
|
||||
def reset(self, *, seed=None, options=None):
|
||||
super().reset(seed=seed)
|
||||
self._env = self._builder.make_env_for_episode(
|
||||
episode_idx=self._episode_idx, max_steps=self._max_steps,
|
||||
)
|
||||
obs, info = self._env.reset()
|
||||
return self._convert_obs(obs), self._convert_info(info)
|
||||
|
||||
def step(self, action):
|
||||
obs, reward, terminated, truncated, info = self._env.step(action)
|
||||
|
||||
terminated_bool = bool(terminated.item()) if hasattr(terminated, "item") else bool(terminated)
|
||||
truncated_bool = bool(truncated.item()) if hasattr(truncated, "item") else bool(truncated)
|
||||
|
||||
status = info.get("status", "ongoing")
|
||||
is_success = status == "success"
|
||||
conv_info = self._convert_info(info)
|
||||
conv_info["is_success"] = is_success
|
||||
|
||||
return self._convert_obs(obs), float(reward), terminated_bool, truncated_bool, conv_info
|
||||
|
||||
def _convert_obs(self, obs: dict) -> dict:
|
||||
front_rgb = obs["front_rgb_list"][-1] if isinstance(obs["front_rgb_list"], list) else obs["front_rgb_list"]
|
||||
wrist_rgb = obs["wrist_rgb_list"][-1] if isinstance(obs["wrist_rgb_list"], list) else obs["wrist_rgb_list"]
|
||||
joint_state = obs["joint_state_list"][-1] if isinstance(obs["joint_state_list"], list) else obs["joint_state_list"]
|
||||
gripper_state = obs["gripper_state_list"][-1] if isinstance(obs["gripper_state_list"], list) else obs["gripper_state_list"]
|
||||
|
||||
front_rgb = np.asarray(front_rgb, dtype=np.uint8)
|
||||
wrist_rgb = np.asarray(wrist_rgb, dtype=np.uint8)
|
||||
joint = np.asarray(joint_state, dtype=np.float32).flatten()[:7]
|
||||
gripper = np.asarray(gripper_state, dtype=np.float32).flatten()[:1]
|
||||
state = np.concatenate([joint, gripper])
|
||||
|
||||
return {
|
||||
"front_rgb": front_rgb,
|
||||
"wrist_rgb": wrist_rgb,
|
||||
"state": state,
|
||||
}
|
||||
|
||||
def _convert_info(self, info: dict) -> dict:
|
||||
return {
|
||||
"status": info.get("status", "ongoing"),
|
||||
"task_goal": info.get("task_goal", ""),
|
||||
}
|
||||
|
||||
|
||||
def _make_env_fns(
|
||||
*,
|
||||
task: str,
|
||||
n_envs: int,
|
||||
action_space_type: str,
|
||||
dataset: str,
|
||||
episode_length: int,
|
||||
task_id: int,
|
||||
) -> list[Callable[[], RoboMMEGymEnv]]:
|
||||
"""Build n_envs factory callables for one RoboMME task id."""
|
||||
|
||||
def _make_one(episode_index: int) -> RoboMMEGymEnv:
|
||||
return RoboMMEGymEnv(
|
||||
task=task,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_idx=episode_index,
|
||||
max_steps=episode_length,
|
||||
)
|
||||
|
||||
return [partial(_make_one, task_id + i) for i in range(n_envs)]
|
||||
|
||||
|
||||
def create_robomme_envs(
|
||||
task: str,
|
||||
n_envs: int = 1,
|
||||
action_space_type: str = "joint_angle",
|
||||
dataset: str = "test",
|
||||
episode_length: int = 300,
|
||||
task_ids: list[int] | None = None,
|
||||
env_cls: Callable[[Sequence[Callable[[], Any]]], Any] | None = None,
|
||||
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
|
||||
"""Create vectorized RoboMME environments for evaluation.
|
||||
|
||||
Returns {suite_name: {task_id: VectorEnv}} matching lerobot's expected format.
|
||||
"""
|
||||
if env_cls is None or not callable(env_cls):
|
||||
raise ValueError("env_cls must be a callable that wraps a list of env factory callables.")
|
||||
if not isinstance(n_envs, int) or n_envs <= 0:
|
||||
raise ValueError(f"n_envs must be a positive int; got {n_envs}.")
|
||||
|
||||
if task_ids is None:
|
||||
task_ids = [0]
|
||||
|
||||
suite_name = "robomme"
|
||||
envs_by_task = {}
|
||||
lazy = len(task_ids) > 50
|
||||
if lazy:
|
||||
print(f"Using lazy env creation for {len(task_ids)} tasks (envs created on demand)")
|
||||
|
||||
for task_id in task_ids:
|
||||
fns = _make_env_fns(
|
||||
task=task,
|
||||
n_envs=n_envs,
|
||||
action_space_type=action_space_type,
|
||||
dataset=dataset,
|
||||
episode_length=episode_length,
|
||||
task_id=task_id,
|
||||
)
|
||||
envs_by_task[task_id] = LazyVectorEnv(env_cls, fns) if lazy else env_cls(fns)
|
||||
|
||||
return {suite_name: envs_by_task}
|
||||
@@ -153,6 +153,44 @@ class LiberoProcessorStep(ObservationProcessorStep):
|
||||
return result
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="robocasa_processor")
|
||||
class RoboCasaProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
Processes RoboCasa observations into LeRobot format.
|
||||
|
||||
The RoboCasaEnv wrapper returns:
|
||||
- ``pixels.<cam_name>``: (B, C, H, W) float32 images (already converted by vectorenv)
|
||||
- ``observation.robot_state``: (B, 16) float32 proprioception
|
||||
|
||||
This step remaps them to:
|
||||
- ``observation.images.<cam_name>`` (unchanged tensor)
|
||||
- ``observation.state`` (robot_state renamed)
|
||||
"""
|
||||
|
||||
def _process_observation(self, observation: dict) -> dict:
|
||||
processed = {}
|
||||
obs_prefix = OBS_PREFIX # "observation."
|
||||
|
||||
for key, value in observation.items():
|
||||
if key.startswith(f"{OBS_IMAGES}."):
|
||||
# Already in the right place; pass through
|
||||
processed[key] = value
|
||||
elif key == OBS_STATE or key == f"{obs_prefix}robot_state":
|
||||
# Rename robot_state → observation.state
|
||||
processed[OBS_STATE] = value.float() if hasattr(value, "float") else value
|
||||
|
||||
return processed
|
||||
|
||||
def transform_features(
|
||||
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
|
||||
) -> dict[PipelineFeatureType, dict[str, PolicyFeature]]:
|
||||
return features
|
||||
|
||||
def observation(self, observation: dict) -> dict:
|
||||
return self._process_observation(observation)
|
||||
|
||||
|
||||
@dataclass
|
||||
@ProcessorStepRegistry.register(name="isaaclab_arena_processor")
|
||||
class IsaaclabArenaProcessorStep(ObservationProcessorStep):
|
||||
|
||||
@@ -0,0 +1,462 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Benchmark runner: train and evaluate policies across simulation benchmarks.
|
||||
|
||||
Orchestrates per-benchmark training and evaluation using the existing
|
||||
``lerobot-train`` and ``lerobot-eval`` CLI tools.
|
||||
|
||||
Typical usage::
|
||||
|
||||
# Train SmolVLA on LIBERO-plus (4 GPUs, 50k steps):
|
||||
lerobot-benchmark train \\
|
||||
--benchmarks libero_plus \\
|
||||
--policy-path lerobot/smolvla_base \\
|
||||
--hub-user $HF_USER \\
|
||||
--num-gpus 4 --steps 50000
|
||||
|
||||
# Evaluate the trained policies:
|
||||
lerobot-benchmark eval \\
|
||||
--benchmarks libero_plus \\
|
||||
--hub-user $HF_USER
|
||||
|
||||
# Full pipeline (train → upload → eval) for multiple benchmarks:
|
||||
lerobot-benchmark all \\
|
||||
--benchmarks libero_plus,robocasa,robomme \\
|
||||
--policy-path lerobot/smolvla_base \\
|
||||
--hub-user $HF_USER \\
|
||||
--num-gpus 4 --steps 50000
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(asctime)s %(levelname)s %(message)s")
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class BenchmarkEntry:
|
||||
"""Training + evaluation settings for a single benchmark.
|
||||
|
||||
When ``eval_tasks`` is set, evaluation runs once per task in the list
|
||||
(e.g. libero_spatial, libero_object, …). ``env_task`` is still used as
|
||||
the task for mid-training evaluation during ``lerobot-train``.
|
||||
"""
|
||||
|
||||
dataset_repo_id: str
|
||||
env_type: str
|
||||
env_task: str
|
||||
eval_tasks: list[str] | None = None
|
||||
train_overrides: dict[str, str] = field(default_factory=dict)
|
||||
eval_overrides: dict[str, str] = field(default_factory=dict)
|
||||
|
||||
|
||||
LIBERO_SUITES = ["libero_spatial", "libero_object", "libero_goal", "libero_10"]
|
||||
|
||||
# Each benchmark maps a human-readable name to its dataset and eval env.
|
||||
# ``dataset_repo_id`` can contain ``{hub_user}`` which is interpolated at
|
||||
# runtime from ``--hub-user``.
|
||||
BENCHMARK_REGISTRY: dict[str, BenchmarkEntry] = {
|
||||
"libero": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/libero",
|
||||
env_type="libero",
|
||||
env_task="libero_spatial",
|
||||
eval_tasks=LIBERO_SUITES,
|
||||
),
|
||||
"libero_plus": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/libero_plus",
|
||||
env_type="libero_plus",
|
||||
env_task="libero_spatial",
|
||||
eval_tasks=LIBERO_SUITES,
|
||||
),
|
||||
"metaworld": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/metaworld",
|
||||
env_type="metaworld",
|
||||
env_task="metaworld-push-v2",
|
||||
),
|
||||
"robocasa": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/robocasa",
|
||||
env_type="robocasa",
|
||||
env_task="PickPlaceCounterToCabinet",
|
||||
),
|
||||
"robomme": BenchmarkEntry(
|
||||
dataset_repo_id="{hub_user}/robomme",
|
||||
env_type="robomme",
|
||||
env_task="PickXtimes",
|
||||
),
|
||||
}
|
||||
|
||||
|
||||
def _policy_repo_id(hub_user: str, policy_name: str, benchmark: str) -> str:
|
||||
return f"{hub_user}/{policy_name}_{benchmark}"
|
||||
|
||||
|
||||
def _extra_keys(extra_args: list[str]) -> set[str]:
|
||||
"""Extract ``--key`` prefixes from extra CLI args for override detection."""
|
||||
keys: set[str] = set()
|
||||
for arg in extra_args:
|
||||
if arg.startswith("--") and "=" in arg:
|
||||
keys.add(arg.split("=", 1)[0])
|
||||
return keys
|
||||
|
||||
|
||||
def _build_train_cmd(
|
||||
benchmark: BenchmarkEntry,
|
||||
*,
|
||||
policy_path: str,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
num_gpus: int,
|
||||
steps: int,
|
||||
batch_size: int,
|
||||
eval_freq: int,
|
||||
save_freq: int,
|
||||
wandb: bool,
|
||||
extra_args: list[str],
|
||||
) -> list[str]:
|
||||
"""Build the ``accelerate launch lerobot-train`` command list."""
|
||||
lerobot_train = shutil.which("lerobot-train")
|
||||
if lerobot_train is None:
|
||||
raise RuntimeError("lerobot-train not found on PATH. Is lerobot installed?")
|
||||
|
||||
# Strip bare "--" separators that argparse may pass through
|
||||
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||
overridden = _extra_keys(cleaned_extra)
|
||||
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
dataset_id = benchmark.dataset_repo_id.format(hub_user=hub_user)
|
||||
|
||||
defaults: list[tuple[str, str]] = [
|
||||
("--policy.path", policy_path),
|
||||
("--dataset.repo_id", dataset_id),
|
||||
("--policy.repo_id", repo_id),
|
||||
("--env.type", benchmark.env_type),
|
||||
("--env.task", benchmark.env_task),
|
||||
("--steps", str(steps)),
|
||||
("--batch_size", str(batch_size)),
|
||||
("--eval_freq", str(eval_freq)),
|
||||
("--save_freq", str(save_freq)),
|
||||
("--output_dir", f"outputs/train/{policy_name}_{benchmark_name}"),
|
||||
("--job_name", f"{policy_name}_{benchmark_name}"),
|
||||
("--policy.push_to_hub", "true"),
|
||||
]
|
||||
if wandb:
|
||||
defaults.append(("--wandb.enable", "true"))
|
||||
for k, v in benchmark.train_overrides.items():
|
||||
defaults.append((f"--{k}", v))
|
||||
|
||||
cmd: list[str] = [
|
||||
"accelerate", "launch",
|
||||
"--multi_gpu",
|
||||
f"--num_processes={num_gpus}",
|
||||
lerobot_train,
|
||||
]
|
||||
for key, val in defaults:
|
||||
if key not in overridden:
|
||||
cmd.append(f"{key}={val}")
|
||||
cmd.extend(cleaned_extra)
|
||||
return cmd
|
||||
|
||||
|
||||
def _build_eval_cmd(
|
||||
benchmark: BenchmarkEntry,
|
||||
*,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
eval_task: str | None = None,
|
||||
n_episodes: int,
|
||||
batch_size_eval: int,
|
||||
extra_args: list[str],
|
||||
) -> list[str]:
|
||||
"""Build the ``lerobot-eval`` command list.
|
||||
|
||||
``eval_task`` overrides the benchmark's ``env_task`` so the same
|
||||
benchmark can be evaluated on multiple suites (e.g. LIBERO).
|
||||
"""
|
||||
lerobot_eval = shutil.which("lerobot-eval")
|
||||
if lerobot_eval is None:
|
||||
raise RuntimeError("lerobot-eval not found on PATH. Is lerobot installed?")
|
||||
|
||||
task = eval_task or benchmark.env_task
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
out_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=task)
|
||||
|
||||
cleaned_extra = [a for a in extra_args if a != "--"]
|
||||
overridden = _extra_keys(cleaned_extra)
|
||||
|
||||
defaults: list[tuple[str, str]] = [
|
||||
("--policy.path", repo_id),
|
||||
("--env.type", benchmark.env_type),
|
||||
("--env.task", task),
|
||||
("--eval.n_episodes", str(n_episodes)),
|
||||
("--eval.batch_size", str(batch_size_eval)),
|
||||
("--output_dir", out_dir),
|
||||
("--policy.device", "cuda"),
|
||||
]
|
||||
for k, v in benchmark.eval_overrides.items():
|
||||
defaults.append((f"--{k}", v))
|
||||
|
||||
cmd: list[str] = [lerobot_eval]
|
||||
for key, val in defaults:
|
||||
if key not in overridden:
|
||||
cmd.append(f"{key}={val}")
|
||||
cmd.extend(cleaned_extra)
|
||||
return cmd
|
||||
|
||||
|
||||
def _eval_output_dir(policy_name: str, benchmark_name: str, eval_task: str | None = None) -> Path:
|
||||
if eval_task:
|
||||
return Path(f"outputs/eval/{policy_name}_{benchmark_name}/{eval_task}")
|
||||
return Path(f"outputs/eval/{policy_name}_{benchmark_name}")
|
||||
|
||||
|
||||
def _run(cmd: list[str], *, dry_run: bool) -> None:
|
||||
log.info("Command: %s", " \\\n ".join(cmd))
|
||||
if dry_run:
|
||||
log.info("[dry-run] Skipping execution.")
|
||||
return
|
||||
result = subprocess.run(cmd, check=False)
|
||||
if result.returncode != 0:
|
||||
log.error("Command failed with exit code %d", result.returncode)
|
||||
sys.exit(result.returncode)
|
||||
|
||||
|
||||
def _push_eval_to_hub(
|
||||
*,
|
||||
hub_user: str,
|
||||
policy_name: str,
|
||||
benchmark_name: str,
|
||||
eval_task: str | None = None,
|
||||
dry_run: bool,
|
||||
) -> None:
|
||||
"""Upload eval results (metrics + videos) to the policy repo on the Hub."""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
repo_id = _policy_repo_id(hub_user, policy_name, benchmark_name)
|
||||
local_dir = _eval_output_dir(policy_name, benchmark_name, eval_task=eval_task)
|
||||
hub_path = f"eval/{eval_task}" if eval_task else f"eval/{benchmark_name}"
|
||||
|
||||
if not local_dir.exists():
|
||||
log.warning("Eval output dir %s does not exist, skipping hub upload.", local_dir)
|
||||
return
|
||||
|
||||
log.info("Uploading eval results from %s to %s (path_in_repo=%s)", local_dir, repo_id, hub_path)
|
||||
if dry_run:
|
||||
log.info("[dry-run] Skipping upload.")
|
||||
return
|
||||
|
||||
api = HfApi()
|
||||
api.upload_folder(
|
||||
folder_path=str(local_dir),
|
||||
repo_id=repo_id,
|
||||
path_in_repo=hub_path,
|
||||
repo_type="model",
|
||||
commit_message=f"Upload eval results for {eval_task or benchmark_name}",
|
||||
)
|
||||
|
||||
|
||||
def _resolve_benchmarks(names: str) -> list[tuple[str, BenchmarkEntry]]:
|
||||
out = []
|
||||
for name in names.split(","):
|
||||
name = name.strip()
|
||||
if name not in BENCHMARK_REGISTRY:
|
||||
available = ", ".join(BENCHMARK_REGISTRY)
|
||||
raise ValueError(f"Unknown benchmark '{name}'. Available: {available}")
|
||||
out.append((name, BENCHMARK_REGISTRY[name]))
|
||||
return out
|
||||
|
||||
|
||||
def cmd_train(args: argparse.Namespace) -> None:
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
for bname, bentry in benchmarks:
|
||||
log.info("=== Training on benchmark: %s ===", bname)
|
||||
cmd = _build_train_cmd(
|
||||
bentry,
|
||||
policy_path=args.policy_path,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
num_gpus=args.num_gpus,
|
||||
steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
wandb=args.wandb,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
|
||||
|
||||
def _run_eval_for_benchmark(
|
||||
bname: str,
|
||||
bentry: BenchmarkEntry,
|
||||
args: argparse.Namespace,
|
||||
) -> None:
|
||||
"""Run evaluation for a single benchmark, iterating over all its eval_tasks."""
|
||||
tasks = bentry.eval_tasks or [bentry.env_task]
|
||||
for task in tasks:
|
||||
log.info("=== Evaluating %s / %s ===", bname, task)
|
||||
cmd = _build_eval_cmd(
|
||||
bentry,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
eval_task=task if bentry.eval_tasks else None,
|
||||
n_episodes=args.n_episodes,
|
||||
batch_size_eval=args.batch_size_eval,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
if args.push_eval_to_hub:
|
||||
_push_eval_to_hub(
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
eval_task=task if bentry.eval_tasks else None,
|
||||
dry_run=args.dry_run,
|
||||
)
|
||||
|
||||
|
||||
def cmd_eval(args: argparse.Namespace) -> None:
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
for bname, bentry in benchmarks:
|
||||
_run_eval_for_benchmark(bname, bentry, args)
|
||||
|
||||
|
||||
def cmd_all(args: argparse.Namespace) -> None:
|
||||
"""Train on each benchmark, then evaluate each."""
|
||||
benchmarks = _resolve_benchmarks(args.benchmarks)
|
||||
|
||||
log.info("Phase 1: Training on %d benchmark(s)", len(benchmarks))
|
||||
for bname, bentry in benchmarks:
|
||||
log.info("=== Training on benchmark: %s ===", bname)
|
||||
cmd = _build_train_cmd(
|
||||
bentry,
|
||||
policy_path=args.policy_path,
|
||||
hub_user=args.hub_user,
|
||||
policy_name=args.policy_name,
|
||||
benchmark_name=bname,
|
||||
num_gpus=args.num_gpus,
|
||||
steps=args.steps,
|
||||
batch_size=args.batch_size,
|
||||
eval_freq=args.eval_freq,
|
||||
save_freq=args.save_freq,
|
||||
wandb=args.wandb,
|
||||
extra_args=args.extra,
|
||||
)
|
||||
_run(cmd, dry_run=args.dry_run)
|
||||
|
||||
log.info("Phase 2: Evaluating %d benchmark(s)", len(benchmarks))
|
||||
for bname, bentry in benchmarks:
|
||||
_run_eval_for_benchmark(bname, bentry, args)
|
||||
|
||||
|
||||
def _add_common_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument(
|
||||
"--benchmarks", required=True,
|
||||
help="Comma-separated benchmark names (e.g. libero_plus,robocasa,robomme).",
|
||||
)
|
||||
p.add_argument("--hub-user", required=True, help="HuggingFace Hub username.")
|
||||
p.add_argument(
|
||||
"--policy-name", default="smolvla",
|
||||
help="Short policy name used in repo IDs and output dirs (default: smolvla).",
|
||||
)
|
||||
p.add_argument("--dry-run", action="store_true", help="Print commands without executing.")
|
||||
|
||||
|
||||
def _add_train_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--policy-path", default="lerobot/smolvla_base", help="Pretrained policy path.")
|
||||
p.add_argument("--num-gpus", type=int, default=4, help="Number of GPUs.")
|
||||
p.add_argument("--steps", type=int, default=50_000, help="Total training steps.")
|
||||
p.add_argument("--batch-size", type=int, default=32, help="Per-GPU batch size.")
|
||||
p.add_argument("--eval-freq", type=int, default=10_000, help="Eval every N steps (0 to disable).")
|
||||
p.add_argument("--save-freq", type=int, default=10_000, help="Save checkpoint every N steps.")
|
||||
p.add_argument("--wandb", action="store_true", help="Enable Weights & Biases logging.")
|
||||
|
||||
|
||||
def _add_eval_args(p: argparse.ArgumentParser) -> None:
|
||||
p.add_argument("--n-episodes", type=int, default=50, help="Number of eval episodes.")
|
||||
p.add_argument("--batch-size-eval", type=int, default=10, help="Eval batch size (parallel envs).")
|
||||
p.add_argument(
|
||||
"--push-eval-to-hub", action="store_true",
|
||||
help="Upload eval results (metrics + videos) to the policy repo on the Hub.",
|
||||
)
|
||||
|
||||
|
||||
def build_parser() -> argparse.ArgumentParser:
|
||||
parser = argparse.ArgumentParser(
|
||||
prog="lerobot-benchmark",
|
||||
description="Train and evaluate policies across simulation benchmarks.",
|
||||
)
|
||||
sub = parser.add_subparsers(dest="command", required=True)
|
||||
|
||||
# train
|
||||
p_train = sub.add_parser("train", help="Train a policy on each selected benchmark.")
|
||||
_add_common_args(p_train)
|
||||
_add_train_args(p_train)
|
||||
p_train.set_defaults(func=cmd_train)
|
||||
|
||||
# eval
|
||||
p_eval = sub.add_parser("eval", help="Evaluate trained policies on each benchmark.")
|
||||
_add_common_args(p_eval)
|
||||
_add_eval_args(p_eval)
|
||||
p_eval.set_defaults(func=cmd_eval)
|
||||
|
||||
# all (train + eval)
|
||||
p_all = sub.add_parser("all", help="Train then evaluate on each benchmark.")
|
||||
_add_common_args(p_all)
|
||||
_add_train_args(p_all)
|
||||
_add_eval_args(p_all)
|
||||
p_all.set_defaults(func=cmd_all)
|
||||
|
||||
# list
|
||||
p_list = sub.add_parser("list", help="List available benchmarks.")
|
||||
p_list.set_defaults(func=lambda _args: _list_benchmarks())
|
||||
|
||||
return parser
|
||||
|
||||
|
||||
def _list_benchmarks() -> None:
|
||||
print("Available benchmarks:\n")
|
||||
for name, entry in BENCHMARK_REGISTRY.items():
|
||||
print(f" {name}")
|
||||
print(f" dataset: {entry.dataset_repo_id}")
|
||||
print(f" env: {entry.env_type}")
|
||||
if entry.eval_tasks:
|
||||
print(f" eval on: {', '.join(entry.eval_tasks)}")
|
||||
else:
|
||||
print(f" eval on: {entry.env_task}")
|
||||
print()
|
||||
|
||||
|
||||
def main() -> None:
|
||||
parser = build_parser()
|
||||
args, extra = parser.parse_known_args()
|
||||
args.extra = extra
|
||||
args.func(args)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -49,6 +49,9 @@ You can learn about the CLI options for this script in the `EvalPipelineConfig`
|
||||
import concurrent.futures as cf
|
||||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import subprocess
|
||||
import sys
|
||||
import threading
|
||||
import time
|
||||
from collections import defaultdict
|
||||
@@ -71,7 +74,9 @@ from tqdm import trange
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.eval import EvalPipelineConfig
|
||||
|
||||
from lerobot.envs.factory import make_env, make_env_pre_post_processors
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.envs.utils import (
|
||||
add_envs_task,
|
||||
check_env_attributes_and_types,
|
||||
@@ -82,6 +87,11 @@ from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.policies.pretrained import PreTrainedPolicy
|
||||
from lerobot.processor import PolicyAction, PolicyProcessorPipeline
|
||||
from lerobot.utils.constants import ACTION, DONE, OBS_STR, REWARD
|
||||
from lerobot.utils.hf_eval_results import (
|
||||
build_eval_results_rows,
|
||||
default_eval_date,
|
||||
upload_eval_results_yaml,
|
||||
)
|
||||
from lerobot.utils.import_utils import register_third_party_plugins
|
||||
from lerobot.utils.io_utils import write_video
|
||||
from lerobot.utils.random_utils import set_seed
|
||||
@@ -502,9 +512,178 @@ def _compile_episode_data(
|
||||
return data_dict
|
||||
|
||||
|
||||
def _serializable_config(obj: Any) -> Any:
|
||||
"""Recursively convert a config dict so it is JSON-serializable."""
|
||||
if isinstance(obj, dict):
|
||||
return {k: _serializable_config(v) for k, v in obj.items()}
|
||||
if isinstance(obj, (list, tuple)):
|
||||
return [_serializable_config(v) for v in obj]
|
||||
if isinstance(obj, Path):
|
||||
return str(obj)
|
||||
if isinstance(obj, (int, float, str, bool, type(None))):
|
||||
return obj
|
||||
return str(obj)
|
||||
|
||||
|
||||
def push_eval_to_hub(
|
||||
repo_id: str,
|
||||
output_dir: Path,
|
||||
info: dict,
|
||||
env_type: str,
|
||||
env_task: str | None,
|
||||
benchmark_dataset_id: str,
|
||||
source_url: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> str:
|
||||
"""Upload eval artifacts and `.eval_results` rows to the Hub.
|
||||
|
||||
Args:
|
||||
repo_id: HF model repo (e.g. "user/my_policy").
|
||||
output_dir: Local directory containing eval_info.json and videos/.
|
||||
info: The eval results dict (as returned by eval_policy_all).
|
||||
env_type: Environment type string (e.g. "libero_plus", "pusht").
|
||||
env_task: The env task string from eval config.
|
||||
benchmark_dataset_id: HF dataset id of the consolidated benchmark dataset.
|
||||
source_url: Optional source URL for `.eval_results` attribution.
|
||||
notes: Optional setup notes to include in `.eval_results`.
|
||||
|
||||
Returns:
|
||||
URL of the last Hub commit.
|
||||
"""
|
||||
from huggingface_hub import HfApi
|
||||
|
||||
api = HfApi()
|
||||
api.create_repo(repo_id=repo_id, exist_ok=True)
|
||||
|
||||
commit_url = ""
|
||||
|
||||
# 1. Upload eval_info.json
|
||||
eval_json_path = output_dir / "eval_info.json"
|
||||
if eval_json_path.exists():
|
||||
commit_url = api.upload_file(
|
||||
path_or_fileobj=str(eval_json_path),
|
||||
path_in_repo=f"eval/{env_type}/eval_info.json",
|
||||
repo_id=repo_id,
|
||||
commit_message=f"Upload eval results for {env_type}",
|
||||
)
|
||||
|
||||
# 2. Upload eval_config.json (policy, env, and eval settings used)
|
||||
eval_config_path = output_dir / "eval_config.json"
|
||||
if eval_config_path.exists():
|
||||
api.upload_file(
|
||||
path_or_fileobj=str(eval_config_path),
|
||||
path_in_repo=f"eval/{env_type}/eval_config.json",
|
||||
repo_id=repo_id,
|
||||
commit_message=f"Upload eval config for {env_type}",
|
||||
)
|
||||
|
||||
# 3. Upload rollout videos
|
||||
videos_dir = output_dir / "videos"
|
||||
if videos_dir.is_dir():
|
||||
api.upload_folder(
|
||||
folder_path=str(videos_dir),
|
||||
path_in_repo=f"eval/{env_type}/videos",
|
||||
repo_id=repo_id,
|
||||
commit_message=f"Upload eval rollout videos for {env_type}",
|
||||
)
|
||||
|
||||
# 4. Upload HF-native `.eval_results` rows (canonical leaderboard surface).
|
||||
rows = build_eval_results_rows(
|
||||
info=info,
|
||||
env_type=env_type,
|
||||
env_task=env_task,
|
||||
benchmark_dataset_id=benchmark_dataset_id,
|
||||
source_url=source_url,
|
||||
notes=notes,
|
||||
eval_date=default_eval_date(),
|
||||
)
|
||||
commit_url = upload_eval_results_yaml(
|
||||
api=api,
|
||||
repo_id=repo_id,
|
||||
rows=rows,
|
||||
env_type=env_type,
|
||||
env_task=env_task,
|
||||
output_dir=output_dir,
|
||||
)
|
||||
|
||||
logging.info(f"Eval results pushed to https://huggingface.co/{repo_id}")
|
||||
return commit_url
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
def eval_main(cfg: EvalPipelineConfig):
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
# Multi-instance orchestration only applies to local runtime.
|
||||
# For docker runtime, instance_count controls the number of env containers
|
||||
# spawned directly by run_eval_in_docker — no extra lerobot-eval processes needed.
|
||||
if cfg.eval.runtime == "local" and cfg.eval.instance_count > 1 and cfg.eval.instance_id == 0:
|
||||
_orchestrate_multi_instance_eval(cfg)
|
||||
else:
|
||||
_run_eval_worker(cfg)
|
||||
|
||||
|
||||
def _maybe_add_libero_plus_perturbation(info: dict, cfg: EvalPipelineConfig) -> None:
|
||||
if cfg.env.type != "libero_plus":
|
||||
return
|
||||
try:
|
||||
from lerobot.envs.libero import aggregate_by_perturbation, build_perturbation_index
|
||||
|
||||
suite_names = [s.strip() for s in cfg.env.task.split(",") if s.strip()]
|
||||
suite_indices = {s: build_perturbation_index(s) for s in suite_names}
|
||||
perturbation_results = aggregate_by_perturbation(info["per_task"], suite_indices)
|
||||
info["perturbation_results"] = perturbation_results
|
||||
print("\n=== Perturbation Results ===")
|
||||
for dim, stats in perturbation_results.items():
|
||||
print(f" {dim}: {stats['pc_success']:.1f}% ({stats['n_episodes']} episodes)")
|
||||
except Exception as exc:
|
||||
# Never fail a finished long-running eval on post-processing.
|
||||
print(f"WARNING: Failed to compute LIBERO-Plus perturbation breakdown: {exc}")
|
||||
print("Continuing with per-suite + overall metrics only.")
|
||||
|
||||
|
||||
def _save_eval_outputs(cfg: EvalPipelineConfig, info: dict) -> None:
|
||||
output_dir = Path(cfg.output_dir)
|
||||
output_dir.mkdir(parents=True, exist_ok=True)
|
||||
with open(output_dir / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
|
||||
eval_cfg_dict = _serializable_config(asdict(cfg))
|
||||
with open(output_dir / "eval_config.json", "w") as f:
|
||||
json.dump(eval_cfg_dict, f, indent=2)
|
||||
|
||||
|
||||
def _maybe_push_eval_outputs(cfg: EvalPipelineConfig, info: dict) -> None:
|
||||
if not cfg.push_to_hub:
|
||||
return
|
||||
repo_id = str(cfg.policy.pretrained_path)
|
||||
try:
|
||||
push_eval_to_hub(
|
||||
repo_id=repo_id,
|
||||
output_dir=Path(cfg.output_dir),
|
||||
info=info,
|
||||
env_type=cfg.env.type,
|
||||
env_task=cfg.env.task,
|
||||
benchmark_dataset_id=cfg.benchmark_dataset_id,
|
||||
source_url=cfg.eval_result_source_url,
|
||||
notes=cfg.eval_result_notes,
|
||||
)
|
||||
except Exception as exc:
|
||||
logging.warning("Failed to push eval artifacts/results to Hub: %s", exc)
|
||||
|
||||
|
||||
def _run_eval_worker(cfg: EvalPipelineConfig) -> dict:
|
||||
logging.info(pformat(asdict(cfg)))
|
||||
|
||||
if cfg.eval.runtime in ("docker", "multiprocess"):
|
||||
from lerobot.envs.docker_runtime import run_eval_in_docker, run_eval_multiprocess
|
||||
|
||||
if cfg.eval.runtime == "docker":
|
||||
run_eval_in_docker(cfg)
|
||||
else:
|
||||
run_eval_multiprocess(cfg)
|
||||
output_dir = Path(cfg.output_dir)
|
||||
with open(output_dir / "eval_info.json") as f:
|
||||
return json.load(f)
|
||||
|
||||
# Check device is available
|
||||
device = get_safe_torch_device(cfg.policy.device, log=True)
|
||||
@@ -548,34 +727,123 @@ def eval_main(cfg: EvalPipelineConfig):
|
||||
# Create environment-specific preprocessor and postprocessor (e.g., for LIBERO environments)
|
||||
env_preprocessor, env_postprocessor = make_env_pre_post_processors(env_cfg=cfg.env, policy_cfg=cfg.policy)
|
||||
|
||||
with torch.no_grad(), torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext():
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
try:
|
||||
with (
|
||||
torch.no_grad(),
|
||||
torch.autocast(device_type=device.type) if cfg.policy.use_amp else nullcontext(),
|
||||
):
|
||||
info = eval_policy_all(
|
||||
envs=envs,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
n_episodes=cfg.eval.n_episodes,
|
||||
max_episodes_rendered=10,
|
||||
videos_dir=Path(cfg.output_dir) / "videos",
|
||||
start_seed=cfg.seed,
|
||||
max_parallel_tasks=cfg.env.max_parallel_tasks,
|
||||
instance_count=cfg.eval.instance_count,
|
||||
instance_id=cfg.eval.instance_id,
|
||||
)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(info["overall"])
|
||||
|
||||
# Print per-suite stats
|
||||
for task_group, task_group_info in info.items():
|
||||
print(f"\nAggregated Metrics for {task_group}:")
|
||||
print(task_group_info)
|
||||
# Close all vec envs
|
||||
close_envs(envs)
|
||||
for key, val in info.get("per_group", {}).items():
|
||||
print(f"\nAggregated Metrics for {key}:")
|
||||
print(val)
|
||||
|
||||
# Save info
|
||||
with open(Path(cfg.output_dir) / "eval_info.json", "w") as f:
|
||||
json.dump(info, f, indent=2)
|
||||
_maybe_add_libero_plus_perturbation(info, cfg)
|
||||
finally:
|
||||
close_envs(envs)
|
||||
|
||||
_save_eval_outputs(cfg, info)
|
||||
_maybe_push_eval_outputs(cfg, info)
|
||||
|
||||
logging.info("End of eval")
|
||||
return info
|
||||
|
||||
|
||||
def _orchestrate_multi_instance_eval(cfg: EvalPipelineConfig) -> None:
|
||||
start_t = time.time()
|
||||
root_output_dir = Path(cfg.output_dir)
|
||||
instances_root = root_output_dir / "instances"
|
||||
instances_root.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
n_instances = cfg.eval.instance_count
|
||||
logging.info(f"Launching multi-instance eval with {n_instances} workers.")
|
||||
|
||||
# Spawn workers for shard 1..N-1, run shard 0 in-process.
|
||||
child_procs: list[tuple[int, subprocess.Popen]] = []
|
||||
argv = [
|
||||
arg
|
||||
for arg in sys.argv[1:]
|
||||
if not arg.startswith("--eval.instance_id=")
|
||||
and not arg.startswith("--output_dir=")
|
||||
and not arg.startswith("--push_to_hub=")
|
||||
]
|
||||
for i in range(1, n_instances):
|
||||
child_output_dir = instances_root / str(i)
|
||||
cmd = [
|
||||
sys.executable,
|
||||
"-m",
|
||||
"lerobot.scripts.lerobot_eval",
|
||||
*argv,
|
||||
f"--eval.instance_id={i}",
|
||||
f"--output_dir={child_output_dir}",
|
||||
"--push_to_hub=false",
|
||||
]
|
||||
logging.info("Starting eval worker %s/%s", i + 1, n_instances)
|
||||
child_procs.append((i, subprocess.Popen(cmd)))
|
||||
|
||||
cfg0 = deepcopy(cfg)
|
||||
cfg0.eval.instance_id = 0
|
||||
cfg0.push_to_hub = False
|
||||
cfg0.output_dir = instances_root / "0"
|
||||
_run_eval_worker(cfg0)
|
||||
|
||||
failed = []
|
||||
for idx, proc in child_procs:
|
||||
rc = proc.wait()
|
||||
if rc != 0:
|
||||
failed.append((idx, rc))
|
||||
if failed:
|
||||
raise RuntimeError(f"Multi-instance eval failed for workers: {failed}")
|
||||
|
||||
partial_infos: list[dict] = []
|
||||
for i in range(n_instances):
|
||||
info_path = instances_root / str(i) / "eval_info.json"
|
||||
with open(info_path) as f:
|
||||
partial_infos.append(json.load(f))
|
||||
|
||||
merged_per_task = []
|
||||
for info in partial_infos:
|
||||
merged_per_task.extend(info.get("per_task", []))
|
||||
merged_per_task.sort(key=lambda x: (x["task_group"], x["task_id"]))
|
||||
|
||||
# Merge videos from each shard into final output dir.
|
||||
merged_videos_dir = root_output_dir / "videos"
|
||||
for i in range(n_instances):
|
||||
shard_dir = instances_root / str(i)
|
||||
shard_videos = shard_dir / "videos"
|
||||
if shard_videos.is_dir():
|
||||
shutil.copytree(shard_videos, merged_videos_dir, dirs_exist_ok=True)
|
||||
old_prefix = str(shard_videos)
|
||||
new_prefix = str(merged_videos_dir)
|
||||
for entry in merged_per_task:
|
||||
paths = entry.get("metrics", {}).get("video_paths", [])
|
||||
entry["metrics"]["video_paths"] = [
|
||||
p.replace(old_prefix, new_prefix, 1) if p.startswith(old_prefix) else p for p in paths
|
||||
]
|
||||
|
||||
merged_info = _aggregate_eval_from_per_task(merged_per_task, total_eval_s=time.time() - start_t)
|
||||
_maybe_add_libero_plus_perturbation(merged_info, cfg)
|
||||
print("Overall Aggregated Metrics:")
|
||||
print(merged_info["overall"])
|
||||
|
||||
_save_eval_outputs(cfg, merged_info)
|
||||
_maybe_push_eval_outputs(cfg, merged_info)
|
||||
logging.info("End of eval")
|
||||
|
||||
|
||||
@@ -590,6 +858,179 @@ class TaskMetrics(TypedDict):
|
||||
ACC_KEYS = ("sum_rewards", "max_rewards", "successes", "video_paths")
|
||||
|
||||
|
||||
def _aggregate_eval_from_per_task(per_task_infos: list[dict], total_eval_s: float) -> dict:
|
||||
"""Aggregate eval metrics from per-task payloads."""
|
||||
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||
overall: dict[str, list] = {k: [] for k in ACC_KEYS}
|
||||
|
||||
def _append(group: str, key: str, value: Any):
|
||||
if value is None:
|
||||
return
|
||||
if isinstance(value, list):
|
||||
group_acc[group][key].extend(value)
|
||||
overall[key].extend(value)
|
||||
else:
|
||||
group_acc[group][key].append(value)
|
||||
overall[key].append(value)
|
||||
|
||||
for entry in per_task_infos:
|
||||
group = entry["task_group"]
|
||||
metrics = entry["metrics"]
|
||||
_append(group, "sum_rewards", metrics.get("sum_rewards"))
|
||||
_append(group, "max_rewards", metrics.get("max_rewards"))
|
||||
_append(group, "successes", metrics.get("successes"))
|
||||
paths = metrics.get("video_paths", [])
|
||||
if paths:
|
||||
group_acc[group]["video_paths"].extend(paths)
|
||||
overall["video_paths"].extend(paths)
|
||||
|
||||
def _agg_from_list(xs: list[float]) -> float:
|
||||
if not xs:
|
||||
return float("nan")
|
||||
arr = np.array(xs, dtype=float)
|
||||
return float(np.nanmean(arr))
|
||||
|
||||
groups_aggregated = {}
|
||||
for group, acc in group_acc.items():
|
||||
groups_aggregated[group] = {
|
||||
"avg_sum_reward": _agg_from_list(acc["sum_rewards"]),
|
||||
"avg_max_reward": _agg_from_list(acc["max_rewards"]),
|
||||
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
|
||||
"n_episodes": len(acc["sum_rewards"]),
|
||||
"video_paths": list(acc["video_paths"]),
|
||||
}
|
||||
|
||||
overall_agg = {
|
||||
"avg_sum_reward": _agg_from_list(overall["sum_rewards"]),
|
||||
"avg_max_reward": _agg_from_list(overall["max_rewards"]),
|
||||
"pc_success": _agg_from_list(overall["successes"]) * 100 if overall["successes"] else float("nan"),
|
||||
"n_episodes": len(overall["sum_rewards"]),
|
||||
"eval_s": total_eval_s,
|
||||
"eval_ep_s": total_eval_s / max(1, len(overall["sum_rewards"])),
|
||||
"video_paths": list(overall["video_paths"]),
|
||||
}
|
||||
|
||||
return {"per_task": per_task_infos, "per_group": groups_aggregated, "overall": overall_agg}
|
||||
|
||||
|
||||
def _eval_task_batch(
|
||||
batch: list[tuple[str, int, LazyVectorEnv]],
|
||||
policy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
postprocessor: PolicyProcessorPipeline[PolicyAction, PolicyAction],
|
||||
start_seed: int | None,
|
||||
max_episodes_rendered: int = 0,
|
||||
videos_dir: Path | None = None,
|
||||
) -> list[tuple[str, int, TaskMetrics]]:
|
||||
"""Evaluate N tasks in a single batched rollout for GPU efficiency.
|
||||
|
||||
Each task contributes one sub-env to a combined SyncVectorEnv so the policy
|
||||
processes all N observations in one forward pass per step.
|
||||
"""
|
||||
all_fns: list[Callable] = []
|
||||
task_slices: list[tuple[str, int, int, int]] = []
|
||||
offset = 0
|
||||
for task_group, task_id, lazy_env in batch:
|
||||
fns = lazy_env.factory_fns
|
||||
if not fns:
|
||||
continue
|
||||
start = offset
|
||||
offset += len(fns)
|
||||
all_fns.extend(fns)
|
||||
task_slices.append((task_group, task_id, start, offset))
|
||||
|
||||
if not all_fns:
|
||||
return []
|
||||
|
||||
env_cls = batch[0][2].env_cls
|
||||
combined_env = env_cls(all_fns)
|
||||
|
||||
try:
|
||||
seeds = None
|
||||
if start_seed is not None:
|
||||
seeds = list(range(start_seed, start_seed + combined_env.num_envs))
|
||||
|
||||
ep_frames: list[np.ndarray] = []
|
||||
|
||||
def render_frame(env: gym.vector.VectorEnv):
|
||||
if max_episodes_rendered <= 0:
|
||||
return
|
||||
n = min(max_episodes_rendered, env.num_envs)
|
||||
if isinstance(env, gym.vector.SyncVectorEnv):
|
||||
ep_frames.append(np.stack([env.envs[i].render() for i in range(n)]))
|
||||
elif isinstance(env, gym.vector.AsyncVectorEnv):
|
||||
ep_frames.append(np.stack(env.call("render")[:n]))
|
||||
|
||||
rollout_data = rollout(
|
||||
env=combined_env,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
seeds=seeds,
|
||||
render_callback=render_frame if max_episodes_rendered > 0 else None,
|
||||
)
|
||||
|
||||
n_steps = rollout_data["done"].shape[1]
|
||||
done_indices = torch.argmax(rollout_data["done"].to(int), dim=1)
|
||||
mask = (torch.arange(n_steps) <= einops.repeat(done_indices + 1, "b -> b s", s=n_steps)).int()
|
||||
batch_sum_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "sum")
|
||||
batch_max_rewards = einops.reduce((rollout_data["reward"] * mask), "b n -> b", "max")
|
||||
batch_successes = einops.reduce((rollout_data["success"] * mask), "b n -> b", "any")
|
||||
|
||||
video_paths_per_task: dict[tuple[str, int], list[str]] = defaultdict(list)
|
||||
if max_episodes_rendered > 0 and ep_frames and videos_dir:
|
||||
stacked = np.stack(ep_frames, axis=1) # (batch, time, h, w, c)
|
||||
rendered = 0
|
||||
threads = []
|
||||
for tg, tid, start_i, end_i in task_slices:
|
||||
if rendered >= max_episodes_rendered:
|
||||
break
|
||||
task_dir = videos_dir / f"{tg}_{tid}"
|
||||
task_dir.mkdir(parents=True, exist_ok=True)
|
||||
for env_idx in range(start_i, end_i):
|
||||
if rendered >= max_episodes_rendered:
|
||||
break
|
||||
episode_index = env_idx - start_i
|
||||
video_path = task_dir / f"eval_episode_{episode_index}.mp4"
|
||||
video_paths_per_task[(tg, tid)].append(str(video_path))
|
||||
di = done_indices[env_idx].item()
|
||||
thread = threading.Thread(
|
||||
target=write_video,
|
||||
args=(
|
||||
str(video_path),
|
||||
stacked[env_idx, : di + 1],
|
||||
combined_env.unwrapped.metadata["render_fps"],
|
||||
),
|
||||
)
|
||||
thread.start()
|
||||
threads.append(thread)
|
||||
rendered += 1
|
||||
for t in threads:
|
||||
t.join()
|
||||
|
||||
results: list[tuple[str, int, TaskMetrics]] = []
|
||||
for tg, tid, start_i, end_i in task_slices:
|
||||
results.append(
|
||||
(
|
||||
tg,
|
||||
tid,
|
||||
TaskMetrics(
|
||||
sum_rewards=batch_sum_rewards[start_i:end_i].tolist(),
|
||||
max_rewards=batch_max_rewards[start_i:end_i].tolist(),
|
||||
successes=batch_successes[start_i:end_i].tolist(),
|
||||
video_paths=video_paths_per_task.get((tg, tid), []),
|
||||
),
|
||||
)
|
||||
)
|
||||
return results
|
||||
finally:
|
||||
combined_env.close()
|
||||
|
||||
|
||||
def eval_one(
|
||||
env: gym.vector.VectorEnv,
|
||||
*,
|
||||
@@ -634,7 +1075,7 @@ def eval_one(
|
||||
def run_one(
|
||||
task_group: str,
|
||||
task_id: int,
|
||||
env,
|
||||
env: Any,
|
||||
*,
|
||||
policy,
|
||||
env_preprocessor,
|
||||
@@ -678,7 +1119,7 @@ def run_one(
|
||||
|
||||
|
||||
def eval_policy_all(
|
||||
envs: dict[str, dict[int, gym.vector.VectorEnv]],
|
||||
envs: dict[str, dict[int, Any]],
|
||||
policy,
|
||||
env_preprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
env_postprocessor: PolicyProcessorPipeline[dict[str, Any], dict[str, Any]],
|
||||
@@ -691,6 +1132,8 @@ def eval_policy_all(
|
||||
return_episode_data: bool = False,
|
||||
start_seed: int | None = None,
|
||||
max_parallel_tasks: int = 1,
|
||||
instance_count: int = 1,
|
||||
instance_id: int = 0,
|
||||
) -> dict:
|
||||
"""
|
||||
Evaluate a nested `envs` dict: {task_group: {task_id: vec_env}}.
|
||||
@@ -703,6 +1146,12 @@ def eval_policy_all(
|
||||
|
||||
# Flatten envs into list of (task_group, task_id, env)
|
||||
tasks = [(tg, tid, vec) for tg, group in envs.items() for tid, vec in group.items()]
|
||||
if instance_count > 1:
|
||||
total_tasks = len(tasks)
|
||||
tasks = [task for idx, task in enumerate(tasks) if idx % instance_count == instance_id]
|
||||
logging.info(
|
||||
f"Instance shard {instance_id + 1}/{instance_count}: {len(tasks)}/{total_tasks} tasks assigned."
|
||||
)
|
||||
|
||||
# accumulators: track metrics at both per-group level and across all groups
|
||||
group_acc: dict[str, dict[str, list]] = defaultdict(lambda: {k: [] for k in ACC_KEYS})
|
||||
@@ -748,59 +1197,77 @@ def eval_policy_all(
|
||||
start_seed=start_seed,
|
||||
)
|
||||
|
||||
if max_parallel_tasks <= 1:
|
||||
# sequential path (single accumulator path on the main thread)
|
||||
# NOTE: keeping a single-threaded accumulator avoids concurrent list appends or locks
|
||||
for task_group, task_id, env in tasks:
|
||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
||||
_accumulate_to(tg, metrics)
|
||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||
else:
|
||||
# threaded path: submit all tasks, consume completions on main thread and accumulate there
|
||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||
fut2meta = {}
|
||||
for task_group, task_id, env in tasks:
|
||||
fut = executor.submit(task_runner, task_group, task_id, env)
|
||||
fut2meta[fut] = (task_group, task_id)
|
||||
for fut in cf.as_completed(fut2meta):
|
||||
tg, tid, metrics = fut.result()
|
||||
all_lazy = all(isinstance(env, LazyVectorEnv) for _, _, env in tasks)
|
||||
single_factory_per_task = all(
|
||||
not isinstance(env, LazyVectorEnv) or env.num_factory_fns == 1 for _, _, env in tasks
|
||||
)
|
||||
can_batch = max_parallel_tasks > 1 and all_lazy and single_factory_per_task and n_episodes == 1
|
||||
|
||||
if can_batch:
|
||||
# Multi-task batched path: combine N tasks into one SyncVectorEnv per chunk
|
||||
# so the policy processes all N observations in a single forward pass per step.
|
||||
chunk_size = max_parallel_tasks
|
||||
logging.info(f"Task scheduler mode: batched_lazy (chunk_size={chunk_size})")
|
||||
n_chunks = (len(tasks) + chunk_size - 1) // chunk_size
|
||||
rendered_so_far = 0
|
||||
for chunk_idx in range(n_chunks):
|
||||
chunk = tasks[chunk_idx * chunk_size : (chunk_idx + 1) * chunk_size]
|
||||
render_budget = max(0, max_episodes_rendered - rendered_so_far)
|
||||
logging.info(
|
||||
f"Batch {chunk_idx + 1}/{n_chunks}: evaluating {len(chunk)} tasks "
|
||||
f"({chunk_idx * chunk_size + 1}–{chunk_idx * chunk_size + len(chunk)}/{len(tasks)})"
|
||||
)
|
||||
batch_results = _eval_task_batch(
|
||||
chunk,
|
||||
policy=policy,
|
||||
env_preprocessor=env_preprocessor,
|
||||
env_postprocessor=env_postprocessor,
|
||||
preprocessor=preprocessor,
|
||||
postprocessor=postprocessor,
|
||||
start_seed=start_seed,
|
||||
max_episodes_rendered=render_budget,
|
||||
videos_dir=videos_dir,
|
||||
)
|
||||
for tg, tid, metrics in batch_results:
|
||||
_accumulate_to(tg, metrics)
|
||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||
rendered_so_far += len(metrics.get("video_paths", []))
|
||||
|
||||
# compute aggregated metrics helper (robust to lists/scalars)
|
||||
def _agg_from_list(xs):
|
||||
if not xs:
|
||||
return float("nan")
|
||||
arr = np.array(xs, dtype=float)
|
||||
return float(np.nanmean(arr))
|
||||
if overall["successes"]:
|
||||
sr = np.nanmean(overall["successes"]) * 100
|
||||
logging.info(f" running success rate: {sr:.1f}%")
|
||||
elif max_parallel_tasks <= 1:
|
||||
logging.info("Task scheduler mode: sequential")
|
||||
for task_group, task_id, env in tasks:
|
||||
try:
|
||||
tg, tid, metrics = task_runner(task_group, task_id, env)
|
||||
_accumulate_to(tg, metrics)
|
||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||
finally:
|
||||
env.close()
|
||||
else:
|
||||
# Threaded fallback for cases where batched lazy mode cannot be used.
|
||||
if all_lazy and n_episodes != 1:
|
||||
logging.info("Task scheduler mode: threaded (lazy batching disabled because n_episodes != 1)")
|
||||
elif all_lazy and not single_factory_per_task:
|
||||
logging.info("Task scheduler mode: threaded (lazy batching disabled because eval.batch_size > 1)")
|
||||
else:
|
||||
logging.info(f"Task scheduler mode: threaded (max_workers={max_parallel_tasks})")
|
||||
with cf.ThreadPoolExecutor(max_workers=max_parallel_tasks) as executor:
|
||||
fut2meta: dict[cf.Future, tuple[str, int, Any]] = {}
|
||||
for task_group, task_id, env in tasks:
|
||||
fut = executor.submit(task_runner, task_group, task_id, env)
|
||||
fut2meta[fut] = (task_group, task_id, env)
|
||||
for fut in cf.as_completed(fut2meta):
|
||||
tg, tid, env = fut2meta[fut]
|
||||
try:
|
||||
_, _, metrics = fut.result()
|
||||
_accumulate_to(tg, metrics)
|
||||
per_task_infos.append({"task_group": tg, "task_id": tid, "metrics": metrics})
|
||||
finally:
|
||||
env.close()
|
||||
|
||||
# compute per-group aggregates
|
||||
groups_aggregated = {}
|
||||
for group, acc in group_acc.items():
|
||||
groups_aggregated[group] = {
|
||||
"avg_sum_reward": _agg_from_list(acc["sum_rewards"]),
|
||||
"avg_max_reward": _agg_from_list(acc["max_rewards"]),
|
||||
"pc_success": _agg_from_list(acc["successes"]) * 100 if acc["successes"] else float("nan"),
|
||||
"n_episodes": len(acc["sum_rewards"]),
|
||||
"video_paths": list(acc["video_paths"]),
|
||||
}
|
||||
|
||||
# overall aggregates
|
||||
overall_agg = {
|
||||
"avg_sum_reward": _agg_from_list(overall["sum_rewards"]),
|
||||
"avg_max_reward": _agg_from_list(overall["max_rewards"]),
|
||||
"pc_success": _agg_from_list(overall["successes"]) * 100 if overall["successes"] else float("nan"),
|
||||
"n_episodes": len(overall["sum_rewards"]),
|
||||
"eval_s": time.time() - start_t,
|
||||
"eval_ep_s": (time.time() - start_t) / max(1, len(overall["sum_rewards"])),
|
||||
"video_paths": list(overall["video_paths"]),
|
||||
}
|
||||
|
||||
return {
|
||||
"per_task": per_task_infos,
|
||||
"per_group": groups_aggregated,
|
||||
"overall": overall_agg,
|
||||
}
|
||||
return _aggregate_eval_from_per_task(per_task_infos, total_eval_s=time.time() - start_t)
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
@@ -0,0 +1,218 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Docker eval worker — runs inside a benchmark container.
|
||||
|
||||
Runs gym episodes for a sharded subset of the configured env's tasks, calling
|
||||
a remote HTTP policy inference server (running on the host GPU) for action chunks.
|
||||
|
||||
Usage (normally invoked by docker_runtime.run_eval_in_docker, not directly):
|
||||
lerobot-eval-worker \\
|
||||
--env.type=libero_plus \\
|
||||
--server_address=host.docker.internal:50051 \\
|
||||
--n_episodes=5 \\
|
||||
--seed=1000 \\
|
||||
--instance_id=0 \\
|
||||
--instance_count=2 \\
|
||||
--output_path=/results/worker_0.json
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import pickle # nosec B403 — internal serialisation only
|
||||
import time
|
||||
import urllib.request
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
import draccus
|
||||
import numpy as np
|
||||
|
||||
from lerobot import envs # noqa: F401 — registers all env subclasses
|
||||
from lerobot.envs.configs import EnvConfig
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.envs.utils import add_envs_task, preprocess_observation
|
||||
from lerobot.utils.utils import init_logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EvalWorkerConfig:
|
||||
env: EnvConfig
|
||||
# Address of the policy inference HTTP server on the host.
|
||||
server_address: str = "host.docker.internal:50051"
|
||||
# Number of episodes to run per task.
|
||||
n_episodes: int = 1
|
||||
# Starting random seed; episode i of a task uses seed + i.
|
||||
seed: int = 0
|
||||
# 0-indexed shard id for this worker.
|
||||
instance_id: int = 0
|
||||
# Total number of shards (workers).
|
||||
instance_count: int = 1
|
||||
# Path (inside the container) to write the JSON per-task results.
|
||||
output_path: Path = field(default_factory=lambda: Path("/results/worker.json"))
|
||||
# Timeout in seconds for each HTTP request to the policy server.
|
||||
server_timeout: float = 120.0
|
||||
|
||||
|
||||
def _call_server(server_address: str, obs_t: dict, timeout: float) -> np.ndarray:
|
||||
"""POST pickled obs to /predict_chunk, return numpy chunk (T, action_dim)."""
|
||||
body = pickle.dumps({"obs_t": obs_t}) # nosec B301
|
||||
req = urllib.request.Request(
|
||||
f"http://{server_address}/predict_chunk",
|
||||
data=body,
|
||||
method="POST",
|
||||
headers={"Content-Type": "application/octet-stream"},
|
||||
)
|
||||
with urllib.request.urlopen(req, timeout=timeout) as resp: # nosec B310
|
||||
return pickle.loads(resp.read()) # nosec B301
|
||||
|
||||
|
||||
def run_worker(cfg: EvalWorkerConfig) -> dict:
|
||||
"""Run cfg.n_episodes episodes per assigned task. Returns per-task results dict."""
|
||||
# Build envs: {task_group: {task_id: vec_env}}
|
||||
envs_dict = make_env(cfg.env, n_envs=1)
|
||||
|
||||
# Flatten to list of (task_group, task_id, env)
|
||||
tasks = [
|
||||
(task_group, task_id, vec)
|
||||
for task_group, group in envs_dict.items()
|
||||
for task_id, vec in group.items()
|
||||
]
|
||||
|
||||
# Shard: this worker handles tasks where index % instance_count == instance_id
|
||||
if cfg.instance_count > 1:
|
||||
total = len(tasks)
|
||||
assigned = {i for i in range(total) if i % cfg.instance_count == cfg.instance_id}
|
||||
for i, (_, _, env) in enumerate(tasks):
|
||||
if i not in assigned:
|
||||
try:
|
||||
env.close()
|
||||
except Exception:
|
||||
pass
|
||||
tasks = [t for i, t in enumerate(tasks) if i in assigned]
|
||||
logger.info(
|
||||
"Shard %d/%d: %d/%d tasks assigned.",
|
||||
cfg.instance_id + 1,
|
||||
cfg.instance_count,
|
||||
len(tasks),
|
||||
total,
|
||||
)
|
||||
|
||||
per_task: list[dict] = []
|
||||
|
||||
for task_group, task_id, env in tasks:
|
||||
sum_rewards: list[float] = []
|
||||
max_rewards: list[float] = []
|
||||
successes: list[bool] = []
|
||||
|
||||
for ep_idx in range(cfg.n_episodes):
|
||||
obs, _info = env.reset(seed=[cfg.seed + ep_idx])
|
||||
obs_t = preprocess_observation(obs)
|
||||
obs_t = add_envs_task(env, obs_t)
|
||||
|
||||
action_buffer: list[np.ndarray] = [] # each element: (1, action_dim)
|
||||
ep_rewards: list[float] = []
|
||||
ep_success = False
|
||||
done = np.zeros(1, dtype=bool)
|
||||
ep_steps = 0
|
||||
ep_infer_time = 0.0
|
||||
ep_env_time = 0.0
|
||||
ep_infer_calls = 0
|
||||
|
||||
while not np.all(done):
|
||||
if not action_buffer:
|
||||
t0 = time.monotonic()
|
||||
chunk_np = _call_server(cfg.server_address, obs_t, cfg.server_timeout)
|
||||
ep_infer_time += time.monotonic() - t0
|
||||
ep_infer_calls += 1
|
||||
action_buffer = [chunk_np[i : i + 1] for i in range(chunk_np.shape[0])]
|
||||
|
||||
action_np = action_buffer.pop(0)
|
||||
t0 = time.monotonic()
|
||||
obs, reward, terminated, truncated, info = env.step(action_np)
|
||||
ep_env_time += time.monotonic() - t0
|
||||
ep_steps += 1
|
||||
|
||||
done = terminated | truncated | done
|
||||
ep_rewards.append(float(np.mean(reward)))
|
||||
|
||||
if "final_info" in info:
|
||||
final_info = info["final_info"]
|
||||
if isinstance(final_info, dict) and "is_success" in final_info:
|
||||
ep_success = bool(final_info["is_success"][0])
|
||||
|
||||
if not np.all(done):
|
||||
obs_t = preprocess_observation(obs)
|
||||
obs_t = add_envs_task(env, obs_t)
|
||||
|
||||
sum_rewards.append(float(np.sum(ep_rewards)))
|
||||
max_rewards.append(float(np.max(ep_rewards)) if ep_rewards else 0.0)
|
||||
successes.append(ep_success)
|
||||
avg_env_ms = (ep_env_time / ep_steps * 1000) if ep_steps else 0
|
||||
avg_infer_ms = (ep_infer_time / ep_infer_calls * 1000) if ep_infer_calls else 0
|
||||
logger.info(
|
||||
"Task %s[%d] ep %d/%d — success=%s | %d steps, %d infer calls | "
|
||||
"env %.0fms/step, infer %.0fms/call (env %.1fs, infer %.1fs total)",
|
||||
task_group,
|
||||
task_id,
|
||||
ep_idx + 1,
|
||||
cfg.n_episodes,
|
||||
ep_success,
|
||||
ep_steps,
|
||||
ep_infer_calls,
|
||||
avg_env_ms,
|
||||
avg_infer_ms,
|
||||
ep_env_time,
|
||||
ep_infer_time,
|
||||
)
|
||||
|
||||
per_task.append(
|
||||
{
|
||||
"task_group": task_group,
|
||||
"task_id": task_id,
|
||||
"metrics": {
|
||||
"sum_rewards": sum_rewards,
|
||||
"max_rewards": max_rewards,
|
||||
"successes": successes,
|
||||
"video_paths": [],
|
||||
},
|
||||
}
|
||||
)
|
||||
env.close()
|
||||
|
||||
return {"per_task": per_task}
|
||||
|
||||
|
||||
def worker_main(cfg: EvalWorkerConfig) -> None:
|
||||
results = run_worker(cfg)
|
||||
output = Path(cfg.output_path)
|
||||
output.parent.mkdir(parents=True, exist_ok=True)
|
||||
output.write_text(json.dumps(results, indent=2))
|
||||
logger.info("Worker %d wrote results to %s", cfg.instance_id, output)
|
||||
|
||||
|
||||
def main() -> None:
|
||||
init_logging()
|
||||
cfg = draccus.parse(config_class=EvalWorkerConfig)
|
||||
worker_main(cfg)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,588 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Generate an interactive eval leaderboard from Hub model repos.
|
||||
|
||||
Reads eval results (as pushed by ``lerobot-eval --push_to_hub``) from one or
|
||||
more Hugging Face model repos and produces a self-contained HTML page with a
|
||||
sortable, filterable leaderboard table.
|
||||
|
||||
Usage::
|
||||
|
||||
# models.txt contains one HF repo ID per line (lines starting with # are ignored)
|
||||
lerobot-leaderboard models.txt --output leaderboard.html
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from dataclasses import dataclass, field
|
||||
from pathlib import Path
|
||||
|
||||
logging.basicConfig(level=logging.INFO, format="%(levelname)s: %(message)s")
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class ModelEntry:
|
||||
repo_id: str
|
||||
policy_type: str = "—"
|
||||
dataset: str = "—"
|
||||
training_steps: str = "—"
|
||||
batch_size: str = "—"
|
||||
# env_type -> {group_name -> pc_success}
|
||||
eval_results: dict[str, dict[str, float]] = field(default_factory=dict)
|
||||
# env_type -> overall pc_success
|
||||
eval_overall: dict[str, float] = field(default_factory=dict)
|
||||
# env_type -> n_episodes
|
||||
eval_n_episodes: dict[str, int] = field(default_factory=dict)
|
||||
|
||||
|
||||
def _try_download(repo_id: str, filename: str) -> dict | None:
|
||||
"""Download a JSON file from a Hub repo, return parsed dict or None."""
|
||||
from huggingface_hub import hf_hub_download
|
||||
from huggingface_hub.utils import EntryNotFoundError, RepositoryNotFoundError
|
||||
|
||||
try:
|
||||
path = hf_hub_download(repo_id, filename, repo_type="model")
|
||||
with open(path) as f:
|
||||
return json.load(f)
|
||||
except (EntryNotFoundError, RepositoryNotFoundError, OSError):
|
||||
return None
|
||||
|
||||
|
||||
def _list_eval_dirs(repo_id: str) -> list[str]:
|
||||
"""List env_type subdirectories under eval/ in a Hub repo."""
|
||||
from huggingface_hub import HfApi
|
||||
from huggingface_hub.utils import RepositoryNotFoundError
|
||||
|
||||
api = HfApi()
|
||||
try:
|
||||
files = api.list_repo_files(repo_id, repo_type="model")
|
||||
except RepositoryNotFoundError:
|
||||
return []
|
||||
|
||||
env_types = set()
|
||||
for f in files:
|
||||
if f.startswith("eval/") and f.count("/") >= 2:
|
||||
env_types.add(f.split("/")[1])
|
||||
return sorted(env_types)
|
||||
|
||||
|
||||
def fetch_model_entry(repo_id: str) -> ModelEntry:
|
||||
"""Fetch all available metadata and eval results for a single model."""
|
||||
entry = ModelEntry(repo_id=repo_id)
|
||||
|
||||
# Policy config
|
||||
policy_cfg = _try_download(repo_id, "config.json")
|
||||
if policy_cfg:
|
||||
entry.policy_type = policy_cfg.get("type", "—")
|
||||
|
||||
# Training config
|
||||
train_cfg = _try_download(repo_id, "train_config.json")
|
||||
if train_cfg:
|
||||
ds = train_cfg.get("dataset", {})
|
||||
entry.dataset = ds.get("repo_id", "—") if isinstance(ds, dict) else str(ds)
|
||||
entry.training_steps = str(train_cfg.get("steps", "—"))
|
||||
entry.batch_size = str(train_cfg.get("batch_size", "—"))
|
||||
|
||||
# Eval results per env_type
|
||||
for env_type in _list_eval_dirs(repo_id):
|
||||
eval_info = _try_download(repo_id, f"eval/{env_type}/eval_info.json")
|
||||
if not eval_info:
|
||||
continue
|
||||
|
||||
per_group = eval_info.get("per_group", {})
|
||||
group_results = {}
|
||||
for group_name, stats in per_group.items():
|
||||
group_results[group_name] = stats.get("pc_success", float("nan"))
|
||||
|
||||
entry.eval_results[env_type] = group_results
|
||||
|
||||
overall = eval_info.get("overall", {})
|
||||
entry.eval_overall[env_type] = overall.get("pc_success", float("nan"))
|
||||
entry.eval_n_episodes[env_type] = overall.get("n_episodes", 0)
|
||||
|
||||
return entry
|
||||
|
||||
|
||||
def fetch_all(repo_ids: list[str]) -> list[ModelEntry]:
|
||||
entries = []
|
||||
for repo_id in repo_ids:
|
||||
logger.info(f"Fetching {repo_id}...")
|
||||
try:
|
||||
entries.append(fetch_model_entry(repo_id))
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to fetch {repo_id}: {e}")
|
||||
return entries
|
||||
|
||||
|
||||
def collect_all_env_types(entries: list[ModelEntry]) -> list[str]:
|
||||
"""Collect all unique env_types across all entries, sorted."""
|
||||
env_types: set[str] = set()
|
||||
for e in entries:
|
||||
env_types.update(e.eval_overall.keys())
|
||||
return sorted(env_types)
|
||||
|
||||
|
||||
def collect_all_groups(entries: list[ModelEntry]) -> dict[str, list[str]]:
|
||||
"""Collect all unique group names per env_type."""
|
||||
groups: dict[str, set[str]] = {}
|
||||
for e in entries:
|
||||
for env_type, group_results in e.eval_results.items():
|
||||
groups.setdefault(env_type, set()).update(group_results.keys())
|
||||
return {k: sorted(v) for k, v in groups.items()}
|
||||
|
||||
|
||||
def build_html(entries: list[ModelEntry], title: str = "LeRobot Eval Leaderboard") -> str:
|
||||
env_types = collect_all_env_types(entries)
|
||||
all_groups = collect_all_groups(entries)
|
||||
|
||||
# Build column structure: fixed cols + per env_type (overall + per-group sub-columns)
|
||||
# We'll build the data as JSON and let JS handle rendering
|
||||
table_data = []
|
||||
for e in entries:
|
||||
row = {
|
||||
"repo_id": e.repo_id,
|
||||
"policy_type": e.policy_type,
|
||||
"dataset": e.dataset,
|
||||
"training_steps": e.training_steps,
|
||||
"batch_size": e.batch_size,
|
||||
}
|
||||
for env_type in env_types:
|
||||
overall = e.eval_overall.get(env_type)
|
||||
row[f"{env_type}__overall"] = round(overall, 1) if overall is not None else None
|
||||
n_ep = e.eval_n_episodes.get(env_type)
|
||||
row[f"{env_type}__n_episodes"] = n_ep if n_ep else None
|
||||
for group in all_groups.get(env_type, []):
|
||||
val = e.eval_results.get(env_type, {}).get(group)
|
||||
row[f"{env_type}__{group}"] = round(val, 1) if val is not None else None
|
||||
table_data.append(row)
|
||||
|
||||
# Build column definitions for the JS table
|
||||
columns_json = json.dumps(_build_column_defs(env_types, all_groups))
|
||||
data_json = json.dumps(table_data)
|
||||
|
||||
return _HTML_TEMPLATE.format(
|
||||
title=title,
|
||||
columns_json=columns_json,
|
||||
data_json=data_json,
|
||||
)
|
||||
|
||||
|
||||
def _build_column_defs(env_types: list[str], all_groups: dict[str, list[str]]) -> list[dict]:
|
||||
cols = [
|
||||
{"key": "repo_id", "label": "Model", "group": "Model Info", "sortable": True, "type": "link"},
|
||||
{"key": "policy_type", "label": "Policy", "group": "Model Info", "sortable": True, "type": "text"},
|
||||
{"key": "dataset", "label": "Dataset", "group": "Model Info", "sortable": True, "type": "text"},
|
||||
{
|
||||
"key": "training_steps",
|
||||
"label": "Steps",
|
||||
"group": "Training",
|
||||
"sortable": True,
|
||||
"type": "number",
|
||||
},
|
||||
{
|
||||
"key": "batch_size",
|
||||
"label": "Batch",
|
||||
"group": "Training",
|
||||
"sortable": True,
|
||||
"type": "number",
|
||||
},
|
||||
]
|
||||
for env_type in env_types:
|
||||
cols.append(
|
||||
{
|
||||
"key": f"{env_type}__overall",
|
||||
"label": "Overall %",
|
||||
"group": env_type,
|
||||
"sortable": True,
|
||||
"type": "pct",
|
||||
}
|
||||
)
|
||||
for group in all_groups.get(env_type, []):
|
||||
cols.append(
|
||||
{
|
||||
"key": f"{env_type}__{group}",
|
||||
"label": f"{group} %",
|
||||
"group": env_type,
|
||||
"sortable": True,
|
||||
"type": "pct",
|
||||
}
|
||||
)
|
||||
cols.append(
|
||||
{
|
||||
"key": f"{env_type}__n_episodes",
|
||||
"label": "Episodes",
|
||||
"group": env_type,
|
||||
"sortable": True,
|
||||
"type": "number",
|
||||
}
|
||||
)
|
||||
return cols
|
||||
|
||||
|
||||
_HTML_TEMPLATE = """\
|
||||
<!DOCTYPE html>
|
||||
<html lang="en">
|
||||
<head>
|
||||
<meta charset="utf-8" />
|
||||
<meta name="viewport" content="width=device-width, initial-scale=1" />
|
||||
<title>{title}</title>
|
||||
<style>
|
||||
:root {{
|
||||
--bg: #0d1117;
|
||||
--surface: #161b22;
|
||||
--border: #30363d;
|
||||
--text: #e6edf3;
|
||||
--text-muted: #8b949e;
|
||||
--accent: #58a6ff;
|
||||
--green: #3fb950;
|
||||
--yellow: #d29922;
|
||||
--red: #f85149;
|
||||
--header-bg: #1c2128;
|
||||
}}
|
||||
* {{ box-sizing: border-box; margin: 0; padding: 0; }}
|
||||
body {{
|
||||
font-family: -apple-system, BlinkMacSystemFont, 'Segoe UI', Helvetica, Arial, sans-serif;
|
||||
background: var(--bg);
|
||||
color: var(--text);
|
||||
padding: 24px;
|
||||
line-height: 1.5;
|
||||
}}
|
||||
h1 {{
|
||||
font-size: 1.75rem;
|
||||
font-weight: 600;
|
||||
margin-bottom: 8px;
|
||||
}}
|
||||
.subtitle {{
|
||||
color: var(--text-muted);
|
||||
font-size: 0.9rem;
|
||||
margin-bottom: 20px;
|
||||
}}
|
||||
.controls {{
|
||||
display: flex;
|
||||
gap: 12px;
|
||||
margin-bottom: 16px;
|
||||
flex-wrap: wrap;
|
||||
align-items: center;
|
||||
}}
|
||||
.controls input {{
|
||||
background: var(--surface);
|
||||
border: 1px solid var(--border);
|
||||
color: var(--text);
|
||||
padding: 8px 14px;
|
||||
border-radius: 6px;
|
||||
font-size: 0.875rem;
|
||||
width: 280px;
|
||||
outline: none;
|
||||
}}
|
||||
.controls input:focus {{ border-color: var(--accent); }}
|
||||
.controls label {{
|
||||
color: var(--text-muted);
|
||||
font-size: 0.8rem;
|
||||
display: flex;
|
||||
align-items: center;
|
||||
gap: 6px;
|
||||
cursor: pointer;
|
||||
}}
|
||||
.controls input[type="checkbox"] {{
|
||||
width: auto;
|
||||
accent-color: var(--accent);
|
||||
}}
|
||||
.table-wrap {{
|
||||
overflow-x: auto;
|
||||
border: 1px solid var(--border);
|
||||
border-radius: 8px;
|
||||
}}
|
||||
table {{
|
||||
border-collapse: collapse;
|
||||
width: 100%;
|
||||
font-size: 0.85rem;
|
||||
white-space: nowrap;
|
||||
}}
|
||||
thead th {{
|
||||
background: var(--header-bg);
|
||||
color: var(--text-muted);
|
||||
font-weight: 600;
|
||||
text-transform: uppercase;
|
||||
font-size: 0.7rem;
|
||||
letter-spacing: 0.05em;
|
||||
padding: 10px 14px;
|
||||
border-bottom: 1px solid var(--border);
|
||||
cursor: pointer;
|
||||
user-select: none;
|
||||
position: sticky;
|
||||
top: 0;
|
||||
z-index: 2;
|
||||
}}
|
||||
thead th:hover {{ color: var(--text); }}
|
||||
thead th .arrow {{ margin-left: 4px; opacity: 0.4; }}
|
||||
thead th.sorted .arrow {{ opacity: 1; color: var(--accent); }}
|
||||
thead tr.group-header th {{
|
||||
text-align: center;
|
||||
font-size: 0.75rem;
|
||||
letter-spacing: 0.08em;
|
||||
border-right: 1px solid var(--border);
|
||||
cursor: default;
|
||||
}}
|
||||
tbody tr {{ border-bottom: 1px solid var(--border); }}
|
||||
tbody tr:hover {{ background: rgba(88,166,255,0.06); }}
|
||||
td {{
|
||||
padding: 10px 14px;
|
||||
vertical-align: middle;
|
||||
}}
|
||||
td.model-cell a {{
|
||||
color: var(--accent);
|
||||
text-decoration: none;
|
||||
font-weight: 500;
|
||||
}}
|
||||
td.model-cell a:hover {{ text-decoration: underline; }}
|
||||
td.pct {{
|
||||
font-weight: 600;
|
||||
font-variant-numeric: tabular-nums;
|
||||
text-align: right;
|
||||
}}
|
||||
td.number {{
|
||||
text-align: right;
|
||||
font-variant-numeric: tabular-nums;
|
||||
}}
|
||||
.pct-high {{ color: var(--green); }}
|
||||
.pct-mid {{ color: var(--yellow); }}
|
||||
.pct-low {{ color: var(--red); }}
|
||||
.pct-na {{ color: var(--text-muted); font-weight: 400; }}
|
||||
.best-in-col {{ background: rgba(63,185,80,0.12); }}
|
||||
footer {{
|
||||
margin-top: 20px;
|
||||
color: var(--text-muted);
|
||||
font-size: 0.75rem;
|
||||
}}
|
||||
footer a {{ color: var(--accent); text-decoration: none; }}
|
||||
</style>
|
||||
</head>
|
||||
<body>
|
||||
|
||||
<h1>🤖 {title}</h1>
|
||||
<p class="subtitle">Click any column header to sort. Filter by typing below.</p>
|
||||
|
||||
<div class="controls">
|
||||
<input type="text" id="filter" placeholder="Filter models..." />
|
||||
<label><input type="checkbox" id="toggleGroups" checked /> Show sub-suite columns</label>
|
||||
</div>
|
||||
|
||||
<div class="table-wrap">
|
||||
<table id="leaderboard">
|
||||
<thead></thead>
|
||||
<tbody></tbody>
|
||||
</table>
|
||||
</div>
|
||||
|
||||
<footer>
|
||||
Generated by <a href="https://github.com/huggingface/lerobot">LeRobot</a> ·
|
||||
Eval results from <a href="https://huggingface.co">Hugging Face Hub</a>
|
||||
</footer>
|
||||
|
||||
<script>
|
||||
const COLUMNS = {columns_json};
|
||||
const DATA = {data_json};
|
||||
|
||||
let sortKey = null, sortAsc = true, showGroups = true;
|
||||
|
||||
function pctClass(v) {{
|
||||
if (v == null) return 'pct-na';
|
||||
if (v >= 70) return 'pct-high';
|
||||
if (v >= 40) return 'pct-mid';
|
||||
return 'pct-low';
|
||||
}}
|
||||
|
||||
function visibleCols() {{
|
||||
if (showGroups) return COLUMNS;
|
||||
return COLUMNS.filter(c => !c.key.includes('__') || c.key.endsWith('__overall') || c.key.endsWith('__n_episodes'));
|
||||
}}
|
||||
|
||||
function bestPerCol(rows) {{
|
||||
const best = {{}};
|
||||
for (const c of visibleCols()) {{
|
||||
if (c.type !== 'pct') continue;
|
||||
let max = -Infinity;
|
||||
for (const r of rows) {{
|
||||
const v = r[c.key];
|
||||
if (v != null && v > max) max = v;
|
||||
}}
|
||||
best[c.key] = max === -Infinity ? null : max;
|
||||
}}
|
||||
return best;
|
||||
}}
|
||||
|
||||
function render() {{
|
||||
const filter = document.getElementById('filter').value.toLowerCase();
|
||||
let rows = DATA.filter(r => r.repo_id.toLowerCase().includes(filter)
|
||||
|| (r.policy_type||'').toLowerCase().includes(filter)
|
||||
|| (r.dataset||'').toLowerCase().includes(filter));
|
||||
|
||||
if (sortKey) {{
|
||||
rows.sort((a, b) => {{
|
||||
let va = a[sortKey], vb = b[sortKey];
|
||||
if (va == null && vb == null) return 0;
|
||||
if (va == null) return 1;
|
||||
if (vb == null) return -1;
|
||||
if (typeof va === 'string') va = va.toLowerCase();
|
||||
if (typeof vb === 'string') vb = vb.toLowerCase();
|
||||
return sortAsc ? (va < vb ? -1 : va > vb ? 1 : 0) : (va > vb ? -1 : va < vb ? 1 : 0);
|
||||
}});
|
||||
}}
|
||||
|
||||
const cols = visibleCols();
|
||||
const best = bestPerCol(rows);
|
||||
|
||||
// Group header row
|
||||
const groups = [];
|
||||
let lastGroup = null;
|
||||
for (const c of cols) {{
|
||||
if (c.group !== lastGroup) {{
|
||||
groups.push({{ label: c.group, span: 1 }});
|
||||
lastGroup = c.group;
|
||||
}} else {{
|
||||
groups[groups.length - 1].span++;
|
||||
}}
|
||||
}}
|
||||
|
||||
const thead = document.querySelector('#leaderboard thead');
|
||||
thead.innerHTML = '';
|
||||
|
||||
// Group header
|
||||
const gtr = document.createElement('tr');
|
||||
gtr.className = 'group-header';
|
||||
for (const g of groups) {{
|
||||
const th = document.createElement('th');
|
||||
th.colSpan = g.span;
|
||||
th.textContent = g.label;
|
||||
gtr.appendChild(th);
|
||||
}}
|
||||
thead.appendChild(gtr);
|
||||
|
||||
// Column headers
|
||||
const htr = document.createElement('tr');
|
||||
for (const c of cols) {{
|
||||
const th = document.createElement('th');
|
||||
th.innerHTML = c.label + ' <span class="arrow">' + (sortKey === c.key ? (sortAsc ? '▲' : '▼') : '⇵') + '</span>';
|
||||
if (sortKey === c.key) th.classList.add('sorted');
|
||||
th.addEventListener('click', () => {{
|
||||
if (sortKey === c.key) {{ sortAsc = !sortAsc; }}
|
||||
else {{ sortKey = c.key; sortAsc = c.type === 'pct' ? false : true; }}
|
||||
render();
|
||||
}});
|
||||
htr.appendChild(th);
|
||||
}}
|
||||
thead.appendChild(htr);
|
||||
|
||||
// Body
|
||||
const tbody = document.querySelector('#leaderboard tbody');
|
||||
tbody.innerHTML = '';
|
||||
for (const r of rows) {{
|
||||
const tr = document.createElement('tr');
|
||||
for (const c of cols) {{
|
||||
const td = document.createElement('td');
|
||||
const v = r[c.key];
|
||||
if (c.type === 'link') {{
|
||||
td.className = 'model-cell';
|
||||
const a = document.createElement('a');
|
||||
a.href = 'https://huggingface.co/' + v;
|
||||
a.target = '_blank';
|
||||
a.textContent = v;
|
||||
td.appendChild(a);
|
||||
}} else if (c.type === 'pct') {{
|
||||
td.className = 'pct ' + pctClass(v);
|
||||
td.textContent = v != null ? v.toFixed(1) : '—';
|
||||
if (v != null && best[c.key] != null && v === best[c.key] && rows.length > 1) {{
|
||||
td.classList.add('best-in-col');
|
||||
}}
|
||||
}} else if (c.type === 'number') {{
|
||||
td.className = 'number';
|
||||
td.textContent = v != null ? v : '—';
|
||||
}} else {{
|
||||
td.textContent = v || '—';
|
||||
}}
|
||||
tr.appendChild(td);
|
||||
}}
|
||||
tbody.appendChild(tr);
|
||||
}}
|
||||
}}
|
||||
|
||||
document.getElementById('filter').addEventListener('input', render);
|
||||
document.getElementById('toggleGroups').addEventListener('change', (e) => {{
|
||||
showGroups = e.target.checked;
|
||||
render();
|
||||
}});
|
||||
|
||||
render();
|
||||
</script>
|
||||
</body>
|
||||
</html>
|
||||
"""
|
||||
|
||||
|
||||
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
|
||||
p = argparse.ArgumentParser(
|
||||
description="Generate an interactive eval leaderboard from Hub model repos.",
|
||||
)
|
||||
p.add_argument(
|
||||
"repo_ids_file",
|
||||
type=str,
|
||||
help="Path to a text file with one repo ID per line.",
|
||||
)
|
||||
p.add_argument(
|
||||
"--output",
|
||||
type=str,
|
||||
default="leaderboard.html",
|
||||
help="Output HTML file path (default: leaderboard.html).",
|
||||
)
|
||||
p.add_argument(
|
||||
"--title",
|
||||
type=str,
|
||||
default="LeRobot Eval Leaderboard",
|
||||
help="Title shown in the leaderboard page.",
|
||||
)
|
||||
return p.parse_args(argv)
|
||||
|
||||
|
||||
def main(argv: list[str] | None = None):
|
||||
args = parse_args(argv)
|
||||
|
||||
path = Path(args.repo_ids_file)
|
||||
if not path.exists():
|
||||
logger.error(f"File not found: {path}")
|
||||
sys.exit(1)
|
||||
repo_ids = [line.strip() for line in path.read_text().splitlines() if line.strip() and not line.startswith("#")]
|
||||
if not repo_ids:
|
||||
logger.error(f"No repo IDs found in {path}")
|
||||
sys.exit(1)
|
||||
|
||||
entries = fetch_all(repo_ids)
|
||||
if not entries:
|
||||
logger.error("No valid entries found.")
|
||||
sys.exit(1)
|
||||
|
||||
html = build_html(entries, title=args.title)
|
||||
out = Path(args.output)
|
||||
out.write_text(html)
|
||||
logger.info(f"Leaderboard written to {out.resolve()}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
@@ -0,0 +1,161 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Helpers for building and uploading HF-native `.eval_results` YAML rows.
|
||||
|
||||
The `.eval_results` format is consumed by the HuggingFace leaderboard surface.
|
||||
See https://huggingface.co/docs/evaluate/en/evaluation-on-hub for the spec.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
from datetime import date
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import yaml
|
||||
|
||||
|
||||
def default_eval_date() -> str:
|
||||
"""Return today's UTC date as an ISO-8601 string (YYYY-MM-DD)."""
|
||||
return date.today().isoformat()
|
||||
|
||||
|
||||
def build_eval_results_rows(
|
||||
*,
|
||||
info: dict,
|
||||
env_type: str,
|
||||
env_task: str | None,
|
||||
benchmark_dataset_id: str,
|
||||
eval_date: str,
|
||||
source_url: str | None = None,
|
||||
notes: str | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Build a list of `.eval_results` rows from an eval info dict.
|
||||
|
||||
Each row represents one (task_group, metric) combination. When no
|
||||
per-group breakdown is available a single overall row is emitted.
|
||||
|
||||
Args:
|
||||
info: The dict returned by ``_aggregate_eval_from_per_task`` / ``eval_policy_all``.
|
||||
Expected keys: ``overall``, optionally ``per_group``.
|
||||
env_type: Environment type string (e.g. ``"libero_plus"``).
|
||||
env_task: The env task string from eval config (may be ``None``).
|
||||
benchmark_dataset_id: HF dataset repo-id of the consolidated benchmark dataset.
|
||||
eval_date: ISO-8601 date string (use ``default_eval_date()``).
|
||||
source_url: Optional URL to the evaluation run / report.
|
||||
notes: Optional free-text notes about the evaluation setup.
|
||||
|
||||
Returns:
|
||||
A list of row dicts ready for serialisation with ``upload_eval_results_yaml``.
|
||||
"""
|
||||
rows: list[dict[str, Any]] = []
|
||||
task_name = env_task or env_type
|
||||
|
||||
def _safe(v: float) -> float:
|
||||
return 0.0 if (v is None or (isinstance(v, float) and math.isnan(v))) else float(v)
|
||||
|
||||
def _make_row(config_name: str, pc_success: float, n_episodes: int) -> dict[str, Any]:
|
||||
row: dict[str, Any] = {
|
||||
"task": {
|
||||
"type": "robotics",
|
||||
"name": task_name,
|
||||
},
|
||||
"dataset": {
|
||||
"name": benchmark_dataset_id,
|
||||
"type": benchmark_dataset_id,
|
||||
"config": config_name,
|
||||
"split": "test",
|
||||
},
|
||||
"metrics": [
|
||||
{
|
||||
"type": "success_rate",
|
||||
"value": _safe(pc_success),
|
||||
"name": "Success Rate (%)",
|
||||
},
|
||||
{
|
||||
"type": "n_episodes",
|
||||
"value": n_episodes,
|
||||
"name": "Number of Episodes",
|
||||
},
|
||||
],
|
||||
"evaluated_at": eval_date,
|
||||
}
|
||||
if source_url:
|
||||
row["source_url"] = source_url
|
||||
if notes:
|
||||
row["notes"] = notes
|
||||
return row
|
||||
|
||||
per_group: dict = info.get("per_group", {})
|
||||
if per_group:
|
||||
for group_name, group_metrics in per_group.items():
|
||||
rows.append(
|
||||
_make_row(
|
||||
config_name=group_name,
|
||||
pc_success=group_metrics.get("pc_success", float("nan")),
|
||||
n_episodes=group_metrics.get("n_episodes", 0),
|
||||
)
|
||||
)
|
||||
else:
|
||||
overall = info.get("overall", {})
|
||||
rows.append(
|
||||
_make_row(
|
||||
config_name=env_type,
|
||||
pc_success=overall.get("pc_success", float("nan")),
|
||||
n_episodes=overall.get("n_episodes", 0),
|
||||
)
|
||||
)
|
||||
|
||||
return rows
|
||||
|
||||
|
||||
def upload_eval_results_yaml(
|
||||
*,
|
||||
api: Any,
|
||||
repo_id: str,
|
||||
rows: list[dict[str, Any]],
|
||||
env_type: str,
|
||||
env_task: str | None,
|
||||
output_dir: Path,
|
||||
) -> str:
|
||||
"""Serialise ``rows`` to YAML and upload to the Hub model repo.
|
||||
|
||||
The file is written locally to ``output_dir/eval_results.yaml`` and
|
||||
then uploaded to ``eval/{env_type}/eval_results.yaml`` in ``repo_id``.
|
||||
|
||||
Args:
|
||||
api: An instantiated ``huggingface_hub.HfApi`` object.
|
||||
repo_id: HF model repo (e.g. ``"user/my_policy"``).
|
||||
rows: Rows produced by ``build_eval_results_rows``.
|
||||
env_type: Environment type string (used for the Hub path prefix).
|
||||
env_task: The env task string (unused, kept for API symmetry).
|
||||
output_dir: Local directory to write the YAML before uploading.
|
||||
|
||||
Returns:
|
||||
URL of the Hub commit containing the uploaded file.
|
||||
"""
|
||||
yaml_path = Path(output_dir) / "eval_results.yaml"
|
||||
yaml_path.write_text(yaml.dump({"eval_results": rows}, sort_keys=False, allow_unicode=True))
|
||||
|
||||
commit_url = api.upload_file(
|
||||
path_or_fileobj=str(yaml_path),
|
||||
path_in_repo=f"eval/{env_type}/eval_results.yaml",
|
||||
repo_id=repo_id,
|
||||
commit_message=f"Upload eval results YAML for {env_type}",
|
||||
)
|
||||
return str(commit_url)
|
||||
@@ -0,0 +1,62 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
|
||||
|
||||
class _DummyVectorEnv:
|
||||
def __init__(self):
|
||||
self.marker = "ok"
|
||||
self.closed = False
|
||||
|
||||
def close(self):
|
||||
self.closed = True
|
||||
|
||||
|
||||
def test_lazy_vec_env_materializes_only_on_access():
|
||||
created = []
|
||||
|
||||
def _make_env(fns):
|
||||
created.append(len(fns))
|
||||
return _DummyVectorEnv()
|
||||
|
||||
lazy = LazyVectorEnv(_make_env, [lambda: None, lambda: None])
|
||||
assert created == []
|
||||
assert lazy.num_factory_fns == 2
|
||||
|
||||
assert lazy.marker == "ok"
|
||||
assert created == [2]
|
||||
|
||||
# Second access should re-use the same materialized env.
|
||||
assert lazy.marker == "ok"
|
||||
assert created == [2]
|
||||
|
||||
|
||||
def test_lazy_vec_env_can_rematerialize_after_close():
|
||||
created = []
|
||||
|
||||
def _make_env(fns):
|
||||
created.append(len(fns))
|
||||
return _DummyVectorEnv()
|
||||
|
||||
lazy = LazyVectorEnv(_make_env, [lambda: None])
|
||||
lazy.materialize()
|
||||
assert created == [1]
|
||||
|
||||
lazy.close()
|
||||
lazy.materialize()
|
||||
assert created == [1, 1]
|
||||
|
||||
@@ -0,0 +1,244 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from lerobot.envs.lazy_vec_env import LazyVectorEnv
|
||||
from lerobot.scripts import lerobot_eval
|
||||
|
||||
|
||||
class _DummyTaskEnv:
|
||||
def __init__(self):
|
||||
self.close_calls = 0
|
||||
|
||||
def close(self):
|
||||
self.close_calls += 1
|
||||
|
||||
|
||||
class _TrackedLazyEnv(LazyVectorEnv):
|
||||
def __init__(self, n_factory_fns: int = 1):
|
||||
super().__init__(lambda fns: None, [lambda: None for _ in range(n_factory_fns)])
|
||||
self.close_calls = 0
|
||||
|
||||
def close(self):
|
||||
self.close_calls += 1
|
||||
super().close()
|
||||
|
||||
|
||||
def _fake_metrics():
|
||||
return {
|
||||
"sum_rewards": [1.0],
|
||||
"max_rewards": [1.0],
|
||||
"successes": [True],
|
||||
"video_paths": [],
|
||||
}
|
||||
|
||||
|
||||
def test_eval_policy_all_sequential_closes_envs(monkeypatch):
|
||||
def _fake_run_one(task_group, task_id, env, **kwargs): # noqa: ARG001
|
||||
return task_group, task_id, _fake_metrics()
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _fake_run_one)
|
||||
env_a = _DummyTaskEnv()
|
||||
env_b = _DummyTaskEnv()
|
||||
envs = {"suite": {0: env_a, 1: env_b}}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=1,
|
||||
max_parallel_tasks=1,
|
||||
)
|
||||
|
||||
assert env_a.close_calls == 1
|
||||
assert env_b.close_calls == 1
|
||||
assert result["overall"]["n_episodes"] == 2
|
||||
|
||||
|
||||
def test_eval_policy_all_threaded_fallback_closes_envs(monkeypatch):
|
||||
def _fake_run_one(task_group, task_id, env, **kwargs): # noqa: ARG001
|
||||
return task_group, task_id, _fake_metrics()
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _fake_run_one)
|
||||
env_a = _DummyTaskEnv()
|
||||
env_b = _DummyTaskEnv()
|
||||
env_c = _DummyTaskEnv()
|
||||
envs = {"suite": {0: env_a, 1: env_b, 2: env_c}}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=1,
|
||||
max_parallel_tasks=2,
|
||||
)
|
||||
|
||||
assert env_a.close_calls == 1
|
||||
assert env_b.close_calls == 1
|
||||
assert env_c.close_calls == 1
|
||||
assert result["overall"]["n_episodes"] == 3
|
||||
|
||||
|
||||
def test_eval_policy_all_uses_batched_lazy_mode(monkeypatch):
|
||||
def _run_one_should_not_be_called(*args, **kwargs):
|
||||
raise AssertionError("run_one should not run in batched lazy mode")
|
||||
|
||||
chunk_sizes = []
|
||||
|
||||
def _fake_eval_task_batch(chunk, **kwargs): # noqa: ARG001
|
||||
chunk_sizes.append(len(chunk))
|
||||
return [(tg, tid, _fake_metrics()) for tg, tid, _ in chunk]
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _run_one_should_not_be_called)
|
||||
monkeypatch.setattr(lerobot_eval, "_eval_task_batch", _fake_eval_task_batch)
|
||||
|
||||
envs = {
|
||||
"suite": {
|
||||
0: LazyVectorEnv(lambda fns: None, [lambda: None]),
|
||||
1: LazyVectorEnv(lambda fns: None, [lambda: None]),
|
||||
2: LazyVectorEnv(lambda fns: None, [lambda: None]),
|
||||
}
|
||||
}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=1,
|
||||
max_parallel_tasks=2,
|
||||
)
|
||||
|
||||
assert chunk_sizes == [2, 1]
|
||||
assert result["overall"]["n_episodes"] == 3
|
||||
|
||||
|
||||
def test_eval_policy_all_disables_batched_lazy_when_n_episodes_not_one(monkeypatch):
|
||||
def _fake_run_one(task_group, task_id, env, **kwargs): # noqa: ARG001
|
||||
return task_group, task_id, _fake_metrics()
|
||||
|
||||
def _batch_should_not_run(*args, **kwargs):
|
||||
raise AssertionError("_eval_task_batch should not run when n_episodes != 1")
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _fake_run_one)
|
||||
monkeypatch.setattr(lerobot_eval, "_eval_task_batch", _batch_should_not_run)
|
||||
|
||||
env_a = _TrackedLazyEnv()
|
||||
env_b = _TrackedLazyEnv()
|
||||
envs = {"suite": {0: env_a, 1: env_b}}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=2,
|
||||
max_parallel_tasks=2,
|
||||
)
|
||||
|
||||
assert env_a.close_calls == 1
|
||||
assert env_b.close_calls == 1
|
||||
assert result["overall"]["n_episodes"] == 2
|
||||
|
||||
|
||||
def test_eval_policy_all_disables_batched_lazy_when_batch_size_above_one(monkeypatch):
|
||||
def _fake_run_one(task_group, task_id, env, **kwargs): # noqa: ARG001
|
||||
return task_group, task_id, _fake_metrics()
|
||||
|
||||
def _batch_should_not_run(*args, **kwargs):
|
||||
raise AssertionError("_eval_task_batch should not run when eval.batch_size > 1")
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _fake_run_one)
|
||||
monkeypatch.setattr(lerobot_eval, "_eval_task_batch", _batch_should_not_run)
|
||||
|
||||
env_a = _TrackedLazyEnv(n_factory_fns=2)
|
||||
env_b = _TrackedLazyEnv(n_factory_fns=2)
|
||||
envs = {"suite": {0: env_a, 1: env_b}}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=1,
|
||||
max_parallel_tasks=2,
|
||||
)
|
||||
|
||||
assert env_a.close_calls == 1
|
||||
assert env_b.close_calls == 1
|
||||
assert result["overall"]["n_episodes"] == 2
|
||||
|
||||
|
||||
def test_eval_policy_all_applies_instance_sharding(monkeypatch):
|
||||
called = []
|
||||
|
||||
def _fake_run_one(task_group, task_id, env, **kwargs): # noqa: ARG001
|
||||
called.append(task_id)
|
||||
return task_group, task_id, _fake_metrics()
|
||||
|
||||
monkeypatch.setattr(lerobot_eval, "run_one", _fake_run_one)
|
||||
envs = {"suite": {0: _DummyTaskEnv(), 1: _DummyTaskEnv(), 2: _DummyTaskEnv(), 3: _DummyTaskEnv()}}
|
||||
|
||||
result = lerobot_eval.eval_policy_all(
|
||||
envs=envs,
|
||||
policy=None,
|
||||
env_preprocessor=None,
|
||||
env_postprocessor=None,
|
||||
preprocessor=None,
|
||||
postprocessor=None,
|
||||
n_episodes=1,
|
||||
max_parallel_tasks=1,
|
||||
instance_count=2,
|
||||
instance_id=1,
|
||||
)
|
||||
|
||||
assert called == [1, 3]
|
||||
assert result["overall"]["n_episodes"] == 2
|
||||
|
||||
|
||||
def test_aggregate_eval_from_per_task_merges_groups_and_overall():
|
||||
per_task = [
|
||||
{
|
||||
"task_group": "a",
|
||||
"task_id": 0,
|
||||
"metrics": {"sum_rewards": [1.0], "max_rewards": [2.0], "successes": [True], "video_paths": ["v0"]},
|
||||
},
|
||||
{
|
||||
"task_group": "b",
|
||||
"task_id": 1,
|
||||
"metrics": {"sum_rewards": [3.0], "max_rewards": [4.0], "successes": [False], "video_paths": []},
|
||||
},
|
||||
]
|
||||
|
||||
merged = lerobot_eval._aggregate_eval_from_per_task(per_task, total_eval_s=10.0)
|
||||
|
||||
assert merged["overall"]["n_episodes"] == 2
|
||||
assert merged["overall"]["avg_sum_reward"] == 2.0
|
||||
assert merged["overall"]["pc_success"] == 50.0
|
||||
assert merged["overall"]["eval_s"] == 10.0
|
||||
assert set(merged["per_group"]) == {"a", "b"}
|
||||
|
||||
@@ -0,0 +1,34 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from pathlib import Path
|
||||
|
||||
from lerobot.scripts import lerobot_eval_worker
|
||||
|
||||
|
||||
def test_worker_main_writes_results(monkeypatch, tmp_path: Path):
|
||||
cfg = lerobot_eval_worker.EvalWorkerConfig(
|
||||
env="pusht", # type: ignore[arg-type]
|
||||
instance_id=3,
|
||||
output_path=tmp_path / "worker.json",
|
||||
)
|
||||
|
||||
monkeypatch.setattr(lerobot_eval_worker, "run_worker", lambda _cfg: {"per_task": []})
|
||||
|
||||
lerobot_eval_worker.worker_main(cfg)
|
||||
|
||||
assert cfg.output_path.exists()
|
||||
assert cfg.output_path.read_text().strip() == '{\n "per_task": []\n}'
|
||||
@@ -0,0 +1,176 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
"""Tests for RoboCasa LeRobot integration.
|
||||
|
||||
Requires: robocasa installed + kitchen assets downloaded.
|
||||
Tests are skipped automatically if robocasa is not available.
|
||||
"""
|
||||
from __future__ import annotations
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
|
||||
# Skip entire module if robocasa is not installed or assets are missing
|
||||
robocasa = pytest.importorskip("robocasa", reason="robocasa not installed")
|
||||
|
||||
from lerobot.envs.robocasa import ACTION_DIM, STATE_DIM, CAM_KEY_TO_NAME, RoboCasaEnv, create_robocasa_envs
|
||||
|
||||
# The 5 benchmark tasks (3 short + 2 long)
|
||||
BENCHMARK_TASKS = [
|
||||
"PickPlaceCounterToCabinet", # short
|
||||
"PrepareToast", # short
|
||||
"CoffeeSetupMug", # short
|
||||
"PrepareCoffee", # long
|
||||
"RestockPantry", # long
|
||||
]
|
||||
SHORT_TASKS = BENCHMARK_TASKS[:3]
|
||||
LONG_TASKS = BENCHMARK_TASKS[3:]
|
||||
|
||||
IMAGE_SIZE = 64 # small for fast tests
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def single_env():
|
||||
"""Shared env instance for lightweight tests."""
|
||||
env = RoboCasaEnv(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
yield env
|
||||
env.close()
|
||||
|
||||
|
||||
class TestRoboCasaEnvSpaces:
|
||||
def test_action_space_is_flat_box(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.action_space, gym.spaces.Box)
|
||||
assert single_env.action_space.shape == (ACTION_DIM,)
|
||||
assert single_env.action_space.dtype == np.float32
|
||||
|
||||
def test_action_bounds(self, single_env):
|
||||
assert np.all(single_env.action_space.low == -1.0)
|
||||
assert np.all(single_env.action_space.high == 1.0)
|
||||
|
||||
def test_observation_space_has_pixels_and_state(self, single_env):
|
||||
import gymnasium as gym
|
||||
|
||||
assert isinstance(single_env.observation_space, gym.spaces.Dict)
|
||||
assert "pixels" in single_env.observation_space.spaces
|
||||
assert "robot_state" in single_env.observation_space.spaces
|
||||
|
||||
def test_observation_space_cameras(self, single_env):
|
||||
pixels_space = single_env.observation_space["pixels"]
|
||||
expected_cams = set(CAM_KEY_TO_NAME.values())
|
||||
assert set(pixels_space.spaces.keys()) == expected_cams
|
||||
|
||||
def test_state_dim(self, single_env):
|
||||
state_space = single_env.observation_space["robot_state"]
|
||||
assert state_space.shape == (STATE_DIM,)
|
||||
|
||||
|
||||
class TestRoboCasaEnvReset:
|
||||
def test_reset_returns_obs_and_info(self, single_env):
|
||||
obs, info = single_env.reset()
|
||||
assert isinstance(obs, dict)
|
||||
assert isinstance(info, dict)
|
||||
|
||||
def test_reset_obs_has_pixels(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert "pixels" in obs
|
||||
for cam_name in CAM_KEY_TO_NAME.values():
|
||||
assert cam_name in obs["pixels"], f"Missing camera: {cam_name}"
|
||||
|
||||
def test_reset_obs_image_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
for cam_name, img in obs["pixels"].items():
|
||||
assert img.shape == (IMAGE_SIZE, IMAGE_SIZE, 3), f"Bad shape for {cam_name}: {img.shape}"
|
||||
assert img.dtype == np.uint8
|
||||
|
||||
def test_reset_obs_state_shape(self, single_env):
|
||||
obs, _ = single_env.reset()
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert obs["robot_state"].dtype == np.float32
|
||||
|
||||
def test_reset_info_has_task(self, single_env):
|
||||
_, info = single_env.reset()
|
||||
assert "task" in info
|
||||
assert info["task"] == "PickPlaceCounterToCabinet"
|
||||
|
||||
|
||||
class TestRoboCasaEnvStep:
|
||||
def test_step_10_random_actions(self, single_env):
|
||||
single_env.reset()
|
||||
for _ in range(10):
|
||||
action = single_env.action_space.sample()
|
||||
obs, reward, terminated, truncated, info = single_env.step(action)
|
||||
assert obs["robot_state"].shape == (STATE_DIM,)
|
||||
assert isinstance(reward, float)
|
||||
assert isinstance(terminated, bool)
|
||||
assert isinstance(truncated, bool)
|
||||
|
||||
def test_step_bad_action_raises(self, single_env):
|
||||
single_env.reset()
|
||||
with pytest.raises(ValueError, match="Expected 1-D action"):
|
||||
single_env.step(np.zeros((2, ACTION_DIM)))
|
||||
|
||||
def test_step_info_has_is_success(self, single_env):
|
||||
single_env.reset()
|
||||
_, _, _, _, info = single_env.step(single_env.action_space.sample())
|
||||
assert "is_success" in info
|
||||
|
||||
|
||||
class TestRoboCasaConfig:
|
||||
def test_robocasa_env_config(self):
|
||||
from lerobot.envs.configs import RoboCasaEnv as RoboCasaEnvConfig
|
||||
from lerobot.configs.types import FeatureType
|
||||
|
||||
cfg = RoboCasaEnvConfig(task="PickPlaceCounterToCabinet", image_size=IMAGE_SIZE)
|
||||
assert cfg.type == "robocasa"
|
||||
# action feature
|
||||
assert "action" in cfg.features
|
||||
assert cfg.features["action"].shape == (ACTION_DIM,)
|
||||
# camera features
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert cam in cfg.features
|
||||
assert cfg.features[cam].type == FeatureType.VISUAL
|
||||
assert cfg.features[cam].shape == (IMAGE_SIZE, IMAGE_SIZE, 3)
|
||||
# state feature
|
||||
assert "robot_state" in cfg.features
|
||||
assert cfg.features["robot_state"].shape == (STATE_DIM,)
|
||||
|
||||
def test_make_env_config_robocasa(self):
|
||||
from lerobot.envs.factory import make_env_config
|
||||
cfg = make_env_config("robocasa", task="PickPlaceCounterToCabinet")
|
||||
assert cfg.type == "robocasa"
|
||||
|
||||
|
||||
class TestRoboCasaProcessorStep:
|
||||
def test_processor_remaps_keys(self):
|
||||
import torch
|
||||
from lerobot.processor.env_processor import RoboCasaProcessorStep
|
||||
from lerobot.utils.constants import OBS_IMAGES, OBS_STATE
|
||||
|
||||
step = RoboCasaProcessorStep()
|
||||
B = 2
|
||||
obs = {
|
||||
f"{OBS_IMAGES}.agentview_left": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.agentview_right": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"{OBS_IMAGES}.eye_in_hand": torch.zeros(B, 3, IMAGE_SIZE, IMAGE_SIZE),
|
||||
f"observation.robot_state": torch.zeros(B, STATE_DIM),
|
||||
}
|
||||
out = step._process_observation(obs)
|
||||
assert OBS_STATE in out
|
||||
assert out[OBS_STATE].dtype == torch.float32
|
||||
for cam in ("agentview_left", "agentview_right", "eye_in_hand"):
|
||||
assert f"{OBS_IMAGES}.{cam}" in out
|
||||
Reference in New Issue
Block a user