Compare commits

...

6 Commits

Author SHA1 Message Date
Steven Palma bbcf66bd82 chore: enable simplify in ruff lint (#2085) 2025-09-29 15:06:56 +02:00
Steven Palma c378a325f0 chore: enable pyugrade ruff lint (#2084) 2025-09-29 13:28:53 +02:00
Qizhi Chen 90684a9690 Improve V3 aggregate implementation (#2077)
* fix return type

* improve apply with vertorize op

* Update src/lerobot/datasets/aggregate.py

Co-authored-by: Michel Aractingi <michel.aractingi@huggingface.co>
2025-09-29 11:18:54 +02:00
Steven Palma f59eb54f5c chore: remove unused code (#2062) 2025-09-29 10:49:36 +02:00
Qizhi Chen 62e9849ffd use abs path when concatenating (#2076) 2025-09-28 14:18:22 +02:00
Francesco Capuano e3b572992e Save Cropped Dataset to Hub (#2071)
* fix: cast fps argument from dataset to int

* fix: typo

* fix: specify repo-id
2025-09-27 16:07:53 +02:00
62 changed files with 110 additions and 372 deletions
-3
View File
@@ -95,7 +95,6 @@ class HILSerlProcessorConfig:
class ObservationConfig:
add_joint_velocity_to_observation: bool = False # Add joint velocities to state
add_current_to_observation: bool = False # Add motor currents to state
add_ee_pose_to_observation: bool = False # Add end-effector pose to state
display_cameras: bool = False # Display camera feeds during execution
class ImagePreprocessingConfig:
@@ -105,7 +104,6 @@ class ImagePreprocessingConfig:
class GripperConfig:
use_gripper: bool = True # Enable gripper control
gripper_penalty: float = 0.0 # Penalty for inappropriate gripper usage
gripper_penalty_in_reward: bool = False # Include gripper penalty in reward
class ResetConfig:
fixed_reset_joint_positions: Any | None = None # Joint positions for reset
@@ -288,7 +286,6 @@ You can enable multiple observation processing features simultaneously:
"observation": {
"add_joint_velocity_to_observation": true,
"add_current_to_observation": true,
"add_ee_pose_to_observation": false,
"display_cameras": false
}
}
+1 -2
View File
@@ -136,13 +136,12 @@ Additionally you can customize mapping or safety limits by editing the processor
),
```
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` and `max_ee_twist_step_rad` are the step limits for the EE pose and can be modified to change the safety limits.
- The `EEBoundsAndSafety` step clamps EE motion to a workspace and checks for large ee step jumps to ensure safety. The `end_effector_bounds` are the bounds for the EE pose and can be modified to change the workspace. The `max_ee_step_m` are the step limits for the EE pose and can be modified to change the safety limits.
```examples/phone_to_so100/teleoperate.py
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
)
```
+1 -1
View File
@@ -38,7 +38,7 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[RobotAction, RobotActi
kinematics=kinematics_solver, end_effector_step_sizes={"x": 0.5, "y": 0.5, "z": 0.5}, motor_names=list(robot.bus.motors.keys()),
),
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20, max_ee_twist_step_rad=0.50,
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]}, max_ee_step_m=0.20,
),
GripperVelocityToJoint(),
],
-1
View File
@@ -84,7 +84,6 @@ phone_to_robot_ee_pose_processor = RobotProcessorPipeline[tuple[RobotAction, Rob
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.20,
max_ee_twist_step_rad=0.50,
),
GripperVelocityToJoint(speed_factor=20.0),
],
-1
View File
@@ -67,7 +67,6 @@ phone_to_robot_joints_processor = RobotProcessorPipeline[tuple[RobotAction, Robo
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
GripperVelocityToJoint(
speed_factor=20.0,
-1
View File
@@ -101,7 +101,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
@@ -78,7 +78,6 @@ ee_to_follower_joints = RobotProcessorPipeline[tuple[RobotAction, RobotObservati
EEBoundsAndSafety(
end_effector_bounds={"min": [-1.0, -1.0, -1.0], "max": [1.0, 1.0, 1.0]},
max_ee_step_m=0.10,
max_ee_twist_step_rad=0.50,
),
InverseKinematicsEEToJoints(
kinematics=follower_kinematics_solver,
+1 -1
View File
@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
# N: pep8-naming
# TODO: Uncomment rules when ready to use
select = [
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
]
ignore = [
"E501", # Line too long
+1 -2
View File
@@ -31,7 +31,6 @@ from lerobot.utils.constants import OBS_IMAGES, OBS_STATE, OBS_STR
from lerobot.utils.utils import init_logging
Action = torch.Tensor
ActionChunk = torch.Tensor
# observation as received from the robot
RawObservation = dict[str, torch.Tensor]
@@ -46,7 +45,7 @@ Observation = dict[str, torch.Tensor]
def visualize_action_queue_size(action_queue_size: list[int]) -> None:
import matplotlib.pyplot as plt
fig, ax = plt.subplots()
_, ax = plt.subplots()
ax.set_title("Action Queue Size Over Time")
ax.set_xlabel("Environment steps")
ax.set_ylabel("Action Queue Size")
-4
View File
@@ -15,14 +15,10 @@
# limitations under the License.
import platform
from pathlib import Path
from typing import TypeAlias
from .camera import Camera
from .configs import CameraConfig, Cv2Rotation
IndexOrPath: TypeAlias = int | Path
def make_cameras_from_configs(camera_configs: dict[str, CameraConfig]) -> dict[str, Camera]:
cameras = {}
-3
View File
@@ -16,9 +16,6 @@
from dataclasses import dataclass, field
from lerobot import (
policies, # noqa: F401
)
from lerobot.datasets.transforms import ImageTransformsConfig
from lerobot.datasets.video_utils import get_safe_default_codec
-5
View File
@@ -15,7 +15,6 @@
# https://stackoverflow.com/questions/24481852/serialising-an-enum-member-to-json
from dataclasses import dataclass
from enum import Enum
from typing import Any, Protocol
class FeatureType(str, Enum):
@@ -38,10 +37,6 @@ class NormalizationMode(str, Enum):
IDENTITY = "IDENTITY"
class DictLike(Protocol):
def __getitem__(self, key: Any) -> Any: ...
@dataclass
class PolicyFeature:
type: FeatureType
+19 -26
View File
@@ -93,14 +93,13 @@ def update_data_df(df, src_meta, dst_meta):
pd.DataFrame: Updated DataFrame with adjusted indices.
"""
def _update(row):
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
row["index"] = row["index"] + dst_meta.info["total_frames"]
task = src_meta.tasks.iloc[row["task_index"]].name
row["task_index"] = dst_meta.tasks.loc[task].task_index.item()
return row
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
df["index"] = df["index"] + dst_meta.info["total_frames"]
return df.apply(_update, axis=1)
src_task_names = src_meta.tasks.index.take(df["task_index"].to_numpy())
df["task_index"] = dst_meta.tasks.loc[src_task_names, "task_index"].to_numpy()
return df
def update_meta_data(
@@ -126,27 +125,21 @@ def update_meta_data(
pd.DataFrame: Updated DataFrame with adjusted indices and timestamps.
"""
def _update(row):
row["meta/episodes/chunk_index"] = row["meta/episodes/chunk_index"] + meta_idx["chunk"]
row["meta/episodes/file_index"] = row["meta/episodes/file_index"] + meta_idx["file"]
row["data/chunk_index"] = row["data/chunk_index"] + data_idx["chunk"]
row["data/file_index"] = row["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
row[f"videos/{key}/chunk_index"] = row[f"videos/{key}/chunk_index"] + video_idx["chunk"]
row[f"videos/{key}/file_index"] = row[f"videos/{key}/file_index"] + video_idx["file"]
row[f"videos/{key}/from_timestamp"] = (
row[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
)
row[f"videos/{key}/to_timestamp"] = (
row[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
)
df["meta/episodes/chunk_index"] = df["meta/episodes/chunk_index"] + meta_idx["chunk"]
df["meta/episodes/file_index"] = df["meta/episodes/file_index"] + meta_idx["file"]
df["data/chunk_index"] = df["data/chunk_index"] + data_idx["chunk"]
df["data/file_index"] = df["data/file_index"] + data_idx["file"]
for key, video_idx in videos_idx.items():
df[f"videos/{key}/chunk_index"] = df[f"videos/{key}/chunk_index"] + video_idx["chunk"]
df[f"videos/{key}/file_index"] = df[f"videos/{key}/file_index"] + video_idx["file"]
df[f"videos/{key}/from_timestamp"] = df[f"videos/{key}/from_timestamp"] + video_idx["latest_duration"]
df[f"videos/{key}/to_timestamp"] = df[f"videos/{key}/to_timestamp"] + video_idx["latest_duration"]
row["dataset_from_index"] = row["dataset_from_index"] + dst_meta.info["total_frames"]
row["dataset_to_index"] = row["dataset_to_index"] + dst_meta.info["total_frames"]
row["episode_index"] = row["episode_index"] + dst_meta.info["total_episodes"]
return row
df["dataset_from_index"] = df["dataset_from_index"] + dst_meta.info["total_frames"]
df["dataset_to_index"] = df["dataset_to_index"] + dst_meta.info["total_frames"]
df["episode_index"] = df["episode_index"] + dst_meta.info["total_episodes"]
return df.apply(_update, axis=1)
return df
def aggregate_datasets(
+4 -14
View File
@@ -848,11 +848,6 @@ class LeRobotDataset(torch.utils.data.Dataset):
return item
def _add_padding_keys(self, item: dict, padding: dict[str, list[bool]]) -> dict:
for key, val in padding.items():
item[key] = torch.BoolTensor(val)
return item
def __len__(self):
return self.num_frames
@@ -1032,7 +1027,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
# Reset episode buffer and clean up temporary images (if not already deleted during video encoding)
self.clear_episode_buffer(delete_images=len(self.meta.image_keys) > 0)
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None):
def _batch_save_episode_video(self, start_episode: int, end_episode: int | None = None) -> None:
"""
Batch save videos for multiple episodes.
@@ -1158,7 +1153,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
}
return metadata
def _save_episode_video(self, video_key: str, episode_index: int):
def _save_episode_video(self, video_key: str, episode_index: int) -> dict:
# Encode episode frames into a temporary video
ep_path = self._encode_temporary_episode_video(video_key, episode_index)
ep_size_in_mb = get_video_size_in_mb(ep_path)
@@ -1263,7 +1258,7 @@ class LeRobotDataset(torch.utils.data.Dataset):
if self.image_writer is not None:
self.image_writer.wait_until_done()
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> dict:
def _encode_temporary_episode_video(self, video_key: str, episode_index: int) -> Path:
"""
Use ffmpeg to convert frames stored as png into mp4 videos.
Note: `encode_video_frames` is a blocking call. Making it asynchronous shouldn't speedup encoding,
@@ -1396,11 +1391,6 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""
return {repo_id: i for i, repo_id in enumerate(self.repo_ids)}
@property
def repo_index_to_id(self):
"""Return the inverse mapping if repo_id_to_index."""
return {v: k for k, v in self.repo_id_to_index}
@property
def fps(self) -> int:
"""Frames per second used during data collection.
@@ -1431,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
@@ -13,67 +13,10 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from concurrent.futures import ThreadPoolExecutor
from pathlib import Path
import datasets
import numpy
import PIL
import torch
from lerobot.datasets.video_utils import encode_video_frames
def concatenate_episodes(ep_dicts):
data_dict = {}
keys = ep_dicts[0].keys()
for key in keys:
if torch.is_tensor(ep_dicts[0][key][0]):
data_dict[key] = torch.cat([ep_dict[key] for ep_dict in ep_dicts])
else:
if key not in data_dict:
data_dict[key] = []
for ep_dict in ep_dicts:
for x in ep_dict[key]:
data_dict[key].append(x)
total_frames = data_dict["frame_index"].shape[0]
data_dict["index"] = torch.arange(0, total_frames, 1)
return data_dict
def save_images_concurrently(imgs_array: numpy.array, out_dir: Path, max_workers: int = 4):
out_dir = Path(out_dir)
out_dir.mkdir(parents=True, exist_ok=True)
def save_image(img_array, i, out_dir):
img = PIL.Image.fromarray(img_array)
img.save(str(out_dir / f"frame_{i:06d}.png"), quality=100)
num_images = len(imgs_array)
with ThreadPoolExecutor(max_workers=max_workers) as executor:
[executor.submit(save_image, imgs_array[i], i, out_dir) for i in range(num_images)]
def get_default_encoding() -> dict:
"""Returns the default ffmpeg encoding parameters used by `encode_video_frames`."""
signature = inspect.signature(encode_video_frames)
return {
k: v.default
for k, v in signature.parameters.items()
if v.default is not inspect.Parameter.empty and k in ["vcodec", "pix_fmt", "g", "crf"]
}
def check_repo_id(repo_id: str) -> None:
if len(repo_id.split("/")) != 2:
raise ValueError(
f"""`repo_id` is expected to contain a community or user id `/` the name of the dataset
(e.g. 'lerobot/pusht'), but contains '{repo_id}'."""
)
# TODO(aliberts): remove
def calculate_episode_data_index(hf_dataset: datasets.Dataset) -> dict[str, torch.Tensor]:
+1 -3
View File
@@ -298,9 +298,7 @@ class StreamingLeRobotDataset(torch.utils.data.IterableDataset):
return padding_mask
def make_frame(
self, dataset_iterator: Backtrackable, previous_dataset_iterator: Backtrackable | None = None
) -> Generator:
def make_frame(self, dataset_iterator: Backtrackable) -> Generator:
"""Makes a frame starting from a dataset iterator"""
item = next(dataset_iterator)
item = item_to_torch(item)
+1 -1
View File
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
+7 -56
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Deque, Generic, TypeVar
from typing import Any, Generic, TypeVar
import datasets
import numpy as np
@@ -67,18 +67,6 @@ DEFAULT_IMAGE_PATH = "images/{image_key}/episode-{episode_index:06d}/frame-{fram
LEGACY_EPISODES_PATH = "meta/episodes.jsonl"
LEGACY_EPISODES_STATS_PATH = "meta/episodes_stats.jsonl"
LEGACY_TASKS_PATH = "meta/tasks.jsonl"
LEGACY_DEFAULT_VIDEO_PATH = "videos/chunk-{episode_chunk:03d}/{video_key}/episode_{episode_index:06d}.mp4"
LEGACY_DEFAULT_PARQUET_PATH = "data/chunk-{episode_chunk:03d}/episode_{episode_index:06d}.parquet"
DATASET_CARD_TEMPLATE = """
---
# Metadata will go there
---
This dataset was created using [LeRobot](https://github.com/huggingface/lerobot).
## {}
"""
DEFAULT_FEATURES = {
"timestamp": {"dtype": "float32", "shape": (1,), "names": None},
@@ -219,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -383,12 +371,6 @@ def load_episodes(local_dir: Path) -> datasets.Dataset:
return episodes
def backward_compatible_episodes_stats(
stats: dict[str, dict[str, np.ndarray]], episodes: list[int]
) -> dict[int, dict[str, dict[str, np.ndarray]]]:
return dict.fromkeys(episodes, stats)
def load_image_as_numpy(
fpath: str | Path, dtype: np.dtype = np.float32, channel_first: bool = True
) -> np.ndarray:
@@ -1197,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
@@ -1271,8 +1253,8 @@ class Backtrackable(Generic[T]):
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
@@ -1346,12 +1328,6 @@ class Backtrackable(Generic[T]):
# When cursor<0, slice so the order remains chronological
return list(self._back_buf)[: self._cursor or None]
def lookahead_buffer(self) -> list[T]:
"""
Return a copy of the current lookahead buffer.
"""
return list(self._ahead_buf)
def can_peek_back(self, steps: int = 1) -> bool:
"""
Check if we can go back `steps` items without raising an IndexError.
@@ -1377,31 +1353,6 @@ class Backtrackable(Generic[T]):
except StopIteration:
return False
def reset_cursor(self) -> None:
"""
Reset cursor to the most recent position (equivalent to calling next()
until you're back to the latest item).
"""
self._cursor = 0
def clear_ahead_buffer(self) -> None:
"""
Clear the ahead buffer, discarding any pre-fetched items.
"""
self._ahead_buf.clear()
def switch_source_iterable(self, new_source: Iterable[T]) -> None:
"""
Switch the source of the backtrackable to a new iterable, keeping the history.
This is useful when iterating over a sequence of datasets. The history from the
previous source is kept, but the lookahead buffer is cleared. The cursor is reset
to the present.
"""
self._source = iter(new_source)
self.clear_ahead_buffer()
self.reset_cursor()
def safe_shard(dataset: datasets.IterableDataset, index: int, num_shards: int) -> datasets.Dataset:
"""
+4 -15
View File
@@ -428,7 +428,7 @@ def concatenate_video_files(
with tempfile.NamedTemporaryFile(mode="w", suffix=".ffconcat", delete=False) as tmp_concatenate_file:
tmp_concatenate_file.write("ffconcat version 1.0\n")
for input_path in input_video_paths:
tmp_concatenate_file.write(f"file '{str(input_path)}'\n")
tmp_concatenate_file.write(f"file '{str(input_path.resolve())}'\n")
tmp_concatenate_file.flush()
tmp_concatenate_path = tmp_concatenate_file.name
@@ -437,7 +437,9 @@ def concatenate_video_files(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
@@ -585,19 +587,6 @@ def get_video_pixel_channels(pix_fmt: str) -> int:
raise ValueError("Unknown format")
def get_image_pixel_channels(image: Image):
if image.mode == "L":
return 1 # Grayscale
elif image.mode == "LA":
return 2 # Grayscale + Alpha
elif image.mode == "RGB":
return 3 # RGB
elif image.mode == "RGBA":
return 4 # RGBA
else:
raise ValueError("Unknown format")
def get_video_duration_in_s(video_path: Path | str) -> float:
"""
Get the duration of a video file in seconds using PyAV.
-2
View File
@@ -193,7 +193,6 @@ class ObservationConfig:
add_joint_velocity_to_observation: bool = False
add_current_to_observation: bool = False
add_ee_pose_to_observation: bool = False
display_cameras: bool = False
@@ -203,7 +202,6 @@ class GripperConfig:
use_gripper: bool = True
gripper_penalty: float = 0.0
gripper_penalty_in_reward: bool = False
@dataclass
+1 -1
View File
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
elif isinstance(camera_name, (list | tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
+2 -2
View File
@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
@close_envs.register
def _(envs: Sequence) -> None:
if isinstance(envs, (str, bytes)):
if isinstance(envs, (str | bytes)):
return
for v in envs:
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
close_envs(v)
elif hasattr(v, "close"):
_close_single_env(v)
+4 -10
View File
@@ -99,12 +99,6 @@ class Motor:
norm_mode: MotorNormMode
class JointOutOfRangeError(Exception):
def __init__(self, message="Joint is out of range"):
self.message = message
super().__init__(self.message)
class PortHandler(Protocol):
def __init__(self, port_name):
self.is_open: bool
@@ -348,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
@@ -675,7 +669,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -703,7 +697,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -739,7 +733,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
+1 -4
View File
@@ -398,10 +398,7 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode."
)
if OBS_IMAGES in batch:
batch_size = batch[OBS_IMAGES][0].shape[0]
else:
batch_size = batch[OBS_ENV_STATE].shape[0]
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and ACTION in batch and self.training:
@@ -139,8 +139,6 @@ class SACConfig(PreTrainedConfig):
# Training parameter
# Number of steps for online training
online_steps: int = 1000000
# Seed for the online environment
online_env_seed: int = 10000
# Capacity of the online replay buffer
online_buffer_capacity: int = 100000
# Capacity of the offline replay buffer
-12
View File
@@ -1061,15 +1061,3 @@ class TanhMultivariateNormalDiag(TransformedDistribution):
x = transform(x)
return x
def _convert_normalization_params_to_tensor(normalization_params: dict) -> dict:
converted_params = {}
for outer_key, inner_dict in normalization_params.items():
converted_params[outer_key] = {}
for key, value in inner_dict.items():
converted_params[outer_key][key] = torch.tensor(value)
if "image" in outer_key:
converted_params[outer_key][key] = converted_params[outer_key][key].view(3, 1, 1)
return converted_params
@@ -82,7 +82,6 @@ class VQBeTConfig(PreTrainedConfig):
gpt_n_head: Number of headers of GPT
gpt_hidden_dim: Size of hidden dimensions of GPT
dropout: Dropout rate for GPT
mlp_hidden_dim: Size of hidden dimensions of offset header / bin prediction headers parts of VQ-BeT
offset_loss_weight: A constant that is multiplied to the offset loss
primary_code_loss_weight: A constant that is multiplied to the primary code prediction loss
secondary_code_loss_weight: A constant that is multiplied to the secondary code prediction loss
@@ -125,7 +124,6 @@ class VQBeTConfig(PreTrainedConfig):
gpt_n_head: int = 8
gpt_hidden_dim: int = 512
dropout: float = 0.1
mlp_hidden_dim: int = 1024
offset_loss_weight: float = 10000.0
primary_code_loss_weight: float = 5.0
secondary_code_loss_weight: float = 0.5
+3 -15
View File
@@ -231,16 +231,6 @@ class GPT(nn.Module):
torch.nn.init.zeros_(module.bias)
torch.nn.init.ones_(module.weight)
def crop_block_size(self, gpt_block_size):
# model surgery to decrease the block size if necessary
# e.g. we may load the GPT2 pretrained model checkpoint (block size 1024)
# but want to use a smaller block size for some smaller, simpler model
assert gpt_block_size <= self.config.gpt_block_size
self.config.gpt_block_size = gpt_block_size
self.transformer.wpe.weight = nn.Parameter(self.transformer.wpe.weight[:gpt_block_size])
for block in self.transformer.h:
block.attn.bias = block.attn.bias[:, :, :gpt_block_size, :gpt_block_size]
def configure_parameters(self):
"""
This long function is unfortunately doing something very simple and is being very defensive:
@@ -270,13 +260,11 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
assert len(inter_params) == 0, (
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
)
decay = [param_dict[pn] for pn in sorted(decay)]
@@ -83,14 +83,12 @@ class MapDeltaActionToRobotActionStep(RobotActionProcessorStep):
Attributes:
position_scale: A factor to scale the delta position inputs.
rotation_scale: A factor to scale the delta rotation inputs (currently unused).
noise_threshold: The magnitude below which delta inputs are considered noise
and do not trigger an "enabled" state.
"""
# Scale factors for delta movements
position_scale: float = 1.0
rotation_scale: float = 0.0 # No rotation deltas for gamepad/keyboard
noise_threshold: float = 1e-3 # 1 mm threshold to filter out noise
def action(self, action: RobotAction) -> RobotAction:
+1 -1
View File
@@ -340,7 +340,7 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""
action = self.transition.get(TransitionKey.ACTION)
raw_joint_positions = complementary_data.get("raw_joint_positions", None)
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
+6 -7
View File
@@ -119,13 +119,12 @@ class _NormalizationMixin:
)
self.features = reconstructed
if self.norm_map:
# if keys are strings (JSON), rebuild enum map
if all(isinstance(k, str) for k in self.norm_map.keys()):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
# if keys are strings (JSON), rebuild enum map
if self.norm_map and all(isinstance(k, str) for k in self.norm_map):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
# Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {}
@@ -152,7 +152,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
"""
# Build a new features mapping keyed by the same FeatureType buckets
# We assume callers already placed features in the correct FeatureType.
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
exact_pairs = {
"pixels": OBS_IMAGE,
-2
View File
@@ -97,8 +97,6 @@ from .gym_manipulator import (
step_env_and_process_transition,
)
ACTOR_SHUTDOWN_TIMEOUT = 30
# Main entry point
+2 -2
View File
@@ -176,7 +176,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
@@ -223,7 +223,7 @@ class ReplayBuffer:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
+17 -5
View File
@@ -160,7 +160,7 @@ def get_image_from_lerobot_dataset(dataset: LeRobotDataset):
return image_dict
def convert_lerobot_dataset_to_cropper_lerobot_dataset(
def convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset: LeRobotDataset,
crop_params_dict: dict[str, tuple[int, int, int, int]],
new_repo_id: str,
@@ -190,7 +190,7 @@ def convert_lerobot_dataset_to_cropper_lerobot_dataset(
# 1. Create a new (empty) LeRobotDataset for writing.
new_dataset = LeRobotDataset.create(
repo_id=new_repo_id,
fps=original_dataset.fps,
fps=int(original_dataset.fps),
root=new_dataset_root,
robot_type=original_dataset.meta.robot_type,
features=original_dataset.meta.info["features"],
@@ -275,6 +275,12 @@ if __name__ == "__main__":
default="",
help="The natural language task to describe the dataset.",
)
parser.add_argument(
"--new-repo-id",
type=str,
default=None,
help="The repository id for the new cropped and resized dataset. If not provided, it defaults to `repo_id` + '_cropped_resized'.",
)
args = parser.parse_args()
dataset = LeRobotDataset(repo_id=args.repo_id, root=args.root)
@@ -294,10 +300,16 @@ if __name__ == "__main__":
for key, roi in rois.items():
print(f"{key}: {roi}")
new_repo_id = args.repo_id + "_cropped_resized"
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
new_repo_id = args.new_repo_id if args.new_repo_id else args.repo_id + "_cropped_resized"
cropped_resized_dataset = convert_lerobot_dataset_to_cropper_lerobot_dataset(
if args.new_repo_id:
new_dataset_name = args.new_repo_id.split("/")[-1]
# Parent 1: HF user, Parent 2: HF LeRobot Home
new_dataset_root = dataset.root.parent.parent / new_dataset_name
else:
new_dataset_root = Path(str(dataset.root) + "_cropped_resized")
cropped_resized_dataset = convert_lerobot_dataset_to_cropped_lerobot_dataset(
original_dataset=dataset,
crop_params_dict=rois,
new_repo_id=new_repo_id,
-2
View File
@@ -102,8 +102,6 @@ from lerobot.utils.utils import (
from .learner_service import MAX_WORKERS, SHUTDOWN_TIMEOUT, LearnerService
LOG_PREFIX = "[LEARNER]"
@parser.wrap()
def train_cli(cfg: TrainRLServerPipelineConfig):
+1 -1
View File
@@ -137,7 +137,7 @@ class WandBLogger:
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
+1 -1
View File
@@ -105,7 +105,7 @@ class HopeJrArm(Robot):
def is_calibrated(self) -> bool:
return self.bus.is_calibrated
def calibrate(self, limb_name: str = None) -> None:
def calibrate(self) -> None:
groups = {
"all": list(self.bus.motors.keys()),
"shoulder": ["shoulder_pitch", "shoulder_yaw", "shoulder_roll"],
@@ -193,16 +193,12 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
Attributes:
end_effector_bounds: A dictionary with "min" and "max" keys for position clipping.
max_ee_step_m: The maximum allowed change in position (in meters) between steps.
max_ee_twist_step_rad: The maximum allowed change in orientation (in radians) between steps.
_last_pos: Internal state storing the last commanded position.
_last_twist: Internal state storing the last commanded orientation.
"""
end_effector_bounds: dict
max_ee_step_m: float = 0.05
max_ee_twist_step_rad: float = 0.20
_last_pos: np.ndarray | None = field(default=None, init=False, repr=False)
_last_twist: np.ndarray | None = field(default=None, init=False, repr=False)
def action(self, action: RobotAction) -> RobotAction:
x = action["ee.x"]
@@ -233,7 +229,6 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
raise ValueError(f"EE jump {n:.3f}m > {self.max_ee_step_m}m")
self._last_pos = pos
self._last_twist = twist
action["ee.x"] = float(pos[0])
action["ee.y"] = float(pos[1])
@@ -246,7 +241,6 @@ class EEBoundsAndSafety(RobotActionProcessorStep):
def reset(self):
"""Resets the last known position and orientation."""
self._last_pos = None
self._last_twist = None
def transform_features(
self, features: dict[PipelineFeatureType, dict[str, PolicyFeature]]
@@ -49,5 +49,3 @@ class Stretch3RobotConfig(RobotConfig):
),
}
)
mock: bool = False
@@ -164,10 +164,6 @@ class Stretch3Robot(Robot):
# TODO(aliberts): return action_sent when motion is limited
return action
def print_logs(self) -> None:
pass
# TODO(aliberts): move robot-specific logs logic here
def teleop_safety_stop(self) -> None:
if self.teleop is not None:
self.teleop._safety_stop(robot=self)
+1 -5
View File
@@ -267,11 +267,7 @@ def record_loop(
for t in teleop
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
)
),
None,
@@ -52,10 +52,6 @@ class InputController:
"""Get the current movement deltas (dx, dy, dz) in meters."""
return 0.0, 0.0, 0.0
def should_quit(self):
"""Return True if the user has requested to quit."""
return not self.running
def update(self):
"""Update controller state - call this once per frame."""
pass
@@ -198,14 +194,6 @@ class KeyboardController(InputController):
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if ESC was pressed."""
return self.key_states["quit"]
def should_save(self):
"""Return True if Enter was pressed (save episode)."""
return self.key_states["success"] or self.key_states["failure"]
class GamepadController(InputController):
"""Generate motion deltas from gamepad input."""
@@ -351,8 +339,6 @@ class GamepadControllerHID(InputController):
# Button states
self.buttons = {}
self.quit_requested = False
self.save_requested = False
def find_device(self):
"""Look for the gamepad device by vendor and product ID."""
@@ -472,11 +458,3 @@ class GamepadControllerHID(InputController):
delta_z = -self.right_y * self.z_step_size # Up/down
return delta_x, delta_y, delta_z
def should_quit(self):
"""Return True if quit button was pressed."""
return self.quit_requested
def should_save(self):
"""Return True if save button was pressed."""
return self.save_requested
@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {
self._buffers: dict[str, deque[int]] = {
joint: deque(maxlen=n)
for joint in (
"shoulder_pitch",
@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
# running EMA value per joint lazily initialised on first read
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
@@ -22,8 +22,9 @@ from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("keyboard")
@dataclass
class KeyboardTeleopConfig(TeleoperatorConfig):
"""KeyboardTeleopConfig"""
# TODO(Steven): Consider setting in here the keys that we want to capture/listen
mock: bool = False
@TeleoperatorConfig.register_subclass("keyboard_ee")
@@ -22,4 +22,4 @@ from ..config import TeleoperatorConfig
@TeleoperatorConfig.register_subclass("stretch3")
@dataclass
class Stretch3GamePadConfig(TeleoperatorConfig):
mock: bool = False
"""Stretch3GamePadConfig"""
@@ -112,10 +112,6 @@ class Stretch3GamePad(Teleoperator):
def send_feedback(self, feedback: np.ndarray) -> None:
pass
def print_logs(self) -> None:
pass
# TODO(aliberts): move robot-specific logs logic here
def disconnect(self) -> None:
self.api.stop()
self.is_connected = False
-1
View File
@@ -33,7 +33,6 @@ TRUNCATED = "next.truncated"
DONE = "next.done"
ROBOTS = "robots"
ROBOT_TYPE = "robot_type"
TELEOPERATORS = "teleoperators"
# files & directories
-11
View File
@@ -30,14 +30,3 @@ class DeviceAlreadyConnectedError(ConnectionError):
):
self.message = message
super().__init__(self.message)
class InvalidActionError(ValueError):
"""Exception raised when an action is already invalid."""
def __init__(
self,
message="The action is invalid. Check the value follows what it is expected from the action space.",
):
self.message = message
super().__init__(self.message)
+1 -1
View File
@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
elif isinstance(val, (int | float | bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
-4
View File
@@ -330,10 +330,6 @@ class TimerManager:
def history(self) -> list[float]:
return deepcopy(self._history)
@property
def fps_history(self) -> list[float]:
return [1.0 / t for t in self._history]
@property
def fps_last(self) -> float:
return 0.0 if self.last == 0 else 1.0 / self.last
+2 -5
View File
@@ -32,11 +32,8 @@ def init_rerun(session_name: str = "lerobot_control_loop") -> None:
def _is_scalar(x):
return (
isinstance(x, float)
or isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0)
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
optimizer.zero_grad()
policy.reset()
+2 -2
View File
@@ -85,7 +85,7 @@ def policy_feature_factory():
def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, PipelineFeatureType) for k in features.keys())
assert all(isinstance(k, PipelineFeatureType) for k in features)
assert all(isinstance(v, dict) for v in features.values())
assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values())
assert all(all(isinstance(nk, str) for nk in v) for v in features.values())
assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values())
+1 -1
View File
@@ -949,7 +949,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
# Check that statistics exist for all features
assert loaded_dataset.meta.stats is not None, "No statistics found"
for feature_name in features.keys():
for feature_name in features:
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
feature_stats = loaded_dataset.meta.stats[feature_name]
-1
View File
@@ -69,7 +69,6 @@ def test_sac_config_default_initialization():
# Training parameters
assert config.online_steps == 1000000
assert config.online_env_seed == 10000
assert config.online_buffer_capacity == 100000
assert config.offline_buffer_capacity == 100000
assert config.async_prefetch is False
+5 -8
View File
@@ -246,7 +246,7 @@ def test_step_through():
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, dict)
assert all(isinstance(k, TransitionKey) for k in result.keys())
assert all(isinstance(k, TransitionKey) for k in result)
def test_step_through_with_dict():
@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
if not isinstance(multiplier, (int | float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
@@ -1623,9 +1623,7 @@ def test_override_with_callables():
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
return x * 2
elif isinstance(x, torch.Tensor):
if isinstance(x, (int | float | torch.Tensor)):
return x * 2
return x
@@ -1797,10 +1795,9 @@ def test_from_pretrained_nonexistent_path():
)
# Test with a local directory that exists but has no config files
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(FileNotFoundError):
# Since the directory exists but has no config, it will raise FileNotFoundError
with pytest.raises(FileNotFoundError):
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
def test_save_load_with_custom_converter_functions():
+1 -4
View File
@@ -32,10 +32,7 @@ class MockTokenizer:
**kwargs,
) -> dict[str, torch.Tensor]:
"""Mock tokenization that returns deterministic tokens based on text."""
if isinstance(text, str):
texts = [text]
else:
texts = text
texts = [text] if isinstance(text, str) else text
batch_size = len(texts)
+5 -5
View File
@@ -245,14 +245,14 @@ def test_get_observation(reachy2):
obs = reachy2.get_observation()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
if reachy2.config.with_left_teleop_camera:
assert obs["teleop_left"].shape == (
@@ -282,7 +282,7 @@ def test_send_action(reachy2):
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
previous_present_position = {
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict
}
returned = reachy2.send_action(action)
@@ -290,7 +290,7 @@ def test_send_action(reachy2):
assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.max_relative_target is None:
@@ -121,20 +121,20 @@ def test_get_action(reachy2):
action = reachy2.get_action()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
assert set(action.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
if reachy2.config.use_present_position:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
else:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.with_mobile_base:
if reachy2.config.use_present_position:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
else:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
+1 -1
View File
@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)):
elif isinstance(obj, (list | tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):