mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-27 06:29:47 +00:00
Merge branch 'main' into chore/bump_transformers_v5
This commit is contained in:
@@ -173,6 +173,8 @@ jobs:
|
|||||||
shell: bash
|
shell: bash
|
||||||
working-directory: /lerobot
|
working-directory: /lerobot
|
||||||
steps:
|
steps:
|
||||||
|
- name: Fix ptxas permissions
|
||||||
|
run: chmod +x /lerobot/.venv/lib/python3.10/site-packages/triton/backends/nvidia/bin/ptxas
|
||||||
- name: Run pytest on GPU
|
- name: Run pytest on GPU
|
||||||
run: pytest tests -vv --maxfail=10
|
run: pytest tests -vv --maxfail=10
|
||||||
- name: Run end-to-end tests
|
- name: Run end-to-end tests
|
||||||
|
|||||||
@@ -1,2 +1,3 @@
|
|||||||
include src/lerobot/templates/lerobot_modelcard_template.md
|
include src/lerobot/templates/lerobot_modelcard_template.md
|
||||||
include src/lerobot/datasets/card_template.md
|
include src/lerobot/datasets/card_template.md
|
||||||
|
include src/lerobot/envs/metaworld_config.json
|
||||||
|
|||||||
@@ -85,6 +85,8 @@ RUN if [ "$UNBOUND_DEPS" = "true" ]; then \
|
|||||||
|
|
||||||
RUN uv pip install --no-cache ".[all]"
|
RUN uv pip install --no-cache ".[all]"
|
||||||
|
|
||||||
|
RUN chmod +x /lerobot/.venv/lib/python${PYTHON_VERSION}/site-packages/triton/backends/nvidia/bin/ptxas
|
||||||
|
|
||||||
# Copy the rest of the application source code
|
# Copy the rest of the application source code
|
||||||
# Make sure to have the git-LFS files for testing
|
# Make sure to have the git-LFS files for testing
|
||||||
COPY --chown=user_lerobot:user_lerobot . .
|
COPY --chown=user_lerobot:user_lerobot . .
|
||||||
|
|||||||
@@ -217,6 +217,9 @@ lerobot-edit-dataset="lerobot.scripts.lerobot_edit_dataset:main"
|
|||||||
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
lerobot-setup-can="lerobot.scripts.lerobot_setup_can:main"
|
||||||
|
|
||||||
# ---------------- Tool Configurations ----------------
|
# ---------------- Tool Configurations ----------------
|
||||||
|
[tool.setuptools.package-data]
|
||||||
|
lerobot = ["envs/*.json"]
|
||||||
|
|
||||||
[tool.setuptools.packages.find]
|
[tool.setuptools.packages.find]
|
||||||
where = ["src"]
|
where = ["src"]
|
||||||
|
|
||||||
|
|||||||
@@ -49,23 +49,18 @@ import torch
|
|||||||
|
|
||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig # noqa: F401
|
||||||
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
from lerobot.cameras.realsense.configuration_realsense import RealSenseCameraConfig # noqa: F401
|
||||||
from lerobot.robots import ( # noqa: F401
|
from lerobot.robots import (
|
||||||
Robot,
|
RobotConfig, # noqa: F401
|
||||||
RobotConfig,
|
|
||||||
bi_so_follower,
|
|
||||||
koch_follower,
|
|
||||||
make_robot_from_config,
|
make_robot_from_config,
|
||||||
omx_follower,
|
|
||||||
so_follower,
|
|
||||||
)
|
)
|
||||||
from lerobot.transport import (
|
from lerobot.transport import (
|
||||||
services_pb2, # type: ignore
|
services_pb2, # type: ignore
|
||||||
services_pb2_grpc, # type: ignore
|
services_pb2_grpc, # type: ignore
|
||||||
)
|
)
|
||||||
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
from lerobot.transport.utils import grpc_channel_options, send_bytes_in_chunks
|
||||||
|
from lerobot.utils.import_utils import register_third_party_plugins
|
||||||
|
|
||||||
from .configs import RobotClientConfig
|
from .configs import RobotClientConfig
|
||||||
from .constants import SUPPORTED_ROBOTS
|
|
||||||
from .helpers import (
|
from .helpers import (
|
||||||
Action,
|
Action,
|
||||||
FPSTracker,
|
FPSTracker,
|
||||||
@@ -485,8 +480,9 @@ class RobotClient:
|
|||||||
def async_client(cfg: RobotClientConfig):
|
def async_client(cfg: RobotClientConfig):
|
||||||
logging.info(pformat(asdict(cfg)))
|
logging.info(pformat(asdict(cfg)))
|
||||||
|
|
||||||
if cfg.robot.type not in SUPPORTED_ROBOTS:
|
# TODO: Assert if checking robot support is still needed with the plugin system
|
||||||
raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
# if cfg.robot.type not in SUPPORTED_ROBOTS:
|
||||||
|
# raise ValueError(f"Robot {cfg.robot.type} not yet supported!")
|
||||||
|
|
||||||
client = RobotClient(cfg)
|
client = RobotClient(cfg)
|
||||||
|
|
||||||
@@ -512,4 +508,5 @@ def async_client(cfg: RobotClientConfig):
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
register_third_party_plugins()
|
||||||
async_client() # run the client
|
async_client() # run the client
|
||||||
|
|||||||
@@ -7,6 +7,13 @@
|
|||||||
|
|
||||||
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
|
||||||
|
|
||||||
|
{% if repo_id is defined and repo_id %}
|
||||||
|
<a class="flex" href="https://huggingface.co/spaces/lerobot/visualize_dataset?path={{ repo_id }}">
|
||||||
|
<img class="block dark:hidden" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl.svg"/>
|
||||||
|
<img class="hidden dark:block" src="https://huggingface.co/datasets/huggingface/badges/resolve/main/visualize-this-dataset-xl-dark.svg"/>
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
## Dataset Description
|
## Dataset Description
|
||||||
|
|
||||||
{{ dataset_description | default("", true) }}
|
{{ dataset_description | default("", true) }}
|
||||||
|
|||||||
@@ -567,20 +567,22 @@ def _copy_and_reindex_data(
|
|||||||
def _keep_episodes_from_video_with_av(
|
def _keep_episodes_from_video_with_av(
|
||||||
input_path: Path,
|
input_path: Path,
|
||||||
output_path: Path,
|
output_path: Path,
|
||||||
episodes_to_keep: list[tuple[float, float]],
|
episodes_to_keep: list[tuple[int, int]],
|
||||||
fps: float,
|
fps: float,
|
||||||
vcodec: str = "libsvtav1",
|
vcodec: str = "libsvtav1",
|
||||||
pix_fmt: str = "yuv420p",
|
pix_fmt: str = "yuv420p",
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Keep only specified episodes from a video file using PyAV.
|
"""Keep only specified episodes from a video file using PyAV.
|
||||||
|
|
||||||
This function decodes frames from specified time ranges and re-encodes them with
|
This function decodes frames from specified frame ranges and re-encodes them with
|
||||||
properly reset timestamps to ensure monotonic progression.
|
properly reset timestamps to ensure monotonic progression.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
input_path: Source video file path.
|
input_path: Source video file path.
|
||||||
output_path: Destination video file path.
|
output_path: Destination video file path.
|
||||||
episodes_to_keep: List of (start_time, end_time) tuples for episodes to keep.
|
episodes_to_keep: List of (start_frame, end_frame) tuples for episodes to keep.
|
||||||
|
Ranges are half-open intervals: [start_frame, end_frame), where start_frame
|
||||||
|
is inclusive and end_frame is exclusive.
|
||||||
fps: Frame rate of the video.
|
fps: Frame rate of the video.
|
||||||
vcodec: Video codec to use for encoding.
|
vcodec: Video codec to use for encoding.
|
||||||
pix_fmt: Pixel format for output video.
|
pix_fmt: Pixel format for output video.
|
||||||
@@ -622,9 +624,10 @@ def _keep_episodes_from_video_with_av(
|
|||||||
|
|
||||||
# Create set of (start, end) ranges for fast lookup.
|
# Create set of (start, end) ranges for fast lookup.
|
||||||
# Convert to a sorted list for efficient checking.
|
# Convert to a sorted list for efficient checking.
|
||||||
time_ranges = sorted(episodes_to_keep)
|
frame_ranges = sorted(episodes_to_keep)
|
||||||
|
|
||||||
# Track frame index for setting PTS and current range being processed.
|
# Track frame index for setting PTS and current range being processed.
|
||||||
|
src_frame_count = 0
|
||||||
frame_count = 0
|
frame_count = 0
|
||||||
range_idx = 0
|
range_idx = 0
|
||||||
|
|
||||||
@@ -634,21 +637,20 @@ def _keep_episodes_from_video_with_av(
|
|||||||
if frame is None:
|
if frame is None:
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Get frame timestamp.
|
# Check if frame is in any of our desired frame ranges.
|
||||||
frame_time = float(frame.pts * frame.time_base) if frame.pts is not None else 0.0
|
|
||||||
|
|
||||||
# Check if frame is in any of our desired time ranges.
|
|
||||||
# Skip ranges that have already passed.
|
# Skip ranges that have already passed.
|
||||||
while range_idx < len(time_ranges) and frame_time >= time_ranges[range_idx][1]:
|
while range_idx < len(frame_ranges) and src_frame_count >= frame_ranges[range_idx][1]:
|
||||||
range_idx += 1
|
range_idx += 1
|
||||||
|
|
||||||
# If we've passed all ranges, stop processing.
|
# If we've passed all ranges, stop processing.
|
||||||
if range_idx >= len(time_ranges):
|
if range_idx >= len(frame_ranges):
|
||||||
break
|
break
|
||||||
|
|
||||||
# Check if frame is in current range.
|
# Check if frame is in current range.
|
||||||
start_ts, end_ts = time_ranges[range_idx]
|
start_frame = frame_ranges[range_idx][0]
|
||||||
if frame_time < start_ts:
|
|
||||||
|
if src_frame_count < start_frame:
|
||||||
|
src_frame_count += 1
|
||||||
continue
|
continue
|
||||||
|
|
||||||
# Frame is in range - create a new frame with reset timestamps.
|
# Frame is in range - create a new frame with reset timestamps.
|
||||||
@@ -661,6 +663,7 @@ def _keep_episodes_from_video_with_av(
|
|||||||
for pkt in v_out.encode(new_frame):
|
for pkt in v_out.encode(new_frame):
|
||||||
out.mux(pkt)
|
out.mux(pkt)
|
||||||
|
|
||||||
|
src_frame_count += 1
|
||||||
frame_count += 1
|
frame_count += 1
|
||||||
|
|
||||||
# Flush encoder.
|
# Flush encoder.
|
||||||
@@ -749,15 +752,17 @@ def _copy_and_reindex_videos(
|
|||||||
f"videos/{video_key}/to_timestamp"
|
f"videos/{video_key}/to_timestamp"
|
||||||
]
|
]
|
||||||
else:
|
else:
|
||||||
# Build list of time ranges to keep, in sorted order.
|
# Build list of frame ranges to keep, in sorted order.
|
||||||
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
sorted_keep_episodes = sorted(episodes_in_file, key=lambda x: episode_mapping[x])
|
||||||
episodes_to_keep_ranges: list[tuple[float, float]] = []
|
episodes_to_keep_ranges: list[tuple[int, int]] = []
|
||||||
|
|
||||||
for old_idx in sorted_keep_episodes:
|
for old_idx in sorted_keep_episodes:
|
||||||
src_ep = src_dataset.meta.episodes[old_idx]
|
src_ep = src_dataset.meta.episodes[old_idx]
|
||||||
from_ts = src_ep[f"videos/{video_key}/from_timestamp"]
|
from_frame = round(src_ep[f"videos/{video_key}/from_timestamp"] * src_dataset.meta.fps)
|
||||||
to_ts = src_ep[f"videos/{video_key}/to_timestamp"]
|
to_frame = round(src_ep[f"videos/{video_key}/to_timestamp"] * src_dataset.meta.fps)
|
||||||
episodes_to_keep_ranges.append((from_ts, to_ts))
|
assert src_ep["length"] == to_frame - from_frame, (
|
||||||
|
f"Episode length mismatch: {src_ep['length']} vs {to_frame - from_frame}"
|
||||||
|
)
|
||||||
|
episodes_to_keep_ranges.append((from_frame, to_frame))
|
||||||
|
|
||||||
# Use PyAV filters to efficiently re-encode only the desired segments.
|
# Use PyAV filters to efficiently re-encode only the desired segments.
|
||||||
assert src_dataset.meta.video_path is not None
|
assert src_dataset.meta.video_path is not None
|
||||||
|
|||||||
@@ -747,7 +747,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
# Check if cached dataset contains all requested episodes
|
# Check if cached dataset contains all requested episodes
|
||||||
if not self._check_cached_episodes_sufficient():
|
if not self._check_cached_episodes_sufficient():
|
||||||
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
raise FileNotFoundError("Cached dataset doesn't contain all requested episodes")
|
||||||
except (AssertionError, FileNotFoundError, NotADirectoryError):
|
except (FileNotFoundError, NotADirectoryError):
|
||||||
if is_valid_version(self.revision):
|
if is_valid_version(self.revision):
|
||||||
self.revision = get_safe_version(self.repo_id, self.revision)
|
self.revision = get_safe_version(self.repo_id, self.revision)
|
||||||
self.download(download_videos)
|
self.download(download_videos)
|
||||||
@@ -839,7 +839,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
|
|||||||
hub_api.upload_folder(**upload_kwargs)
|
hub_api.upload_folder(**upload_kwargs)
|
||||||
|
|
||||||
card = create_lerobot_dataset_card(
|
card = create_lerobot_dataset_card(
|
||||||
tags=tags, dataset_info=self.meta.info, license=license, **card_kwargs
|
tags=tags, dataset_info=self.meta.info, license=license, repo_id=self.repo_id, **card_kwargs
|
||||||
)
|
)
|
||||||
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
card.push_to_hub(repo_id=self.repo_id, repo_type="dataset", revision=branch)
|
||||||
|
|
||||||
@@ -1771,6 +1771,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
|||||||
)
|
)
|
||||||
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
for repo_id, ds in zip(self.repo_ids, self._datasets, strict=True):
|
||||||
extra_keys = set(ds.features).difference(intersection_features)
|
extra_keys = set(ds.features).difference(intersection_features)
|
||||||
|
if extra_keys:
|
||||||
logging.warning(
|
logging.warning(
|
||||||
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
f"keys {extra_keys} of {repo_id} were disabled as they are not contained in all the "
|
||||||
"other datasets."
|
"other datasets."
|
||||||
|
|||||||
@@ -227,7 +227,8 @@ def decode_video_frames_torchvision(
|
|||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
is_within_tol = min_ < tolerance_s
|
is_within_tol = min_ < tolerance_s
|
||||||
assert is_within_tol.all(), (
|
if not is_within_tol.all():
|
||||||
|
raise FrameTimestampError(
|
||||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
" This might be due to synchronization issues with timestamps during data collection."
|
" This might be due to synchronization issues with timestamps during data collection."
|
||||||
@@ -248,7 +249,11 @@ def decode_video_frames_torchvision(
|
|||||||
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
# convert to the pytorch format which is float32 in [0,1] range (and channel first)
|
||||||
closest_frames = closest_frames.type(torch.float32) / 255
|
closest_frames = closest_frames.type(torch.float32) / 255
|
||||||
|
|
||||||
assert len(timestamps) == len(closest_frames)
|
if len(timestamps) != len(closest_frames):
|
||||||
|
raise FrameTimestampError(
|
||||||
|
f"Number of retrieved frames ({len(closest_frames)}) does not match "
|
||||||
|
f"number of queried timestamps ({len(timestamps)})"
|
||||||
|
)
|
||||||
return closest_frames
|
return closest_frames
|
||||||
|
|
||||||
|
|
||||||
@@ -353,7 +358,8 @@ def decode_video_frames_torchcodec(
|
|||||||
min_, argmin_ = dist.min(1)
|
min_, argmin_ = dist.min(1)
|
||||||
|
|
||||||
is_within_tol = min_ < tolerance_s
|
is_within_tol = min_ < tolerance_s
|
||||||
assert is_within_tol.all(), (
|
if not is_within_tol.all():
|
||||||
|
raise FrameTimestampError(
|
||||||
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
f"One or several query timestamps unexpectedly violate the tolerance ({min_[~is_within_tol]} > {tolerance_s=})."
|
||||||
" It means that the closest frame that can be loaded from the video is too far away in time."
|
" It means that the closest frame that can be loaded from the video is too far away in time."
|
||||||
" This might be due to synchronization issues with timestamps during data collection."
|
" This might be due to synchronization issues with timestamps during data collection."
|
||||||
|
|||||||
@@ -139,6 +139,10 @@ class DiffusionConfig(PreTrainedConfig):
|
|||||||
# Inference
|
# Inference
|
||||||
num_inference_steps: int | None = None
|
num_inference_steps: int | None = None
|
||||||
|
|
||||||
|
# Optimization
|
||||||
|
compile_model: bool = False
|
||||||
|
compile_mode: str = "reduce-overhead"
|
||||||
|
|
||||||
# Loss computation
|
# Loss computation
|
||||||
do_mask_loss_for_padding: bool = False
|
do_mask_loss_for_padding: bool = False
|
||||||
|
|
||||||
|
|||||||
@@ -142,6 +142,9 @@ class DiffusionPolicy(PreTrainedPolicy):
|
|||||||
"""Run the batch through the model and compute the loss for training or validation."""
|
"""Run the batch through the model and compute the loss for training or validation."""
|
||||||
if self.config.image_features:
|
if self.config.image_features:
|
||||||
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
batch = dict(batch) # shallow copy so that adding a key doesn't modify the original
|
||||||
|
for key in self.config.image_features:
|
||||||
|
if self.config.n_obs_steps == 1 and batch[key].ndim == 4:
|
||||||
|
batch[key] = batch[key].unsqueeze(1)
|
||||||
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
batch[OBS_IMAGES] = torch.stack([batch[key] for key in self.config.image_features], dim=-4)
|
||||||
loss = self.diffusion.compute_loss(batch)
|
loss = self.diffusion.compute_loss(batch)
|
||||||
# no output_dict so returning None
|
# no output_dict so returning None
|
||||||
@@ -182,6 +185,11 @@ class DiffusionModel(nn.Module):
|
|||||||
|
|
||||||
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
self.unet = DiffusionConditionalUnet1d(config, global_cond_dim=global_cond_dim * config.n_obs_steps)
|
||||||
|
|
||||||
|
if config.compile_model:
|
||||||
|
# Compile the U-Net. "reduce-overhead" is preferred for the small-batch repetitive loops
|
||||||
|
# common in diffusion inference.
|
||||||
|
self.unet = torch.compile(self.unet, mode=config.compile_mode)
|
||||||
|
|
||||||
self.noise_scheduler = _make_noise_scheduler(
|
self.noise_scheduler = _make_noise_scheduler(
|
||||||
config.noise_scheduler_type,
|
config.noise_scheduler_type,
|
||||||
num_train_timesteps=config.num_train_timesteps,
|
num_train_timesteps=config.num_train_timesteps,
|
||||||
|
|||||||
@@ -277,8 +277,6 @@ class SARMEncodingProcessorStep(ProcessorStep):
|
|||||||
|
|
||||||
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
# When language is perturbed, targets are zero so perturbed samples don't contribute to progress loss
|
||||||
if self.dataset_meta is not None:
|
if self.dataset_meta is not None:
|
||||||
episodes_df = None
|
|
||||||
if self.sparse_subtask_names != ["task"]:
|
|
||||||
episodes_df = self.dataset_meta.episodes.to_pandas()
|
episodes_df = self.dataset_meta.episodes.to_pandas()
|
||||||
|
|
||||||
# Generate sparse targets
|
# Generate sparse targets
|
||||||
|
|||||||
@@ -77,7 +77,6 @@ class SmolVLMWithExpertModel(nn.Module):
|
|||||||
print(f"Loading {model_id} weights ...")
|
print(f"Loading {model_id} weights ...")
|
||||||
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
self.vlm = AutoModelForImageTextToText.from_pretrained(
|
||||||
model_id,
|
model_id,
|
||||||
device_map=device,
|
|
||||||
torch_dtype="bfloat16",
|
torch_dtype="bfloat16",
|
||||||
low_cpu_mem_usage=True,
|
low_cpu_mem_usage=True,
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ from lerobot.teleoperators import ( # noqa: F401
|
|||||||
koch_leader,
|
koch_leader,
|
||||||
make_teleoperator_from_config,
|
make_teleoperator_from_config,
|
||||||
omx_leader,
|
omx_leader,
|
||||||
|
openarm_mini,
|
||||||
so_leader,
|
so_leader,
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -51,6 +52,7 @@ COMPATIBLE_DEVICES = [
|
|||||||
"koch_leader",
|
"koch_leader",
|
||||||
"omx_follower",
|
"omx_follower",
|
||||||
"omx_leader",
|
"omx_leader",
|
||||||
|
"openarm_mini",
|
||||||
"so100_follower",
|
"so100_follower",
|
||||||
"so100_leader",
|
"so100_leader",
|
||||||
"so101_follower",
|
"so101_follower",
|
||||||
|
|||||||
@@ -24,6 +24,7 @@ import torch
|
|||||||
from accelerate import Accelerator
|
from accelerate import Accelerator
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
from torch.optim import Optimizer
|
from torch.optim import Optimizer
|
||||||
|
from tqdm import tqdm
|
||||||
|
|
||||||
from lerobot.configs import parser
|
from lerobot.configs import parser
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
@@ -51,6 +52,7 @@ from lerobot.utils.utils import (
|
|||||||
format_big_number,
|
format_big_number,
|
||||||
has_method,
|
has_method,
|
||||||
init_logging,
|
init_logging,
|
||||||
|
inside_slurm,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -390,6 +392,14 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
)
|
)
|
||||||
|
|
||||||
if is_main_process:
|
if is_main_process:
|
||||||
|
progbar = tqdm(
|
||||||
|
total=cfg.steps - step,
|
||||||
|
desc="Training",
|
||||||
|
unit="step",
|
||||||
|
disable=inside_slurm(),
|
||||||
|
position=0,
|
||||||
|
leave=True,
|
||||||
|
)
|
||||||
logging.info(
|
logging.info(
|
||||||
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
f"Start offline training on a fixed dataset, with effective batch size: {effective_batch_size}"
|
||||||
)
|
)
|
||||||
@@ -414,6 +424,8 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
# Note: eval and checkpoint happens *after* the `step`th training update has completed, so we
|
||||||
# increment `step` here.
|
# increment `step` here.
|
||||||
step += 1
|
step += 1
|
||||||
|
if is_main_process:
|
||||||
|
progbar.update(1)
|
||||||
train_tracker.step()
|
train_tracker.step()
|
||||||
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
is_log_step = cfg.log_freq > 0 and step % cfg.log_freq == 0 and is_main_process
|
||||||
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
is_saving_step = step % cfg.save_freq == 0 or step == cfg.steps
|
||||||
@@ -507,6 +519,9 @@ def train(cfg: TrainPipelineConfig, accelerator: Accelerator | None = None):
|
|||||||
|
|
||||||
accelerator.wait_for_everyone()
|
accelerator.wait_for_everyone()
|
||||||
|
|
||||||
|
if is_main_process:
|
||||||
|
progbar.close()
|
||||||
|
|
||||||
if eval_env:
|
if eval_env:
|
||||||
close_envs(eval_env)
|
close_envs(eval_env)
|
||||||
|
|
||||||
|
|||||||
@@ -0,0 +1,20 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from .config_openarm_mini import OpenArmMiniConfig
|
||||||
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
|
__all__ = ["OpenArmMini", "OpenArmMiniConfig"]
|
||||||
@@ -0,0 +1,30 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from ..config import TeleoperatorConfig
|
||||||
|
|
||||||
|
|
||||||
|
@TeleoperatorConfig.register_subclass("openarm_mini")
|
||||||
|
@dataclass
|
||||||
|
class OpenArmMiniConfig(TeleoperatorConfig):
|
||||||
|
"""Configuration for OpenArm Mini teleoperator with Feetech motors (dual arms)."""
|
||||||
|
|
||||||
|
port_right: str = "/dev/ttyUSB0"
|
||||||
|
port_left: str = "/dev/ttyUSB1"
|
||||||
|
|
||||||
|
use_degrees: bool = True
|
||||||
@@ -0,0 +1,296 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import time
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from lerobot.motors import Motor, MotorCalibration, MotorNormMode
|
||||||
|
from lerobot.motors.feetech import (
|
||||||
|
FeetechMotorsBus,
|
||||||
|
OperatingMode,
|
||||||
|
)
|
||||||
|
from lerobot.processor import RobotAction
|
||||||
|
from lerobot.utils.decorators import check_if_already_connected, check_if_not_connected
|
||||||
|
|
||||||
|
from ..teleoperator import Teleoperator
|
||||||
|
from .config_openarm_mini import OpenArmMiniConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
# Motors whose direction is inverted during readout
|
||||||
|
RIGHT_MOTORS_TO_FLIP = ["joint_1", "joint_2", "joint_3", "joint_4", "joint_5"]
|
||||||
|
LEFT_MOTORS_TO_FLIP = ["joint_1", "joint_3", "joint_4", "joint_5", "joint_6", "joint_7"]
|
||||||
|
|
||||||
|
|
||||||
|
class OpenArmMini(Teleoperator):
|
||||||
|
"""
|
||||||
|
OpenArm Mini Teleoperator with dual Feetech-based arms (8 motors per arm).
|
||||||
|
|
||||||
|
Each arm has 7 joints plus a gripper, using Feetech STS3215 servos.
|
||||||
|
"""
|
||||||
|
|
||||||
|
config_class = OpenArmMiniConfig
|
||||||
|
name = "openarm_mini"
|
||||||
|
|
||||||
|
def __init__(self, config: OpenArmMiniConfig):
|
||||||
|
super().__init__(config)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
norm_mode_body = MotorNormMode.DEGREES
|
||||||
|
|
||||||
|
motors_right = {
|
||||||
|
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||||
|
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||||
|
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||||
|
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||||
|
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||||
|
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||||
|
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||||
|
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
|
}
|
||||||
|
|
||||||
|
motors_left = {
|
||||||
|
"joint_1": Motor(1, "sts3215", norm_mode_body),
|
||||||
|
"joint_2": Motor(2, "sts3215", norm_mode_body),
|
||||||
|
"joint_3": Motor(3, "sts3215", norm_mode_body),
|
||||||
|
"joint_4": Motor(4, "sts3215", norm_mode_body),
|
||||||
|
"joint_5": Motor(5, "sts3215", norm_mode_body),
|
||||||
|
"joint_6": Motor(6, "sts3215", norm_mode_body),
|
||||||
|
"joint_7": Motor(7, "sts3215", norm_mode_body),
|
||||||
|
"gripper": Motor(8, "sts3215", MotorNormMode.RANGE_0_100),
|
||||||
|
}
|
||||||
|
|
||||||
|
cal_right = {
|
||||||
|
k.replace("right_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("right_")
|
||||||
|
}
|
||||||
|
cal_left = {
|
||||||
|
k.replace("left_", ""): v for k, v in (self.calibration or {}).items() if k.startswith("left_")
|
||||||
|
}
|
||||||
|
|
||||||
|
self.bus_right = FeetechMotorsBus(
|
||||||
|
port=self.config.port_right,
|
||||||
|
motors=motors_right,
|
||||||
|
calibration=cal_right,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.bus_left = FeetechMotorsBus(
|
||||||
|
port=self.config.port_left,
|
||||||
|
motors=motors_left,
|
||||||
|
calibration=cal_left,
|
||||||
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def action_features(self) -> dict[str, type]:
|
||||||
|
features: dict[str, type] = {}
|
||||||
|
for motor in self.bus_right.motors:
|
||||||
|
features[f"right_{motor}.pos"] = float
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
features[f"left_{motor}.pos"] = float
|
||||||
|
return features
|
||||||
|
|
||||||
|
@property
|
||||||
|
def feedback_features(self) -> dict[str, type]:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_connected(self) -> bool:
|
||||||
|
return self.bus_right.is_connected and self.bus_left.is_connected
|
||||||
|
|
||||||
|
@check_if_already_connected
|
||||||
|
def connect(self, calibrate: bool = True) -> None:
|
||||||
|
logger.info(f"Connecting right arm on {self.config.port_right}...")
|
||||||
|
self.bus_right.connect()
|
||||||
|
logger.info(f"Connecting left arm on {self.config.port_left}...")
|
||||||
|
self.bus_left.connect()
|
||||||
|
|
||||||
|
if calibrate:
|
||||||
|
self.calibrate()
|
||||||
|
|
||||||
|
self.configure()
|
||||||
|
logger.info(f"{self} connected.")
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_calibrated(self) -> bool:
|
||||||
|
return self.bus_right.is_calibrated and self.bus_left.is_calibrated
|
||||||
|
|
||||||
|
def calibrate(self) -> None:
|
||||||
|
"""
|
||||||
|
Run calibration procedure for OpenArm Mini.
|
||||||
|
|
||||||
|
1. Disable torque
|
||||||
|
2. Ask user to position arms in hanging position with grippers closed
|
||||||
|
3. Set this as zero position via half-turn homing
|
||||||
|
4. Interactive gripper calibration (open/close positions)
|
||||||
|
5. Save calibration
|
||||||
|
"""
|
||||||
|
if self.calibration:
|
||||||
|
user_input = input(
|
||||||
|
f"Press ENTER to use existing calibration for {self.id}, "
|
||||||
|
f"or type 'c' and press ENTER to run new calibration: "
|
||||||
|
)
|
||||||
|
if user_input.strip().lower() != "c":
|
||||||
|
logger.info(f"Using existing calibration for {self.id}")
|
||||||
|
cal_right = {
|
||||||
|
k.replace("right_", ""): v for k, v in self.calibration.items() if k.startswith("right_")
|
||||||
|
}
|
||||||
|
cal_left = {
|
||||||
|
k.replace("left_", ""): v for k, v in self.calibration.items() if k.startswith("left_")
|
||||||
|
}
|
||||||
|
self.bus_right.write_calibration(cal_right)
|
||||||
|
self.bus_left.write_calibration(cal_left)
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.info(f"\nRunning calibration for {self}")
|
||||||
|
|
||||||
|
self._calibrate_arm("right", self.bus_right)
|
||||||
|
self._calibrate_arm("left", self.bus_left)
|
||||||
|
|
||||||
|
self._save_calibration()
|
||||||
|
print(f"\nCalibration complete and saved to {self.calibration_fpath}")
|
||||||
|
|
||||||
|
def _calibrate_arm(self, arm_name: str, bus: FeetechMotorsBus) -> None:
|
||||||
|
"""Calibrate a single arm with Feetech motors."""
|
||||||
|
logger.info(f"\n=== Calibrating {arm_name.upper()} arm ===")
|
||||||
|
|
||||||
|
bus.disable_torque()
|
||||||
|
|
||||||
|
logger.info(f"Setting Phase to 12 for all motors in {arm_name.upper()} arm...")
|
||||||
|
for motor in bus.motors:
|
||||||
|
bus.write("Phase", motor, 12)
|
||||||
|
|
||||||
|
for motor in bus.motors:
|
||||||
|
bus.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
|
input(
|
||||||
|
f"\nCalibration: Zero Position ({arm_name.upper()} arm)\n"
|
||||||
|
"Position the arm in the following configuration:\n"
|
||||||
|
" - Arm hanging straight down\n"
|
||||||
|
" - Gripper closed\n"
|
||||||
|
"Press ENTER when ready..."
|
||||||
|
)
|
||||||
|
|
||||||
|
homing_offsets = bus.set_half_turn_homings()
|
||||||
|
logger.info(f"{arm_name.capitalize()} arm zero position set.")
|
||||||
|
|
||||||
|
print(f"\nSetting motor ranges for {arm_name.upper()} arm\n")
|
||||||
|
|
||||||
|
if self.calibration is None:
|
||||||
|
self.calibration = {}
|
||||||
|
|
||||||
|
motor_resolution = bus.model_resolution_table[list(bus.motors.values())[0].model]
|
||||||
|
max_res = motor_resolution - 1
|
||||||
|
|
||||||
|
for motor_name, motor in bus.motors.items():
|
||||||
|
prefixed_name = f"{arm_name}_{motor_name}"
|
||||||
|
|
||||||
|
if motor_name == "gripper":
|
||||||
|
input(
|
||||||
|
f"\nGripper Calibration ({arm_name.upper()} arm)\n"
|
||||||
|
f"Step 1: CLOSE the gripper fully\n"
|
||||||
|
f"Press ENTER when gripper is closed..."
|
||||||
|
)
|
||||||
|
closed_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||||
|
logger.info(f" Gripper closed position recorded: {closed_pos}")
|
||||||
|
|
||||||
|
input("\nStep 2: OPEN the gripper fully\nPress ENTER when gripper is fully open...")
|
||||||
|
open_pos = bus.read("Present_Position", motor_name, normalize=False)
|
||||||
|
logger.info(f" Gripper open position recorded: {open_pos}")
|
||||||
|
|
||||||
|
if closed_pos < open_pos:
|
||||||
|
range_min = int(closed_pos)
|
||||||
|
range_max = int(open_pos)
|
||||||
|
drive_mode = 0
|
||||||
|
else:
|
||||||
|
range_min = int(open_pos)
|
||||||
|
range_max = int(closed_pos)
|
||||||
|
drive_mode = 1
|
||||||
|
|
||||||
|
logger.info(
|
||||||
|
f" {prefixed_name}: range set to [{range_min}, {range_max}] "
|
||||||
|
f"(0=closed, 100=open, drive_mode={drive_mode})"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
range_min = 0
|
||||||
|
range_max = max_res
|
||||||
|
drive_mode = 0
|
||||||
|
logger.info(f" {prefixed_name}: range set to [0, {max_res}] (full motor range)")
|
||||||
|
|
||||||
|
self.calibration[prefixed_name] = MotorCalibration(
|
||||||
|
id=motor.id,
|
||||||
|
drive_mode=drive_mode,
|
||||||
|
homing_offset=homing_offsets[motor_name],
|
||||||
|
range_min=range_min,
|
||||||
|
range_max=range_max,
|
||||||
|
)
|
||||||
|
|
||||||
|
cal_for_bus = {
|
||||||
|
k.replace(f"{arm_name}_", ""): v
|
||||||
|
for k, v in self.calibration.items()
|
||||||
|
if k.startswith(f"{arm_name}_")
|
||||||
|
}
|
||||||
|
bus.write_calibration(cal_for_bus)
|
||||||
|
|
||||||
|
def configure(self) -> None:
|
||||||
|
self.bus_right.disable_torque()
|
||||||
|
self.bus_right.configure_motors()
|
||||||
|
for motor in self.bus_right.motors:
|
||||||
|
self.bus_right.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
|
self.bus_left.disable_torque()
|
||||||
|
self.bus_left.configure_motors()
|
||||||
|
for motor in self.bus_left.motors:
|
||||||
|
self.bus_left.write("Operating_Mode", motor, OperatingMode.POSITION.value)
|
||||||
|
|
||||||
|
def setup_motors(self) -> None:
|
||||||
|
print("\nSetting up RIGHT arm motors...")
|
||||||
|
for motor in reversed(self.bus_right.motors):
|
||||||
|
input(f"Connect the controller board to the RIGHT '{motor}' motor only and press enter.")
|
||||||
|
self.bus_right.setup_motor(motor)
|
||||||
|
print(f"RIGHT '{motor}' motor id set to {self.bus_right.motors[motor].id}")
|
||||||
|
|
||||||
|
print("\nSetting up LEFT arm motors...")
|
||||||
|
for motor in reversed(self.bus_left.motors):
|
||||||
|
input(f"Connect the controller board to the LEFT '{motor}' motor only and press enter.")
|
||||||
|
self.bus_left.setup_motor(motor)
|
||||||
|
print(f"LEFT '{motor}' motor id set to {self.bus_left.motors[motor].id}")
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def get_action(self) -> RobotAction:
|
||||||
|
"""Get current action from both arms (read positions from all motors)."""
|
||||||
|
start = time.perf_counter()
|
||||||
|
|
||||||
|
right_positions = self.bus_right.sync_read("Present_Position")
|
||||||
|
left_positions = self.bus_left.sync_read("Present_Position")
|
||||||
|
|
||||||
|
action: dict[str, Any] = {}
|
||||||
|
for motor, val in right_positions.items():
|
||||||
|
action[f"right_{motor}.pos"] = -val if motor in RIGHT_MOTORS_TO_FLIP else val
|
||||||
|
for motor, val in left_positions.items():
|
||||||
|
action[f"left_{motor}.pos"] = -val if motor in LEFT_MOTORS_TO_FLIP else val
|
||||||
|
|
||||||
|
dt_ms = (time.perf_counter() - start) * 1e3
|
||||||
|
logger.debug(f"{self} read action: {dt_ms:.1f}ms")
|
||||||
|
return action
|
||||||
|
|
||||||
|
def send_feedback(self, feedback: dict[str, float]) -> None:
|
||||||
|
raise NotImplementedError("Feedback is not yet implemented for OpenArm Mini.")
|
||||||
|
|
||||||
|
@check_if_not_connected
|
||||||
|
def disconnect(self) -> None:
|
||||||
|
self.bus_right.disconnect()
|
||||||
|
self.bus_left.disconnect()
|
||||||
|
logger.info(f"{self} disconnected.")
|
||||||
@@ -95,6 +95,10 @@ def make_teleoperator_from_config(config: TeleoperatorConfig) -> "Teleoperator":
|
|||||||
from .bi_openarm_leader import BiOpenArmLeader
|
from .bi_openarm_leader import BiOpenArmLeader
|
||||||
|
|
||||||
return BiOpenArmLeader(config)
|
return BiOpenArmLeader(config)
|
||||||
|
elif config.type == "openarm_mini":
|
||||||
|
from .openarm_mini import OpenArmMini
|
||||||
|
|
||||||
|
return OpenArmMini(config)
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
return cast("Teleoperator", make_device_from_device_class(config))
|
return cast("Teleoperator", make_device_from_device_class(config))
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ def sanity_check_dataset_name(repo_id, policy_cfg):
|
|||||||
# Check if dataset_name starts with "eval_" but policy is missing
|
# Check if dataset_name starts with "eval_" but policy is missing
|
||||||
if dataset_name.startswith("eval_") and policy_cfg is None:
|
if dataset_name.startswith("eval_") and policy_cfg is None:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided ({policy_cfg.type})."
|
f"Your dataset name begins with 'eval_' ({dataset_name}), but no policy is provided."
|
||||||
)
|
)
|
||||||
|
|
||||||
# Check if dataset_name does not start with "eval_" but policy is provided
|
# Check if dataset_name does not start with "eval_" but policy is provided
|
||||||
|
|||||||
Reference in New Issue
Block a user