mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-13 07:39:53 +00:00
Compare commits
11 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
| 012bde51cb | |||
| 975dcad918 | |||
| d0b58190da | |||
| 9a5ab8ffab | |||
| 7541d72130 | |||
| 0317a15bf1 | |||
| f138e5948a | |||
| 8fef4ddab8 | |||
| 18d9cb5ac4 | |||
| 5095ab0845 | |||
| dac1efd13d |
@@ -173,6 +173,8 @@ jobs:
|
||||
shell: bash
|
||||
working-directory: /lerobot
|
||||
steps:
|
||||
- name: Fix ptxas permissions
|
||||
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
- name: Run pytest on GPU
|
||||
run: pytest tests -vv --maxfail=10
|
||||
- name: Run end-to-end tests
|
||||
|
||||
@@ -1,2 +1,3 @@
|
||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||
include src/lerobot/datasets/card_template.md
|
||||
include src/lerobot/envs/metaworld_config.json
|
||||
|
||||
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
||||
|
||||
RUN uv pip install --no-cache ".[all]"
|
||||
|
||||
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||
|
||||
# Copy the rest of the application source code
|
||||
# Make sure to have the git-LFS files for testing
|
||||
COPY --chown=user_lerobot:user_lerobot . .
|
||||
|
||||
@@ -214,6 +214,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||
|
||||
# ---------------- Tool Configurations ----------------
|
||||
[tool.setuptools.package-data]
|
||||
lerobot = ["envs/*.json"]
|
||||
|
||||
[tool.setuptools.packages.find]
|
||||
where = ["src"]
|
||||
|
||||
|
||||
@@ -7,6 +7,13 @@
|
||||
|
||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||
|
||||
{% if repo_id is defined and repo_id %}
|
||||
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||
</a>
|
||||
{% endif %}
|
||||
|
||||
## Dataset Description
|
||||
|
||||
{{ dataset_description | default("", true) }}
|
||||
|
||||
@@ -47,6 +47,7 @@ from lerobot.datasets.utils import (
|
||||
DEFAULT_EPISODES_PATH,
|
||||
get_parquet_file_size_in_mb,
|
||||
load_episodes,
|
||||
load_info,
|
||||
update_chunk_file_indices,
|
||||
write_info,
|
||||
write_stats,
|
||||
@@ -567,20 +568,22 @@ def _copy_and_reindex_data(
|
||||
def _keep_episodes_from_video_with_av(
|
||||
input_path: Path,
|
||||
output_path: Path,
|
||||
episodes_to_keep: list[tuple[float, float]],
|
||||
episodes_to_keep: list[tuple[int, int]],
|
||||
fps: float,
|
||||
vcodec: str = "libsvtav1",
|
||||
pix_fmt: str = "yuv420p",
|
||||
) -> None:
|
||||
"""Keep only specified episodes from a video file using PyAV.
|
||||
|
||||
This function decodes frames from specified time ranges and re-encodes them with
|
||||
This function decodes frames from specified frame ranges and re-encodes them with
|
||||
properly reset timestamps to ensure monotonic progression.
|
||||
|
||||
Args:
|
||||
input_path: Source video file path.
|
||||
output_path: Destination video file path.
|
||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
||||
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||
is inclusive and end_frame is exclusive.
|
||||
fps: Frame rate of the video.
|
||||
vcodec: Video codec to use for encoding.
|
||||
pix_fmt: Pixel format for output video.
|
||||
@@ -622,9 +625,10 @@ def _keep_episodes_from_video_with_av(
|
||||
|
||||
# Create set of (start, end) ranges for fast lookup.
|
||||
# Convert to a sorted list for efficient checking.
|
||||
time_ranges = sorted(episodes_to_keep)
|
||||
frame_ranges = sorted(episodes_to_keep)
|
||||
|
||||
# Track frame index for setting PTS and current range being processed.
|
||||
src_frame_count = 0
|
||||
frame_count = 0
|
||||
range_idx = 0
|
||||
|
||||
@@ -634,21 +638,20 @@ def _keep_episodes_from_video_with_av(
|
||||
if frame is None:
|
||||
continue
|
||||
|
||||
# Get frame timestamp.
|
||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
||||
|
||||
# Check if frame is in any of our desired time ranges.
|
||||
# Check if frame is in any of our desired frame ranges.
|
||||
# Skip ranges that have already passed.
|
||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
||||
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||
range_idx += 1
|
||||
|
||||
# If we've passed all ranges, stop processing.
|
||||
if range_idx >= len(time_ranges):
|
||||
if range_idx >= len(frame_ranges):
|
||||
break
|
||||
|
||||
# Check if frame is in current range.
|
||||
start_ts, end_ts = time_ranges[range_idx]
|
||||
if frame_time < start_ts:
|
||||
start_frame = frame_ranges[range_idx][0]
|
||||
|
||||
if src_frame_count < start_frame:
|
||||
src_frame_count += 1
|
||||
continue
|
||||
|
||||
# Frame is in range - create a new frame with reset timestamps.
|
||||
@@ -661,6 +664,7 @@ def _keep_episodes_from_video_with_av(
|
||||
for pkt in v_out.encode(new_frame):
|
||||
out.mux(pkt)
|
||||
|
||||
src_frame_count += 1
|
||||
frame_count += 1
|
||||
|
||||
# Flush encoder.
|
||||
@@ -749,15 +753,17 @@ def _copy_and_reindex_videos(
|
||||
f"videos/{video_key}/to_timestamp"
|
||||
]
|
||||
else:
|
||||
# Build list of time ranges to keep, in sorted order.
|
||||
# Build list of frame ranges to keep, in sorted order.
|
||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
||||
|
||||
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||
for old_idx in sorted_keep_episodes:
|
||||
src_ep = src_dataset.meta.episodes[old_idx]
|
||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
||||
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||
assert src_ep["length"] == to_frame - from_frame, (
|
||||
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||
)
|
||||
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||
|
||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||
assert src_dataset.meta.video_path is not None
|
||||
@@ -1769,3 +1775,296 @@ def convert_image_to_video_dataset(
|
||||
|
||||
# Return new dataset
|
||||
return LeRobotDataset(repo_id=repo_id, root=output_dir)
|
||||
|
||||
|
||||
def trim_episodes_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim multiple episodes to keep only specific frames.
|
||||
|
||||
This function creates a new dataset where the specified episodes contain only
|
||||
the frames at the given indices. All other episodes are copied as-is.
|
||||
|
||||
Args:
|
||||
dataset: The source LeRobotDataset.
|
||||
episode_frames_to_keep: Dict mapping episode indices to lists of global frame indices to keep.
|
||||
output_dir: Directory to save the new dataset. If None, uses default location.
|
||||
repo_id: Repository ID for the new dataset. If None, appends "_trimmed" to original.
|
||||
|
||||
Returns:
|
||||
A new LeRobotDataset with the trimmed episodes.
|
||||
"""
|
||||
if not episode_frames_to_keep:
|
||||
raise ValueError("No episodes to trim")
|
||||
|
||||
for ep_idx in episode_frames_to_keep:
|
||||
if ep_idx >= dataset.meta.total_episodes:
|
||||
raise ValueError(f"Episode {ep_idx} does not exist")
|
||||
if not episode_frames_to_keep[ep_idx]:
|
||||
raise ValueError(f"No frames to keep for episode {ep_idx}")
|
||||
|
||||
if repo_id is None:
|
||||
repo_id = f"{dataset.repo_id}_trimmed"
|
||||
output_dir = Path(output_dir) if output_dir is not None else HF_LEROBOT_HOME / repo_id
|
||||
|
||||
total_trimmed = sum(len(frames) for frames in episode_frames_to_keep.values())
|
||||
logging.info(f"Trimming {len(episode_frames_to_keep)} episodes, keeping {total_trimmed} frames total")
|
||||
|
||||
# Create new metadata
|
||||
new_meta = LeRobotDatasetMetadata.create(
|
||||
repo_id=repo_id,
|
||||
fps=dataset.meta.fps,
|
||||
features=dataset.meta.features,
|
||||
robot_type=dataset.meta.robot_type,
|
||||
root=output_dir,
|
||||
use_videos=len(dataset.meta.video_keys) > 0,
|
||||
)
|
||||
|
||||
# Build set of all frames to keep (for episodes being trimmed)
|
||||
# and compute new frame counts per episode
|
||||
all_keep_frames: set[int] = set()
|
||||
trimmed_frame_counts: dict[int, int] = {}
|
||||
for ep_idx, frames in episode_frames_to_keep.items():
|
||||
all_keep_frames.update(frames)
|
||||
trimmed_frame_counts[ep_idx] = len(frames)
|
||||
|
||||
# Copy and filter data
|
||||
_copy_and_reindex_data_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep, all_keep_frames
|
||||
)
|
||||
|
||||
# Handle videos if present
|
||||
if dataset.meta.video_keys:
|
||||
_copy_and_reindex_videos_with_multi_frame_filter(
|
||||
dataset, new_meta, episode_frames_to_keep
|
||||
)
|
||||
|
||||
# Copy episode metadata
|
||||
_copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
dataset, new_meta, trimmed_frame_counts
|
||||
)
|
||||
|
||||
logging.info(f"Created trimmed dataset with {new_meta.total_frames} frames at {output_dir}")
|
||||
|
||||
# Return the metadata instead of trying to load as LeRobotDataset
|
||||
# This avoids Hub validation issues when the repo doesn't exist yet
|
||||
return new_meta
|
||||
|
||||
|
||||
# Keep old function for backward compatibility
|
||||
def trim_episode_by_frames(
|
||||
dataset: LeRobotDataset,
|
||||
episode_index: int,
|
||||
keep_frame_indices: list[int],
|
||||
output_dir: str | Path | None = None,
|
||||
repo_id: str | None = None,
|
||||
) -> LeRobotDataset:
|
||||
"""Trim a single episode. Wrapper around trim_episodes_by_frames."""
|
||||
return trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep={episode_index: keep_frame_indices},
|
||||
output_dir=output_dir,
|
||||
repo_id=repo_id,
|
||||
)
|
||||
|
||||
|
||||
def _copy_and_reindex_data_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
all_keep_frames: set[int],
|
||||
) -> None:
|
||||
"""Copy data files with frame-level filtering for multiple episodes."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Copy tasks
|
||||
if dst_meta.tasks is None and src_dataset.meta.tasks is not None:
|
||||
# Tasks are stored with task string as index
|
||||
dst_meta.save_episode_tasks(list(src_dataset.meta.tasks.index))
|
||||
|
||||
# Get all parquet files
|
||||
data_dir = src_dataset.root / "data"
|
||||
parquet_files = sorted(data_dir.glob("chunk-*/file-*.parquet"))
|
||||
|
||||
trim_episode_set = set(episode_frames_to_keep.keys())
|
||||
global_index = 0
|
||||
|
||||
for parquet_path in tqdm(parquet_files, desc="Processing data files"):
|
||||
df = pd.read_parquet(parquet_path)
|
||||
|
||||
# Filter: keep all frames from non-trimmed episodes,
|
||||
# and only specified frames from trimmed episodes
|
||||
mask = (~df["episode_index"].isin(trim_episode_set)) | (df["index"].isin(all_keep_frames))
|
||||
df = df[mask].copy().reset_index(drop=True)
|
||||
|
||||
if len(df) == 0:
|
||||
continue
|
||||
|
||||
# Reindex
|
||||
df["index"] = range(global_index, global_index + len(df))
|
||||
|
||||
# Recalculate frame_index within each episode
|
||||
for ep_idx in df["episode_index"].unique():
|
||||
ep_mask = df["episode_index"] == ep_idx
|
||||
df.loc[ep_mask, "frame_index"] = range(ep_mask.sum())
|
||||
|
||||
# Recalculate timestamps based on frame_index and fps
|
||||
df["timestamp"] = df["frame_index"] / src_dataset.meta.fps
|
||||
|
||||
# Determine output path (keep same structure)
|
||||
rel_path = parquet_path.relative_to(src_dataset.root)
|
||||
dst_path = dst_meta.root / rel_path
|
||||
dst_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
_write_parquet(df, dst_path, dst_meta)
|
||||
global_index += len(df)
|
||||
|
||||
|
||||
def _copy_and_reindex_videos_with_multi_frame_filter(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
episode_frames_to_keep: dict[int, list[int]],
|
||||
) -> None:
|
||||
"""Copy video files for trimmed dataset.
|
||||
|
||||
In v3.0 datasets, multiple episodes are concatenated into single video files.
|
||||
Each episode has from_timestamp/to_timestamp indicating its portion of the video.
|
||||
|
||||
For trimming, we copy the original video files as-is and update the metadata
|
||||
timestamps in _copy_and_reindex_episodes_metadata_for_multi_trim.
|
||||
"""
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
video_dir = src_dataset.root / "videos" / video_key
|
||||
dst_video_dir = dst_meta.root / "videos" / video_key
|
||||
|
||||
if not video_dir.exists():
|
||||
logging.warning(f"Video directory not found: {video_dir}")
|
||||
continue
|
||||
|
||||
# Copy all video files (they contain concatenated episodes)
|
||||
# The metadata timestamps will handle which portions to use
|
||||
copied_files = set()
|
||||
for chunk_dir in video_dir.glob("chunk-*"):
|
||||
dst_chunk_dir = dst_video_dir / chunk_dir.name
|
||||
dst_chunk_dir.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
for video_file in chunk_dir.glob("*.mp4"):
|
||||
if video_file.name not in copied_files:
|
||||
dst_path = dst_chunk_dir / video_file.name
|
||||
if not dst_path.exists():
|
||||
shutil.copy(video_file, dst_path)
|
||||
copied_files.add(video_file.name)
|
||||
|
||||
logging.info(f"Copied {len(copied_files)} video files for {video_key}")
|
||||
|
||||
|
||||
def _trim_video_frames(
|
||||
src_path: Path,
|
||||
dst_path: Path,
|
||||
keep_frame_indices: list[int],
|
||||
fps: float,
|
||||
episode_start_idx: int,
|
||||
) -> None:
|
||||
"""Trim a video to keep only specific frames using ffmpeg."""
|
||||
import subprocess
|
||||
|
||||
# Convert global indices to local indices within the episode
|
||||
local_indices = sorted([idx - episode_start_idx for idx in keep_frame_indices])
|
||||
|
||||
if not local_indices:
|
||||
logging.warning(f"No frames to keep for video {src_path}")
|
||||
return
|
||||
|
||||
# Calculate start and end times
|
||||
start_frame = local_indices[0]
|
||||
end_frame = local_indices[-1]
|
||||
|
||||
start_time = start_frame / fps
|
||||
duration = (end_frame - start_frame + 1) / fps
|
||||
|
||||
# Use ffmpeg to trim
|
||||
cmd = [
|
||||
"ffmpeg", "-y",
|
||||
"-ss", str(start_time),
|
||||
"-i", str(src_path),
|
||||
"-t", str(duration),
|
||||
"-c", "copy", # Fast copy without re-encoding
|
||||
str(dst_path)
|
||||
]
|
||||
|
||||
try:
|
||||
subprocess.run(cmd, check=True, capture_output=True)
|
||||
except subprocess.CalledProcessError as e:
|
||||
logging.error(f"Failed to trim video: {e.stderr.decode()}")
|
||||
# Fallback: copy the whole video
|
||||
shutil.copy(src_path, dst_path)
|
||||
|
||||
|
||||
def _copy_and_reindex_episodes_metadata_for_multi_trim(
|
||||
src_dataset: LeRobotDataset,
|
||||
dst_meta: LeRobotDatasetMetadata,
|
||||
trimmed_frame_counts: dict[int, int],
|
||||
) -> None:
|
||||
"""Copy and update episode metadata for trimmed dataset."""
|
||||
if src_dataset.meta.episodes is None:
|
||||
src_dataset.meta.episodes = load_episodes(src_dataset.meta.root)
|
||||
|
||||
# Calculate new frame counts and indices
|
||||
episodes_data = []
|
||||
global_idx = 0
|
||||
|
||||
for old_ep_idx in range(src_dataset.meta.total_episodes):
|
||||
src_ep = src_dataset.meta.episodes[old_ep_idx]
|
||||
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_length = trimmed_frame_counts[old_ep_idx]
|
||||
else:
|
||||
ep_length = src_ep["length"]
|
||||
|
||||
ep_data = {
|
||||
"episode_index": old_ep_idx,
|
||||
"tasks": src_ep["tasks"],
|
||||
"length": ep_length,
|
||||
"data/chunk_index": src_ep["data/chunk_index"],
|
||||
"data/file_index": src_ep["data/file_index"],
|
||||
"dataset_from_index": global_idx,
|
||||
"dataset_to_index": global_idx + ep_length,
|
||||
}
|
||||
|
||||
# Copy video metadata - preserve timestamps for concatenated videos
|
||||
for video_key in src_dataset.meta.video_keys:
|
||||
ep_data[f"videos/{video_key}/chunk_index"] = src_ep[f"videos/{video_key}/chunk_index"]
|
||||
ep_data[f"videos/{video_key}/file_index"] = src_ep[f"videos/{video_key}/file_index"]
|
||||
|
||||
# Keep original from_timestamp (start position in concatenated video)
|
||||
orig_from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
||||
ep_data[f"videos/{video_key}/from_timestamp"] = orig_from_ts
|
||||
|
||||
# For trimmed episodes, update to_timestamp based on new length
|
||||
# For non-trimmed episodes, keep original to_timestamp
|
||||
if old_ep_idx in trimmed_frame_counts:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = orig_from_ts + (ep_length / src_dataset.meta.fps)
|
||||
else:
|
||||
ep_data[f"videos/{video_key}/to_timestamp"] = src_ep[f"videos/{video_key}/to_timestamp"]
|
||||
|
||||
ep_data["meta/episodes/chunk_index"] = 0
|
||||
ep_data["meta/episodes/file_index"] = 0
|
||||
|
||||
episodes_data.append(ep_data)
|
||||
global_idx += ep_length
|
||||
|
||||
# Save episodes metadata
|
||||
df = pd.DataFrame(episodes_data)
|
||||
episodes_path = dst_meta.root / DEFAULT_EPISODES_PATH.format(chunk_index=0, file_index=0)
|
||||
episodes_path.parent.mkdir(parents=True, exist_ok=True)
|
||||
df.to_parquet(episodes_path)
|
||||
|
||||
# Update info.json
|
||||
info = load_info(src_dataset.root)
|
||||
info["total_episodes"] = len(episodes_data)
|
||||
info["total_frames"] = global_idx
|
||||
write_info(info, dst_meta.root)
|
||||
|
||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
# Check if cached dataset contains all requested episodes
|
||||
if not self._check_cached_episodes_sufficient():
|
||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
||||
except (FileNotFoundError, NotADirectoryError):
|
||||
if is_valid_version(self.revision):
|
||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||
self.download(download_videos)
|
||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
||||
hub_api.upload_folder(**upload_kwargs)
|
||||
|
||||
card = create_lerobot_dataset_card(
|
||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
||||
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||
)
|
||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||
|
||||
|
||||
@@ -227,16 +227,17 @@ def decode_video_frames_torchvision(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
f"\nbackend: {backend}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
@@ -248,7 +249,11 @@ def decode_video_frames_torchvision(
|
||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||
closest_frames = closest_frames.type(torch.float32) / 255
|
||||
|
||||
assert len(timestamps) == len(closest_frames)
|
||||
if len(timestamps) != len(closest_frames):
|
||||
raise FrameTimestampError(
|
||||
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||
f"number of queried timestamps ({len(timestamps)})"
|
||||
)
|
||||
return closest_frames
|
||||
|
||||
|
||||
@@ -353,15 +358,16 @@ def decode_video_frames_torchcodec(
|
||||
min_, argmin_ = dist.min(1)
|
||||
|
||||
is_within_tol = min_ < tolerance_s
|
||||
assert is_within_tol.all(), (
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
"It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
"This might be due to synchronization issues with timestamps during data collection."
|
||||
"To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
if not is_within_tol.all():
|
||||
raise FrameTimestampError(
|
||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||
" This might be due to synchronization issues with timestamps during data collection."
|
||||
" To be safe, we advise to ignore this item during training."
|
||||
f"\nqueried timestamps: {query_ts}"
|
||||
f"\nloaded timestamps: {loaded_ts}"
|
||||
f"\nvideo: {video_path}"
|
||||
)
|
||||
|
||||
# get closest frames to the query timestamps
|
||||
closest_frames = torch.stack([loaded_frames[idx] for idx in argmin_])
|
||||
|
||||
@@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig):
|
||||
# Inference
|
||||
num_inference_steps: int | None = None
|
||||
|
||||
# Optimization
|
||||
compile_model: bool = False
|
||||
compile_mode: str = "reduce-overhead"
|
||||
|
||||
# Loss computation
|
||||
do_mask_loss_for_padding: bool = False
|
||||
|
||||
|
||||
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
||||
"""Run the batch through the model and compute the loss for training or validation."""
|
||||
if self.config.image_features:
|
||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||
for key in self.config.image_features:
|
||||
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||
batch[key] = batch[key].unsqueeze(1)
|
||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||
loss = self.diffusion.compute_loss(batch)
|
||||
# no output_dict so returning None
|
||||
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
|
||||
|
||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||
|
||||
if config.compile_model:
|
||||
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||
# common in diffusion inference.
|
||||
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||
|
||||
self.noise_scheduler = _make_noise_scheduler(
|
||||
config.noise_scheduler_type,
|
||||
num_train_timesteps=config.num_train_timesteps,
|
||||
|
||||
@@ -277,9 +277,7 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
||||
|
||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||
if self.dataset_meta is not None:
|
||||
episodes_df = None
|
||||
if self.sparse_subtask_names != ["task"]:
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||
|
||||
# Generate sparse targets
|
||||
if self.sparse_temporal_proportions is not None:
|
||||
|
||||
@@ -104,6 +104,28 @@ Convert image dataset to video format and push to hub:
|
||||
--operation.type convert_image_to_video \
|
||||
--push_to_hub true
|
||||
|
||||
Trim single episode to keep only frames within timestamp range:
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--new_repo_id lerobot/pusht_trimmed \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--operation.end_timestamp 30.0
|
||||
|
||||
Trim multiple episodes at once (use null for no limit):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, null], "3": [null, 20.0]}'
|
||||
|
||||
Trim and re-upload to same repo (overwrites original):
|
||||
python -m lerobot.scripts.lerobot_edit_dataset \
|
||||
--repo_id lerobot/pusht \
|
||||
--operation.type trim_episode \
|
||||
--operation.episode_index 0 \
|
||||
--operation.start_timestamp 10.0 \
|
||||
--push_to_hub true
|
||||
Show dataset information:
|
||||
lerobot-edit-dataset \
|
||||
--repo_id lerobot/pusht_image \
|
||||
@@ -204,9 +226,32 @@ class InfoConfig(OperationConfig):
|
||||
show_features: bool = False
|
||||
|
||||
|
||||
@dataclass
|
||||
class TrimEpisodeConfig:
|
||||
"""Trim episodes to keep only frames within timestamp ranges.
|
||||
|
||||
Supports multiple episodes via episode_trims dict:
|
||||
--operation.episode_trims '{"0": [10.0, 30.0], "2": [5.0, 20.0]}'
|
||||
|
||||
Or single episode via legacy parameters:
|
||||
--operation.episode_index 0 --operation.start_timestamp 10.0 --operation.end_timestamp 30.0
|
||||
"""
|
||||
type: str = "trim_episode"
|
||||
# Multi-episode support: dict mapping episode_index -> [start_timestamp, end_timestamp]
|
||||
# Use null for no limit, e.g. {"0": [10.0, null], "2": [null, 30.0]}
|
||||
episode_trims: dict[str, list[float | None]] | None = None
|
||||
# Legacy single-episode parameters (used if episode_trims is None)
|
||||
episode_index: int | None = None
|
||||
start_timestamp: float | None = None # Keep frames from this timestamp (inclusive)
|
||||
end_timestamp: float | None = None # Keep frames until this timestamp (inclusive)
|
||||
|
||||
|
||||
@dataclass
|
||||
class EditDatasetConfig:
|
||||
repo_id: str
|
||||
operation: (
|
||||
DeleteEpisodesConfig | SplitConfig | MergeConfig | RemoveFeatureConfig | ConvertImageToVideoConfig | TrimEpisodeConfig
|
||||
)
|
||||
operation: OperationConfig
|
||||
root: str | None = None
|
||||
new_repo_id: str | None = None
|
||||
@@ -351,6 +396,92 @@ def handle_remove_feature(cfg: EditDatasetConfig) -> None:
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
|
||||
|
||||
def handle_trim_episode(cfg: EditDatasetConfig) -> None:
|
||||
"""Trim episodes to keep only frames within timestamp ranges."""
|
||||
if not isinstance(cfg.operation, TrimEpisodeConfig):
|
||||
raise ValueError("Operation config must be TrimEpisodeConfig")
|
||||
|
||||
# Parse episode trims - support both multi-episode dict and legacy single episode
|
||||
episode_trims: dict[int, tuple[float | None, float | None]] = {}
|
||||
|
||||
if cfg.operation.episode_trims is not None:
|
||||
# Multi-episode mode
|
||||
for ep_str, ts_range in cfg.operation.episode_trims.items():
|
||||
ep_idx = int(ep_str)
|
||||
start_ts = ts_range[0] if len(ts_range) > 0 else None
|
||||
end_ts = ts_range[1] if len(ts_range) > 1 else None
|
||||
episode_trims[ep_idx] = (start_ts, end_ts)
|
||||
elif cfg.operation.episode_index is not None:
|
||||
# Legacy single-episode mode
|
||||
if cfg.operation.start_timestamp is None and cfg.operation.end_timestamp is None:
|
||||
raise ValueError("At least one of start_timestamp or end_timestamp must be specified")
|
||||
episode_trims[cfg.operation.episode_index] = (
|
||||
cfg.operation.start_timestamp,
|
||||
cfg.operation.end_timestamp,
|
||||
)
|
||||
else:
|
||||
raise ValueError("Either episode_trims or episode_index must be specified")
|
||||
|
||||
dataset = LeRobotDataset(cfg.repo_id, root=cfg.root)
|
||||
output_repo_id, output_dir = get_output_path(
|
||||
cfg.repo_id, cfg.new_repo_id, Path(cfg.root) if cfg.root else None
|
||||
)
|
||||
|
||||
if cfg.new_repo_id is None:
|
||||
dataset.root = Path(str(dataset.root) + "_old")
|
||||
|
||||
logging.info(f"Trimming {len(episode_trims)} episode(s) from {cfg.repo_id}")
|
||||
|
||||
# Get episode boundaries and find frames to keep for each episode
|
||||
episodes_info = dataset.meta.episodes
|
||||
all_frames_to_keep: dict[int, list[int]] = {}
|
||||
|
||||
for ep_idx, (start_ts, end_ts) in episode_trims.items():
|
||||
if ep_idx >= len(episodes_info["episode_index"]):
|
||||
raise ValueError(f"Episode {ep_idx} does not exist (dataset has {len(episodes_info['episode_index'])} episodes)")
|
||||
|
||||
from_frame = episodes_info["dataset_from_index"][ep_idx]
|
||||
to_frame = episodes_info["dataset_to_index"][ep_idx]
|
||||
|
||||
logging.info(f"Episode {ep_idx}: trimming to [{start_ts}, {end_ts}]")
|
||||
logging.info(f" Original frames: {from_frame} to {to_frame} ({to_frame - from_frame} frames)")
|
||||
|
||||
# Find frames within timestamp range
|
||||
frames_to_keep = []
|
||||
for frame_idx in range(from_frame, to_frame):
|
||||
frame = dataset.hf_dataset[frame_idx]
|
||||
ts = frame["timestamp"]
|
||||
|
||||
in_range = True
|
||||
if start_ts is not None and ts < start_ts:
|
||||
in_range = False
|
||||
if end_ts is not None and ts > end_ts:
|
||||
in_range = False
|
||||
|
||||
if in_range:
|
||||
frames_to_keep.append(frame_idx)
|
||||
|
||||
if not frames_to_keep:
|
||||
raise ValueError(f"Episode {ep_idx}: No frames found in timestamp range [{start_ts}, {end_ts}]")
|
||||
|
||||
logging.info(f" Keeping {len(frames_to_keep)} frames (indices {frames_to_keep[0]} to {frames_to_keep[-1]})")
|
||||
all_frames_to_keep[ep_idx] = frames_to_keep
|
||||
|
||||
from lerobot.datasets.dataset_tools import trim_episodes_by_frames
|
||||
|
||||
new_dataset = trim_episodes_by_frames(
|
||||
dataset,
|
||||
episode_frames_to_keep=all_frames_to_keep,
|
||||
output_dir=output_dir,
|
||||
repo_id=output_repo_id,
|
||||
)
|
||||
|
||||
logging.info(f"Dataset saved to {output_dir}")
|
||||
logging.info(f"Episodes: {new_dataset.meta.total_episodes}, Frames: {new_dataset.meta.total_frames}")
|
||||
|
||||
if cfg.push_to_hub:
|
||||
logging.info(f"Pushing to hub as {output_repo_id}")
|
||||
LeRobotDataset(output_repo_id, root=output_dir).push_to_hub()
|
||||
def handle_modify_tasks(cfg: EditDatasetConfig) -> None:
|
||||
if not isinstance(cfg.operation, ModifyTasksConfig):
|
||||
raise ValueError("Operation config must be ModifyTasksConfig")
|
||||
@@ -515,6 +646,8 @@ def edit_dataset(cfg: EditDatasetConfig) -> None:
|
||||
handle_modify_tasks(cfg)
|
||||
elif operation_type == "convert_image_to_video":
|
||||
handle_convert_image_to_video(cfg)
|
||||
elif operation_type == "trim_episode":
|
||||
handle_trim_episode(cfg)
|
||||
elif operation_type == "info":
|
||||
handle_info(cfg)
|
||||
else:
|
||||
|
||||
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
||||
koch_leader,
|
||||
make_teleoperator_from_config,
|
||||
omx_leader,
|
||||
openarm_mini,
|
||||
so_leader,
|
||||
)
|
||||
|
||||
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
|
||||
"koch_leader",
|
||||
"omx_follower",
|
||||
"omx_leader",
|
||||
"openarm_mini",
|
||||
"so100_follower",
|
||||
"so100_leader",
|
||||
"so101_follower",
|
||||
|
||||
@@ -24,6 +24,7 @@ import torch
|
||||
from accelerate import Accelerator
|
||||
from termcolor import colored
|
||||
from torch.optim import Optimizer
|
||||
from tqdm import tqdm
|
||||
|
||||
from lerobot.configs import parser
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
||||
format_big_number,
|
||||
has_method,
|
||||
init_logging,
|
||||
inside_slurm,
|
||||
)
|
||||
|
||||
|
||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
)
|
||||
|
||||
if is_main_process:
|
||||
progbar = tqdm(
|
||||
total=cfg.steps - step,
|
||||
desc="Training",
|
||||
unit="step",
|
||||
disable=inside_slurm(),
|
||||
position=0,
|
||||
leave=True,
|
||||
)
|
||||
logging.info(
|
||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||
)
|
||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||
# increment `step` here.
|
||||
step += 1
|
||||
if is_main_process:
|
||||
progbar.update(1)
|
||||
train_tracker.step()
|
||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
||||
|
||||
accelerator.wait_for_everyone()
|
||||
|
||||
if is_main_process:
|
||||
progbar.close()
|
||||
|
||||
if eval_env:
|
||||
close_envs(eval_env)
|
||||
|
||||
|
||||
@@ -0,0 +1,20 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||
@@ -0,0 +1,30 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass
|
||||
|
||||
from ..config import TeleoperatorConfig
|
||||
|
||||
|
||||
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||
@dataclass
|
||||
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||
|
||||
port_right: str = "/dev/ttyUSB0"
|
||||
port_left: str = "/dev/ttyUSB1"
|
||||
|
||||
use_degrees: bool = True
|
||||
@@ -0,0 +1,296 @@
|
||||
#!/usr/bin/env python
|
||||
|
||||
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import time
|
||||
from typing import Any
|
||||
|
||||
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||
from lerobot.motors.feetech import (
|
||||
FeetechMotorsBus,
|
||||
OperatingMode,
|
||||
)
|
||||
from lerobot.processor import RobotAction
|
||||
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||
|
||||
from ..teleoperator import Teleoperator
|
||||
from .config_openarm_mini import OpenArmMiniConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# Motors whose direction is inverted during readout
|
||||
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||
|
||||
|
||||
class OpenArmMini(Teleoperator):
|
||||
"""
|
||||
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||
|
||||
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||
"""
|
||||
|
||||
config_class = OpenArmMiniConfig
|
||||
name = "openarm_mini"
|
||||
|
||||
def __init__(self, config: OpenArmMiniConfig):
|
||||
super().__init__(config)
|
||||
self.config = config
|
||||
|
||||
norm_mode_body = MotorNormMode.DEGREES
|
||||
|
||||
motors_right = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
motors_left = {
|
||||
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||
}
|
||||
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||
}
|
||||
|
||||
self.bus_right = FeetechMotorsBus(
|
||||
port=self.config.port_right,
|
||||
motors=motors_right,
|
||||
calibration=cal_right,
|
||||
)
|
||||
|
||||
self.bus_left = FeetechMotorsBus(
|
||||
port=self.config.port_left,
|
||||
motors=motors_left,
|
||||
calibration=cal_left,
|
||||
)
|
||||
|
||||
@property
|
||||
def action_features(self) -> dict[str, type]:
|
||||
features: dict[str, type] = {}
|
||||
for motor in self.bus_right.motors:
|
||||
features[f"right_{motor}.pos"] = float
|
||||
for motor in self.bus_left.motors:
|
||||
features[f"left_{motor}.pos"] = float
|
||||
return features
|
||||
|
||||
@property
|
||||
def feedback_features(self) -> dict[str, type]:
|
||||
return {}
|
||||
|
||||
@property
|
||||
def is_connected(self) -> bool:
|
||||
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||
|
||||
@check_if_already_connected
|
||||
def connect(self, calibrate: bool = True) -> None:
|
||||
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||
self.bus_right.connect()
|
||||
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||
self.bus_left.connect()
|
||||
|
||||
if calibrate:
|
||||
self.calibrate()
|
||||
|
||||
self.configure()
|
||||
logger.info(f"{self} connected.")
|
||||
|
||||
@property
|
||||
def is_calibrated(self) -> bool:
|
||||
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||
|
||||
def calibrate(self) -> None:
|
||||
"""
|
||||
Run calibration procedure for OpenArm Mini.
|
||||
|
||||
1. Disable torque
|
||||
2. Ask user to position arms in hanging position with grippers closed
|
||||
3. Set this as zero position via half-turn homing
|
||||
4. Interactive gripper calibration (open/close positions)
|
||||
5. Save calibration
|
||||
"""
|
||||
if self.calibration:
|
||||
user_input = input(
|
||||
f"Press ENTER to use existing calibration for {self.id}, "
|
||||
f"or type 'c' and press ENTER to run new calibration: "
|
||||
)
|
||||
if user_input.strip().lower() != "c":
|
||||
logger.info(f"Using existing calibration for {self.id}")
|
||||
cal_right = {
|
||||
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||
}
|
||||
cal_left = {
|
||||
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||
}
|
||||
self.bus_right.write_calibration(cal_right)
|
||||
self.bus_left.write_calibration(cal_left)
|
||||
return
|
||||
|
||||
logger.info(f"\nRunning calibration for {self}")
|
||||
|
||||
self._calibrate_arm("right", self.bus_right)
|
||||
self._calibrate_arm("left", self.bus_left)
|
||||
|
||||
self._save_calibration()
|
||||
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||
|
||||
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||
"""Calibrate a single arm with Feetech motors."""
|
||||
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||
|
||||
bus.disable_torque()
|
||||
|
||||
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||
for motor in bus.motors:
|
||||
bus.write("Phase", motor, 12)
|
||||
|
||||
for motor in bus.motors:
|
||||
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
input(
|
||||
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||
"Position the arm in the following configuration:\n"
|
||||
" - Arm hanging straight down\n"
|
||||
" - Gripper closed\n"
|
||||
"Press ENTER when ready..."
|
||||
)
|
||||
|
||||
homing_offsets = bus.set_half_turn_homings()
|
||||
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||
|
||||
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||
|
||||
if self.calibration is None:
|
||||
self.calibration = {}
|
||||
|
||||
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||
max_res = motor_resolution - 1
|
||||
|
||||
for motor_name, motor in bus.motors.items():
|
||||
prefixed_name = f"{arm_name}_{motor_name}"
|
||||
|
||||
if motor_name == "gripper":
|
||||
input(
|
||||
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||
f"Step 1: CLOSE the gripper fully\n"
|
||||
f"Press ENTER when gripper is closed..."
|
||||
)
|
||||
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||
|
||||
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||
|
||||
if closed_pos < open_pos:
|
||||
range_min = int(closed_pos)
|
||||
range_max = int(open_pos)
|
||||
drive_mode = 0
|
||||
else:
|
||||
range_min = int(open_pos)
|
||||
range_max = int(closed_pos)
|
||||
drive_mode = 1
|
||||
|
||||
logger.info(
|
||||
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||
)
|
||||
else:
|
||||
range_min = 0
|
||||
range_max = max_res
|
||||
drive_mode = 0
|
||||
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||
|
||||
self.calibration[prefixed_name] = MotorCalibration(
|
||||
id=motor.id,
|
||||
drive_mode=drive_mode,
|
||||
homing_offset=homing_offsets[motor_name],
|
||||
range_min=range_min,
|
||||
range_max=range_max,
|
||||
)
|
||||
|
||||
cal_for_bus = {
|
||||
k.replace(f"{arm_name}_", ""): v
|
||||
for k, v in self.calibration.items()
|
||||
if k.startswith(f"{arm_name}_")
|
||||
}
|
||||
bus.write_calibration(cal_for_bus)
|
||||
|
||||
def configure(self) -> None:
|
||||
self.bus_right.disable_torque()
|
||||
self.bus_right.configure_motors()
|
||||
for motor in self.bus_right.motors:
|
||||
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
self.bus_left.disable_torque()
|
||||
self.bus_left.configure_motors()
|
||||
for motor in self.bus_left.motors:
|
||||
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||
|
||||
def setup_motors(self) -> None:
|
||||
print("\nSetting up RIGHT arm motors...")
|
||||
for motor in reversed(self.bus_right.motors):
|
||||
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||
self.bus_right.setup_motor(motor)
|
||||
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||
|
||||
print("\nSetting up LEFT arm motors...")
|
||||
for motor in reversed(self.bus_left.motors):
|
||||
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||
self.bus_left.setup_motor(motor)
|
||||
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||
|
||||
@check_if_not_connected
|
||||
def get_action(self) -> RobotAction:
|
||||
"""Get current action from both arms (read positions from all motors)."""
|
||||
start = time.perf_counter()
|
||||
|
||||
right_positions = self.bus_right.sync_read("Present_Position")
|
||||
left_positions = self.bus_left.sync_read("Present_Position")
|
||||
|
||||
action: dict[str, Any] = {}
|
||||
for motor, val in right_positions.items():
|
||||
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||
for motor, val in left_positions.items():
|
||||
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||
|
||||
dt_ms = (time.perf_counter() - start) * 1e3
|
||||
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||
return action
|
||||
|
||||
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||
|
||||
@check_if_not_connected
|
||||
def disconnect(self) -> None:
|
||||
self.bus_right.disconnect()
|
||||
self.bus_left.disconnect()
|
||||
logger.info(f"{self} disconnected.")
|
||||
@@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
||||
from .bi_openarm_leader import BiOpenArmLeader
|
||||
|
||||
return BiOpenArmLeader(config)
|
||||
elif config.type == "openarm_mini":
|
||||
from .openarm_mini import OpenArmMini
|
||||
|
||||
return OpenArmMini(config)
|
||||
else:
|
||||
try:
|
||||
return cast("Teleoperator", make_device_from_device_class(config))
|
||||
|
||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
||||
# Check if dataset_name starts with "eval_" but policy is missing
|
||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||
raise ValueError(
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||
)
|
||||
|
||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||
|
||||
Reference in New Issue
Block a user