Compare commits

..

5 Commits

Author SHA1 Message Date
Steven Palma 1ec9392bcb chore(style): pre-commit envs 2026-02-24 15:03:36 +01:00
Steven Palma 84b34ae75c Merge branch 'main' into envs/support-more-args
Signed-off-by: Steven Palma <imstevenpmwork@ieee.org>
2026-02-24 15:01:17 +01:00
Jade Choghari ff267c772b allow lerobot-eval to work with kwargs 2025-12-26 17:39:03 +00:00
Jade Choghari 652b1b854d Merge branch 'main' into envs/support-more-args 2025-12-23 16:07:22 +03:00
Jade Choghari 8831b3c47b add changes 2025-12-08 11:11:38 +01:00
25 changed files with 217 additions and 884 deletions
-2
View File
@@ -173,8 +173,6 @@ jobs:
shell: bash
working-directory: /lerobot
steps:
- name: Fix ptxas permissions
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
- name: Run pytest on GPU
run: pytest tests -vv --maxfail=10
- name: Run end-to-end tests
-1
View File
@@ -1,3 +1,2 @@
include src/lerobot/templates/lerobot_modelcard_template.md
include src/lerobot/datasets/card_template.md
include src/lerobot/envs/metaworld_config.json
-2
View File
@@ -85,8 +85,6 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
RUN uv pip install --no-cache ".[all]"
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
# Copy the rest of the application source code
# Make sure to have the git-LFS files for testing
COPY --chown=user_lerobot:user_lerobot . .
+77 -3
View File
@@ -55,7 +55,8 @@ To make your environment loadable from the Hub, your repository must contain at
**`env.py`** (or custom Python file)
- Must expose a `make_env(n_envs: int, use_async_envs: bool)` function
- Must expose a `make_env(n_envs: int, use_async_envs: bool, **kwargs)` function
- The function should accept `**kwargs` to allow users to pass custom configurations
- This function should return one of:
- A `gym.vector.VectorEnv` (most common)
- A single `gym.Env` (will be automatically wrapped)
@@ -99,6 +100,8 @@ Create an `env.py` file with a `make_env` function:
```python
# env.py
import gymnasium as gym
from pathlib import Path
from typing import Any
def make_env(n_envs: int = 1, use_async_envs: bool = False):
"""
@@ -250,6 +253,76 @@ envs_dict = make_env(
)
```
### Custom Configuration via kwargs
Hub environments can accept custom configurations through keyword arguments. This is useful for parameterizing tasks, loading different objects, or overriding default settings:
```python
from pathlib import Path
# Pass a config file path
envs_dict = make_env(
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
n_envs=4,
trust_remote_code=True,
config_path=Path("/path/to/my_config.yaml"),
)
# Pass config overrides as a dictionary
envs_dict = make_env(
"nvkartik/isaaclab-arena-envs:envs/microwave_g1.py",
n_envs=4,
trust_remote_code=True,
config_overrides={
"scene.object": "microwave",
"sim.dt": 0.01,
},
)
# Combine config path with overrides
envs_dict = make_env(
"username/my-env",
n_envs=4,
trust_remote_code=True,
config_path="configs/gr1_pick_place.yaml",
config_overrides={"scene.table_objects": ["apple", "banana", "cup"]},
)
```
Any keyword arguments you pass will be forwarded to the hub environment's `make_env` function. Check the environment's documentation for supported configuration options.
### Using Custom kwargs with lerobot-eval
When evaluating policies using the `lerobot-eval` CLI, you can pass custom kwargs to hub environments using the `--env_kwargs.` prefix:
```bash
lerobot-eval \
--policy.path=user123/example-policy-checkpoint \
--env=user123/example-sim-backend \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--env_kwargs.task_id=demo_task_alpha \
--env_kwargs.agent_profile=arm_v2 \
--env_kwargs.target_item=object_red \
--env_kwargs.run_mode=offscreen \
--env_kwargs.enable_sensors=true \
--env_kwargs.record_output=true \
--env_kwargs.output_horizon=10 \
--env_kwargs.output_stride=15 \
--env_kwargs.state_features=joint_angles \
--env_kwargs.visual_streams=front_camera
```
All `--env_kwargs.*` arguments will be collected into a dictionary and passed as keyword arguments to the hub environment's `make_env` function. This allows you to:
- Pass configuration file paths
- Override default settings
- Specify custom task parameters
- Control simulation options (headless mode, camera settings, etc.)
- Select different embodiments or objects
The hub environment's `make_env` function receives these as regular keyword arguments, so check the environment's documentation for the available options.
## URL Format Reference
The hub URL format supports several patterns:
@@ -266,7 +339,7 @@ The hub URL format supports several patterns:
For benchmarks with multiple tasks (like LIBERO), return a nested dictionary:
```python
def make_env(n_envs: int = 1, use_async_envs: bool = False):
def make_env(n_envs: int = 1, use_async_envs: bool = False, **kwargs):
env_cls = gym.vector.AsyncVectorEnv if use_async_envs else gym.vector.SyncVectorEnv
# Return dict: {suite_name: {task_id: VectorEnv}}
@@ -388,8 +461,9 @@ pip install gymnasium numpy
Your `env.py` must expose a `make_env` function:
```python
def make_env(n_envs: int, use_async_envs: bool):
def make_env(n_envs: int, use_async_envs: bool, **kwargs):
# Your implementation
# kwargs can include config_path, config_overrides, etc.
pass
```
-3
View File
@@ -214,9 +214,6 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
# ---------------- Tool Configurations ----------------
[tool.setuptools.package-data]
lerobot = ["envs/*.json"]
[tool.setuptools.packages.find]
where = ["src"]
+2
View File
@@ -38,6 +38,8 @@ class EvalPipelineConfig:
seed: int | None = 1000
# Rename map for the observation to override the image and state keys
rename_map: dict[str, str] = field(default_factory=dict)
# Additional kwargs to pass to hub environments (e.g., config_path, config_overrides, custom params)
env_kwargs: dict = field(default_factory=dict)
# Explicit consent to execute remote code from the Hub (required for hub environments).
trust_remote_code: bool = False
-7
View File
@@ -7,13 +7,6 @@
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
{% if repo_id is defined and repo_id %}
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
</a>
{% endif %}
## Dataset Description
{{ dataset_description | default("", true) }}
+18 -317
View File
@@ -47,7 +47,6 @@ from lerobot.datasets.utils import (
DEFAULT_EPISODES_PATH,
get_parquet_file_size_in_mb,
load_episodes,
load_info,
update_chunk_file_indices,
write_info,
write_stats,
@@ -568,22 +567,20 @@ def _copy_and_reindex_data(
def _keep_episodes_from_video_with_av(
input_path: Path,
output_path: Path,
episodes_to_keep: list[tuple[int, int]],
episodes_to_keep: list[tuple[float, float]],
fps: float,
vcodec: str = "libsvtav1",
pix_fmt: str = "yuv420p",
) -> None:
"""Keep only specified episodes from a video file using PyAV.
This function decodes frames from specified frame ranges and re-encodes them with
This function decodes frames from specified time ranges and re-encodes them with
properly reset timestamps to ensure monotonic progression.
Args:
input_path: Source video file path.
output_path: Destination video file path.
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
is inclusive and end_frame is exclusive.
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
fps: Frame rate of the video.
vcodec: Video codec to use for encoding.
pix_fmt: Pixel format for output video.
@@ -625,10 +622,9 @@ def _keep_episodes_from_video_with_av(
# Create set of (start, end) ranges for fast lookup.
# Convert to a sorted list for efficient checking.
frame_ranges = sorted(episodes_to_keep)
time_ranges = sorted(episodes_to_keep)
# Track frame index for setting PTS and current range being processed.
src_frame_count = 0
frame_count = 0
range_idx = 0
@@ -638,20 +634,21 @@ def _keep_episodes_from_video_with_av(
if frame is None:
continue
# Check if frame is in any of our desired frame ranges.
# Get frame timestamp.
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
# Check if frame is in any of our desired time ranges.
# Skip ranges that have already passed.
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
range_idx += 1
# If we've passed all ranges, stop processing.
if range_idx >= len(frame_ranges):
if range_idx >= len(time_ranges):
break
# Check if frame is in current range.
start_frame = frame_ranges[range_idx][0]
if src_frame_count < start_frame:
src_frame_count += 1
start_ts, end_ts = time_ranges[range_idx]
if frame_time < start_ts:
continue
# Frame is in range - create a new frame with reset timestamps.
@@ -664,7 +661,6 @@ def _keep_episodes_from_video_with_av(
for pkt in v_out.encode(new_frame):
out.mux(pkt)
src_frame_count += 1
frame_count += 1
# Flush encoder.
@@ -753,17 +749,15 @@ def _copy_and_reindex_videos(
f"videos/{video_key}/to_timestamp"
]
else:
# Build list of frame ranges to keep, in sorted order.
# Build list of time ranges to keep, in sorted order.
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
episodes_to_keep_ranges: list[tuple[int, int]] = []
episodes_to_keep_ranges: list[tuple[float, float]] = []
for old_idx in sorted_keep_episodes:
src_ep = src_dataset.meta.episodes[old_idx]
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
assert src_ep["length"] == to_frame - from_frame, (
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
)
episodes_to_keep_ranges.append((from_frame, to_frame))
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
episodes_to_keep_ranges.append((from_ts, to_ts))
# Use PyAV filters to efficiently re-encode only the desired segments.
assert src_dataset.meta.video_path is not None
@@ -1775,296 +1769,3 @@ def convert_image_to_video_dataset(
# Return new dataset
return LeRobotDataset(repo_id=repo_id, root=output_dir)
def trim_episodes_by_frames(
dataset: LeRobotDataset,
episode_frames_to_keep: dict[int, list[int]],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim multiple episodes to keep only specific frames.
This function creates a new dataset where the specified episodes contain only
the frames at the given indices. All other episodes are copied as-is.
Args:
dataset: The source LeRobotDataset.
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
output_dir: Directory to save the new dataset. If None, uses default location.
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
Returns:
A new LeRobotDataset with the trimmed episodes.
"""
if not episode_frames_to_keep:
raise ValueError("No episodes to trim")
for ep_idx in episode_frames_to_keep:
if ep_idx >= dataset.meta.total_episodes:
raise ValueError(f"Episode {ep_idx} does not exist")
if not episode_frames_to_keep[ep_idx]:
raise ValueError(f"No frames to keep for episode {ep_idx}")
if repo_id is None:
repo_id = f"{dataset.repo_id}_trimmed"
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
# Create new metadata
new_meta = LeRobotDatasetMetadata.create(
repo_id=repo_id,
fps=dataset.meta.fps,
features=dataset.meta.features,
robot_type=dataset.meta.robot_type,
root=output_dir,
use_videos=len(dataset.meta.video_keys) > 0,
)
# Build set of all frames to keep (for episodes being trimmed)
# and compute new frame counts per episode
all_keep_frames: set[int] = set()
trimmed_frame_counts: dict[int, int] = {}
for ep_idx, frames in episode_frames_to_keep.items():
all_keep_frames.update(frames)
trimmed_frame_counts[ep_idx] = len(frames)
# Copy and filter data
_copy_and_reindex_data_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep, all_keep_frames
)
# Handle videos if present
if dataset.meta.video_keys:
_copy_and_reindex_videos_with_multi_frame_filter(
dataset, new_meta, episode_frames_to_keep
)
# Copy episode metadata
_copy_and_reindex_episodes_metadata_for_multi_trim(
dataset, new_meta, trimmed_frame_counts
)
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
# Return the metadata instead of trying to load as LeRobotDataset
# This avoids Hub validation issues when the repo doesn't exist yet
return new_meta
# Keep old function for backward compatibility
def trim_episode_by_frames(
dataset: LeRobotDataset,
episode_index: int,
keep_frame_indices: list[int],
output_dir: str | Path | None = None,
repo_id: str | None = None,
) -> LeRobotDataset:
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
return trim_episodes_by_frames(
dataset,
episode_frames_to_keep={episode_index: keep_frame_indices},
output_dir=output_dir,
repo_id=repo_id,
)
def _copy_and_reindex_data_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
all_keep_frames: set[int],
) -> None:
"""Copy data files with frame-level filtering for multiple episodes."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Copy tasks
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
# Tasks are stored with task string as index
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
# Get all parquet files
data_dir = src_dataset.root / "data"
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
trim_episode_set = set(episode_frames_to_keep.keys())
global_index = 0
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
df = pd.read_parquet(parquet_path)
# Filter: keep all frames from non-trimmed episodes,
# and only specified frames from trimmed episodes
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
df = df[mask].copy().reset_index(drop=True)
if len(df) == 0:
continue
# Reindex
df["index"] = range(global_index, global_index + len(df))
# Recalculate frame_index within each episode
for ep_idx in df["episode_index"].unique():
ep_mask = df["episode_index"] == ep_idx
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
# Recalculate timestamps based on frame_index and fps
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
# Determine output path (keep same structure)
rel_path = parquet_path.relative_to(src_dataset.root)
dst_path = dst_meta.root / rel_path
dst_path.parent.mkdir(parents=True, exist_ok=True)
_write_parquet(df, dst_path, dst_meta)
global_index += len(df)
def _copy_and_reindex_videos_with_multi_frame_filter(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
episode_frames_to_keep: dict[int, list[int]],
) -> None:
"""Copy video files for trimmed dataset.
In v3.0 datasets, multiple episodes are concatenated into single video files.
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
For trimming, we copy the original video files as-is and update the metadata
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
"""
for video_key in src_dataset.meta.video_keys:
video_dir = src_dataset.root / "videos" / video_key
dst_video_dir = dst_meta.root / "videos" / video_key
if not video_dir.exists():
logging.warning(f"Video directory not found: {video_dir}")
continue
# Copy all video files (they contain concatenated episodes)
# The metadata timestamps will handle which portions to use
copied_files = set()
for chunk_dir in video_dir.glob("chunk-*"):
dst_chunk_dir = dst_video_dir / chunk_dir.name
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
for video_file in chunk_dir.glob("*.mp4"):
if video_file.name not in copied_files:
dst_path = dst_chunk_dir / video_file.name
if not dst_path.exists():
shutil.copy(video_file, dst_path)
copied_files.add(video_file.name)
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
def _trim_video_frames(
src_path: Path,
dst_path: Path,
keep_frame_indices: list[int],
fps: float,
episode_start_idx: int,
) -> None:
"""Trim a video to keep only specific frames using ffmpeg."""
import subprocess
# Convert global indices to local indices within the episode
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
if not local_indices:
logging.warning(f"No frames to keep for video {src_path}")
return
# Calculate start and end times
start_frame = local_indices[0]
end_frame = local_indices[-1]
start_time = start_frame / fps
duration = (end_frame - start_frame + 1) / fps
# Use ffmpeg to trim
cmd = [
"ffmpeg", "-y",
"-ss", str(start_time),
"-i", str(src_path),
"-t", str(duration),
"-c", "copy", # Fast copy without re-encoding
str(dst_path)
]
try:
subprocess.run(cmd, check=True, capture_output=True)
except subprocess.CalledProcessError as e:
logging.error(f"Failed to trim video: {e.stderr.decode()}")
# Fallback: copy the whole video
shutil.copy(src_path, dst_path)
def _copy_and_reindex_episodes_metadata_for_multi_trim(
src_dataset: LeRobotDataset,
dst_meta: LeRobotDatasetMetadata,
trimmed_frame_counts: dict[int, int],
) -> None:
"""Copy and update episode metadata for trimmed dataset."""
if src_dataset.meta.episodes is None:
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
# Calculate new frame counts and indices
episodes_data = []
global_idx = 0
for old_ep_idx in range(src_dataset.meta.total_episodes):
src_ep = src_dataset.meta.episodes[old_ep_idx]
if old_ep_idx in trimmed_frame_counts:
ep_length = trimmed_frame_counts[old_ep_idx]
else:
ep_length = src_ep["length"]
ep_data = {
"episode_index": old_ep_idx,
"tasks": src_ep["tasks"],
"length": ep_length,
"data/chunk_index": src_ep["data/chunk_index"],
"data/file_index": src_ep["data/file_index"],
"dataset_from_index": global_idx,
"dataset_to_index": global_idx + ep_length,
}
# Copy video metadata - preserve timestamps for concatenated videos
for video_key in src_dataset.meta.video_keys:
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
# Keep original from_timestamp (start position in concatenated video)
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
# For trimmed episodes, update to_timestamp based on new length
# For non-trimmed episodes, keep original to_timestamp
if old_ep_idx in trimmed_frame_counts:
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
else:
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
ep_data["meta/episodes/chunk_index"] = 0
ep_data["meta/episodes/file_index"] = 0
episodes_data.append(ep_data)
global_idx += ep_length
# Save episodes metadata
df = pd.DataFrame(episodes_data)
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
episodes_path.parent.mkdir(parents=True, exist_ok=True)
df.to_parquet(episodes_path)
# Update info.json
info = load_info(src_dataset.root)
info["total_episodes"] = len(episodes_data)
info["total_frames"] = global_idx
write_info(info, dst_meta.root)
+2 -2
View File
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Check if cached dataset contains all requested episodes
if not self._check_cached_episodes_sufficient():
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
except (FileNotFoundError, NotADirectoryError):
except (AssertionError, FileNotFoundError, NotADirectoryError):
if is_valid_version(self.revision):
self.revision = get_safe_version(self.repo_id, self.revision)
self.download(download_videos)
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
hub_api.upload_folder(**upload_kwargs)
card = create_lerobot_dataset_card(
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
)
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
+20 -26
View File
@@ -227,17 +227,16 @@ def decode_video_frames_torchvision(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
f"\nbackend: {backend}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
@@ -249,11 +248,7 @@ def decode_video_frames_torchvision(
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
closest_frames = closest_frames.type(torch.float32) / 255
if len(timestamps) != len(closest_frames):
raise FrameTimestampError(
f"Number of retrieved frames ({len(closest_frames)}) does not match "
f"number of queried timestamps ({len(timestamps)})"
)
assert len(timestamps) == len(closest_frames)
return closest_frames
@@ -358,16 +353,15 @@ def decode_video_frames_torchcodec(
min_, argmin_ = dist.min(1)
is_within_tol = min_ < tolerance_s
if not is_within_tol.all():
raise FrameTimestampError(
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
" It means that the closest frame that can be loaded from the video is too far away in time."
" This might be due to synchronization issues with timestamps during data collection."
" To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
assert is_within_tol.all(), (
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
"It means that the closest frame that can be loaded from the video is too far away in time."
"This might be due to synchronization issues with timestamps during data collection."
"To be safe, we advise to ignore this item during training."
f"\nqueried timestamps: {query_ts}"
f"\nloaded timestamps: {loaded_ts}"
f"\nvideo: {video_path}"
)
# get closest frames to the query timestamps
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
+8 -2
View File
@@ -105,6 +105,7 @@ def make_env(
use_async_envs: bool = False,
hub_cache_dir: str | None = None,
trust_remote_code: bool = False,
**kwargs,
) -> dict[str, dict[int, gym.vector.VectorEnv]]:
"""Makes a gym vector environment according to the config or Hub reference.
@@ -118,6 +119,9 @@ def make_env(
hub_cache_dir (str | None): Optional cache path for downloaded hub files.
trust_remote_code (bool): **Explicit consent** to execute remote code from the Hub.
Default False must be set to True to import/exec hub `env.py`.
**kwargs: Additional keyword arguments passed to the hub environment's `make_env` function.
Useful for passing custom configurations like `config_path`, `config_overrides`, etc.
Raises:
ValueError: if n_envs < 1
ModuleNotFoundError: If the requested env package is not installed
@@ -149,9 +153,11 @@ def make_env(
# import and surface clear import errors
module = _import_hub_module(local_file, repo_id)
# call the hub-provided make_env
# call the hub-provided make_env with any additional kwargs
env_cfg = None if isinstance(cfg, str) else cfg
raw_result = _call_make_env(module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg)
raw_result = _call_make_env(
module, n_envs=n_envs, use_async_envs=use_async_envs, cfg=env_cfg, **kwargs
)
# normalize the return into {suite: {task_id: vec_env}}
return _normalize_hub_result(raw_result)
+12 -5
View File
@@ -311,20 +311,27 @@ def _import_hub_module(local_file: str, repo_id: str) -> Any:
return module
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None) -> Any:
def _call_make_env(module: Any, n_envs: int, use_async_envs: bool, cfg: EnvConfig | None, **kwargs) -> Any:
"""
Ensure module exposes make_env and call it.
Ensure module exposes make_env and call it with any additional kwargs.
Args:
module: The imported hub module containing make_env.
n_envs: Number of parallel environments.
use_async_envs: Whether to use AsyncVectorEnv or SyncVectorEnv.
**kwargs: Additional keyword arguments to pass to the hub's make_env function.
Common examples include config_path, config_overrides, etc.
"""
if not hasattr(module, "make_env"):
raise AttributeError(
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool)`."
f"The hub module {getattr(module, '__name__', 'hub_module')} must expose `make_env(n_envs=int, use_async_envs=bool, **kwargs)`."
)
entry_fn = module.make_env
# Only pass cfg if it's not None (i.e., when an EnvConfig was provided, not a string hub ID)
if cfg is not None:
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg)
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, cfg=cfg, **kwargs)
else:
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs)
return entry_fn(n_envs=n_envs, use_async_envs=use_async_envs, **kwargs)
def _normalize_hub_result(result: Any) -> dict[str, dict[int, gym.vector.VectorEnv]]:
@@ -139,10 +139,6 @@ class DiffusionConfig(PreTrainedConfig):
# Inference
num_inference_steps: int | None = None
# Optimization
compile_model: bool = False
compile_mode: str = "reduce-overhead"
# Loss computation
do_mask_loss_for_padding: bool = False
@@ -142,9 +142,6 @@ class DiffusionPolicy(PreTrainedPolicy):
"""Run the batch through the model and compute the loss for training or validation."""
if self.config.image_features:
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
for key in self.config.image_features:
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
batch[key] = batch[key].unsqueeze(1)
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
loss = self.diffusion.compute_loss(batch)
# no output_dict so returning None
@@ -185,11 +182,6 @@ class DiffusionModel(nn.Module):
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
if config.compile_model:
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
# common in diffusion inference.
self.unet = torch.compile(self.unet, mode=config.compile_mode)
self.noise_scheduler = _make_noise_scheduler(
config.noise_scheduler_type,
num_train_timesteps=config.num_train_timesteps,
+3 -1
View File
@@ -277,7 +277,9 @@ class SARMEncodingProcessorStep(ProcessorStep):
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
if self.dataset_meta is not None:
episodes_df = self.dataset_meta.episodes.to_pandas()
episodes_df = None
if self.sparse_subtask_names != ["task"]:
episodes_df = self.dataset_meta.episodes.to_pandas()
# Generate sparse targets
if self.sparse_temporal_proportions is not None:
-133
View File
@@ -104,28 +104,6 @@ Convert image dataset to video format and push to hub:
--operation.type convert_image_to_video \
--push_to_hub true
Trim single episode to keep only frames within timestamp range:
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--new_repo_id lerobot/pusht_trimmed \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--operation.end_timestamp 30.0
Trim multiple episodes at once (use null for no limit):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
Trim and re-upload to same repo (overwrites original):
python -m lerobot.scripts.lerobot_edit_dataset \
--repo_id lerobot/pusht \
--operation.type trim_episode \
--operation.episode_index 0 \
--operation.start_timestamp 10.0 \
--push_to_hub true
Show dataset information:
lerobot-edit-dataset \
--repo_id lerobot/pusht_image \
@@ -226,32 +204,9 @@ class InfoConfig(OperationConfig):
show_features: bool = False
@dataclass
class TrimEpisodeConfig:
"""Trim episodes to keep only frames within timestamp ranges.
Supports multiple episodes via episode_trims dict:
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
Or single episode via legacy parameters:
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
"""
type: str = "trim_episode"
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
episode_trims: dict[str, list[float | None]] | None = None
# Legacy single-episode parameters (used if episode_trims is None)
episode_index: int | None = None
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
@dataclass
class EditDatasetConfig:
repo_id: str
operation: (
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
)
operation: OperationConfig
root: str | None = None
new_repo_id: str | None = None
@@ -396,92 +351,6 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
"""Trim episodes to keep only frames within timestamp ranges."""
if not isinstance(cfg.operation, TrimEpisodeConfig):
raise ValueError("Operation config must be TrimEpisodeConfig")
# Parse episode trims - support both multi-episode dict and legacy single episode
episode_trims: dict[int, tuple[float | None, float | None]] = {}
if cfg.operation.episode_trims is not None:
# Multi-episode mode
for ep_str, ts_range in cfg.operation.episode_trims.items():
ep_idx = int(ep_str)
start_ts = ts_range[0] if len(ts_range) > 0 else None
end_ts = ts_range[1] if len(ts_range) > 1 else None
episode_trims[ep_idx] = (start_ts, end_ts)
elif cfg.operation.episode_index is not None:
# Legacy single-episode mode
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
episode_trims[cfg.operation.episode_index] = (
cfg.operation.start_timestamp,
cfg.operation.end_timestamp,
)
else:
raise ValueError("Either episode_trims or episode_index must be specified")
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
output_repo_id, output_dir = get_output_path(
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
)
if cfg.new_repo_id is None:
dataset.root = Path(str(dataset.root) + "_old")
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
# Get episode boundaries and find frames to keep for each episode
episodes_info = dataset.meta.episodes
all_frames_to_keep: dict[int, list[int]] = {}
for ep_idx, (start_ts, end_ts) in episode_trims.items():
if ep_idx >= len(episodes_info["episode_index"]):
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
from_frame = episodes_info["dataset_from_index"][ep_idx]
to_frame = episodes_info["dataset_to_index"][ep_idx]
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
# Find frames within timestamp range
frames_to_keep = []
for frame_idx in range(from_frame, to_frame):
frame = dataset.hf_dataset[frame_idx]
ts = frame["timestamp"]
in_range = True
if start_ts is not None and ts < start_ts:
in_range = False
if end_ts is not None and ts > end_ts:
in_range = False
if in_range:
frames_to_keep.append(frame_idx)
if not frames_to_keep:
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
all_frames_to_keep[ep_idx] = frames_to_keep
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
new_dataset = trim_episodes_by_frames(
dataset,
episode_frames_to_keep=all_frames_to_keep,
output_dir=output_dir,
repo_id=output_repo_id,
)
logging.info(f"Dataset saved to {output_dir}")
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
if cfg.push_to_hub:
logging.info(f"Pushing to hub as {output_repo_id}")
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
if not isinstance(cfg.operation, ModifyTasksConfig):
raise ValueError("Operation config must be ModifyTasksConfig")
@@ -646,8 +515,6 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
handle_modify_tasks(cfg)
elif operation_type == "convert_image_to_video":
handle_convert_image_to_video(cfg)
elif operation_type == "trim_episode":
handle_trim_episode(cfg)
elif operation_type == "info":
handle_info(cfg)
else:
+12
View File
@@ -43,6 +43,17 @@ lerobot-eval \
Note that in both examples, the repo/folder should contain at least `config.json` and `model.safetensors` files.
You can also evaluate a model on a Hub environment with custom kwargs:
```
lerobot-eval \
--policy.path=HF_USER/HF_REPO \
--env=HF_USER/HF_REPO \
--eval.batch_size=1 \
--eval.n_episodes=10 \
--env_kwargs.environment=env_A \
--env_kwargs.embodiment=emb_B \
```
You can learn about the CLI options for this script in the `EvalPipelineConfig` in lerobot/configs/eval.py
"""
@@ -521,6 +532,7 @@ def eval_main(cfg: EvalPipelineConfig):
n_envs=cfg.eval.batch_size,
use_async_envs=cfg.eval.use_async_envs,
trust_remote_code=cfg.trust_remote_code,
**cfg.env_kwargs,
)
logging.info("Making policy.")
@@ -43,7 +43,6 @@ from lerobot.teleoperators import ( # noqa: F401
koch_leader,
make_teleoperator_from_config,
omx_leader,
openarm_mini,
so_leader,
)
@@ -52,7 +51,6 @@ COMPATIBLE_DEVICES = [
"koch_leader",
"omx_follower",
"omx_leader",
"openarm_mini",
"so100_follower",
"so100_leader",
"so101_follower",
-15
View File
@@ -24,7 +24,6 @@ import torch
from accelerate import Accelerator
from termcolor import colored
from torch.optim import Optimizer
from tqdm import tqdm
from lerobot.configs import parser
from lerobot.configs.train import TrainPipelineConfig
@@ -52,7 +51,6 @@ from lerobot.utils.utils import (
format_big_number,
has_method,
init_logging,
inside_slurm,
)
@@ -392,14 +390,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
)
if is_main_process:
progbar = tqdm(
total=cfg.steps - step,
desc="Training",
unit="step",
disable=inside_slurm(),
position=0,
leave=True,
)
logging.info(
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
)
@@ -424,8 +414,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
# increment `step` here.
step += 1
if is_main_process:
progbar.update(1)
train_tracker.step()
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
@@ -519,9 +507,6 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
accelerator.wait_for_everyone()
if is_main_process:
progbar.close()
if eval_env:
close_envs(eval_env)
@@ -1,20 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from .config_openarm_mini import OpenArmMiniConfig
from .openarm_mini import OpenArmMini
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
@@ -1,30 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import dataclass
from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("openarm_mini")
@dataclass
class OpenArmMiniConfig(TeleoperatorConfig):
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
port_right: str = "/dev/ttyUSB0"
port_left: str = "/dev/ttyUSB1"
use_degrees: bool = True
@@ -1,296 +0,0 @@
#!/usr/bin/env python
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import logging
import time
from typing import Any
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
from lerobot.motors.feetech import (
FeetechMotorsBus,
OperatingMode,
)
from lerobot.processor import RobotAction
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
from ..teleoperator import Teleoperator
from .config_openarm_mini import OpenArmMiniConfig
logger = logging.getLogger(__name__)
# Motors whose direction is inverted during readout
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
class OpenArmMini(Teleoperator):
"""
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
"""
config_class = OpenArmMiniConfig
name = "openarm_mini"
def __init__(self, config: OpenArmMiniConfig):
super().__init__(config)
self.config = config
norm_mode_body = MotorNormMode.DEGREES
motors_right = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
motors_left = {
"joint_1": Motor(1, "sts3215", norm_mode_body),
"joint_2": Motor(2, "sts3215", norm_mode_body),
"joint_3": Motor(3, "sts3215", norm_mode_body),
"joint_4": Motor(4, "sts3215", norm_mode_body),
"joint_5": Motor(5, "sts3215", norm_mode_body),
"joint_6": Motor(6, "sts3215", norm_mode_body),
"joint_7": Motor(7, "sts3215", norm_mode_body),
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
}
cal_right = {
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
}
self.bus_right = FeetechMotorsBus(
port=self.config.port_right,
motors=motors_right,
calibration=cal_right,
)
self.bus_left = FeetechMotorsBus(
port=self.config.port_left,
motors=motors_left,
calibration=cal_left,
)
@property
def action_features(self) -> dict[str, type]:
features: dict[str, type] = {}
for motor in self.bus_right.motors:
features[f"right_{motor}.pos"] = float
for motor in self.bus_left.motors:
features[f"left_{motor}.pos"] = float
return features
@property
def feedback_features(self) -> dict[str, type]:
return {}
@property
def is_connected(self) -> bool:
return self.bus_right.is_connected and self.bus_left.is_connected
@check_if_already_connected
def connect(self, calibrate: bool = True) -> None:
logger.info(f"Connecting right arm on {self.config.port_right}...")
self.bus_right.connect()
logger.info(f"Connecting left arm on {self.config.port_left}...")
self.bus_left.connect()
if calibrate:
self.calibrate()
self.configure()
logger.info(f"{self} connected.")
@property
def is_calibrated(self) -> bool:
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
def calibrate(self) -> None:
"""
Run calibration procedure for OpenArm Mini.
1. Disable torque
2. Ask user to position arms in hanging position with grippers closed
3. Set this as zero position via half-turn homing
4. Interactive gripper calibration (open/close positions)
5. Save calibration
"""
if self.calibration:
user_input = input(
f"Press ENTER to use existing calibration for {self.id}, "
f"or type 'c' and press ENTER to run new calibration: "
)
if user_input.strip().lower() != "c":
logger.info(f"Using existing calibration for {self.id}")
cal_right = {
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
}
cal_left = {
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
}
self.bus_right.write_calibration(cal_right)
self.bus_left.write_calibration(cal_left)
return
logger.info(f"\nRunning calibration for {self}")
self._calibrate_arm("right", self.bus_right)
self._calibrate_arm("left", self.bus_left)
self._save_calibration()
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
"""Calibrate a single arm with Feetech motors."""
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
bus.disable_torque()
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
for motor in bus.motors:
bus.write("Phase", motor, 12)
for motor in bus.motors:
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
input(
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
"Position the arm in the following configuration:\n"
" - Arm hanging straight down\n"
" - Gripper closed\n"
"Press ENTER when ready..."
)
homing_offsets = bus.set_half_turn_homings()
logger.info(f"{arm_name.capitalize()} arm zero position set.")
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
if self.calibration is None:
self.calibration = {}
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
max_res = motor_resolution - 1
for motor_name, motor in bus.motors.items():
prefixed_name = f"{arm_name}_{motor_name}"
if motor_name == "gripper":
input(
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
f"Step 1: CLOSE the gripper fully\n"
f"Press ENTER when gripper is closed..."
)
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper closed position recorded: {closed_pos}")
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
open_pos = bus.read("Present_Position", motor_name, normalize=False)
logger.info(f" Gripper open position recorded: {open_pos}")
if closed_pos < open_pos:
range_min = int(closed_pos)
range_max = int(open_pos)
drive_mode = 0
else:
range_min = int(open_pos)
range_max = int(closed_pos)
drive_mode = 1
logger.info(
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
f"(0=closed, 100=open, drive_mode={drive_mode})"
)
else:
range_min = 0
range_max = max_res
drive_mode = 0
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
self.calibration[prefixed_name] = MotorCalibration(
id=motor.id,
drive_mode=drive_mode,
homing_offset=homing_offsets[motor_name],
range_min=range_min,
range_max=range_max,
)
cal_for_bus = {
k.replace(f"{arm_name}_", ""): v
for k, v in self.calibration.items()
if k.startswith(f"{arm_name}_")
}
bus.write_calibration(cal_for_bus)
def configure(self) -> None:
self.bus_right.disable_torque()
self.bus_right.configure_motors()
for motor in self.bus_right.motors:
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
self.bus_left.disable_torque()
self.bus_left.configure_motors()
for motor in self.bus_left.motors:
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
def setup_motors(self) -> None:
print("\nSetting up RIGHT arm motors...")
for motor in reversed(self.bus_right.motors):
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
self.bus_right.setup_motor(motor)
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
print("\nSetting up LEFT arm motors...")
for motor in reversed(self.bus_left.motors):
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
self.bus_left.setup_motor(motor)
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
@check_if_not_connected
def get_action(self) -> RobotAction:
"""Get current action from both arms (read positions from all motors)."""
start = time.perf_counter()
right_positions = self.bus_right.sync_read("Present_Position")
left_positions = self.bus_left.sync_read("Present_Position")
action: dict[str, Any] = {}
for motor, val in right_positions.items():
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
for motor, val in left_positions.items():
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
dt_ms = (time.perf_counter() - start) * 1e3
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
return action
def send_feedback(self, feedback: dict[str, float]) -> None:
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
@check_if_not_connected
def disconnect(self) -> None:
self.bus_right.disconnect()
self.bus_left.disconnect()
logger.info(f"{self} disconnected.")
-4
View File
@@ -95,10 +95,6 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
from .bi_openarm_leader import BiOpenArmLeader
return BiOpenArmLeader(config)
elif config.type == "openarm_mini":
from .openarm_mini import OpenArmMini
return OpenArmMini(config)
else:
try:
return cast("Teleoperator", make_device_from_device_class(config))
+1 -1
View File
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
# Check if dataset_name starts with "eval_" but policy is missing
if dataset_name.startswith("eval_") and policy_cfg is None:
raise ValueError(
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
)
# Check if dataset_name does not start with "eval_" but policy is provided
+62
View File
@@ -266,3 +266,65 @@ def test_make_env_from_hub_async():
# clean up
env.close()
def test_make_env_from_hub_with_kwargs():
"""Test that kwargs are correctly passed to hub environment's make_env."""
hub_id = "lerobot/dummy-hub-env"
# Test with config_path kwarg
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
config_path="/path/to/config.yaml",
)
env = envs_dict["cartpole_suite"][0]
assert hasattr(env, "hub_config")
assert env.hub_config["config_path"] == "/path/to/config.yaml"
env.close()
# Test with config_overrides dict
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
config_overrides={"scene.object": "microwave", "sim.dt": 0.01},
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["config_overrides"]["scene.object"] == "microwave"
assert env.hub_config["config_overrides"]["sim.dt"] == 0.01
env.close()
# Test with arbitrary extra kwargs
envs_dict = make_env(
hub_id,
n_envs=1,
trust_remote_code=True,
custom_param="value",
another_param=42,
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["extra_kwargs"]["custom_param"] == "value"
assert env.hub_config["extra_kwargs"]["another_param"] == 42
env.close()
# Test combining config_path, config_overrides, and extra kwargs
envs_dict = make_env(
hub_id,
n_envs=2,
trust_remote_code=True,
config_path="my_config.yaml",
config_overrides={"robot": "gr1"},
task_name="pick_and_place",
)
env = envs_dict["cartpole_suite"][0]
assert env.hub_config["config_path"] == "my_config.yaml"
assert env.hub_config["config_overrides"]["robot"] == "gr1"
assert env.hub_config["extra_kwargs"]["task_name"] == "pick_and_place"
assert env.num_envs == 2
env.close()