Merge branch 'main' into feat/add_pi

This commit is contained in:
Pepijn
2025-09-29 16:02:36 +02:00
28 changed files with 62 additions and 83 deletions
+1 -1
View File
@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
# N: pep8-naming
# TODO: Uncomment rules when ready to use
select = [
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
]
ignore = [
"E501", # Line too long
+1 -1
View File
@@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
"""Keys to access image and video stream from cameras."""
keys = []
for key, feats in self.features.items():
if isinstance(feats, (datasets.Image, VideoFrame)):
if isinstance(feats, (datasets.Image | VideoFrame)):
keys.append(key)
return keys
+1 -1
View File
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
self.sharpness = self._check_input(sharpness)
def _check_input(self, sharpness):
if isinstance(sharpness, (int, float)):
if isinstance(sharpness, (int | float)):
if sharpness < 0:
raise ValueError("If sharpness is a single number, it must be non negative.")
sharpness = [1.0 - sharpness, 1.0 + sharpness]
+7 -7
View File
@@ -21,7 +21,7 @@ from collections import deque
from collections.abc import Iterable, Iterator
from pathlib import Path
from pprint import pformat
from typing import Any, Deque, Generic, TypeVar
from typing import Any, Generic, TypeVar
import datasets
import numpy as np
@@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
"""
serialized_dict = {}
for key, value in flatten_dict(stats).items():
if isinstance(value, (torch.Tensor, np.ndarray)):
if isinstance(value, (torch.Tensor | np.ndarray)):
serialized_dict[key] = value.tolist()
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
serialized_dict[key] = value
elif isinstance(value, np.generic):
serialized_dict[key] = value.item()
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
serialized_dict[key] = value
else:
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
@@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
dict: Dictionary with all tensor-like items converted to torch.Tensor.
"""
for key, val in item.items():
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
# Convert numpy arrays and lists to torch tensors
item[key] = torch.tensor(val)
return item
@@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]):
raise ValueError("lookahead must be > 0")
self._source: Iterator[T] = iter(iterable)
self._back_buf: Deque[T] = deque(maxlen=history)
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._back_buf: deque[T] = deque(maxlen=history)
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
self._cursor: int = 0
self._history = history
self._lookahead = lookahead
+3 -1
View File
@@ -437,7 +437,9 @@ def concatenate_video_files(
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
) # safe = 0 allows absolute paths as well as relative paths
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
tmp_output_video_path = tmp_named_file.name
output_container = av.open(
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
) # faststart is to move the metadata to the beginning of the file to speed up loading
+1 -1
View File
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
"""Normalize camera_name into a non-empty list of strings."""
if isinstance(camera_name, str):
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
elif isinstance(camera_name, (list, tuple)):
elif isinstance(camera_name, (list | tuple)):
cams = [str(c).strip() for c in camera_name if str(c).strip()]
else:
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
+2 -2
View File
@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
@close_envs.register
def _(envs: Sequence) -> None:
if isinstance(envs, (str, bytes)):
if isinstance(envs, (str | bytes)):
return
for v in envs:
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
close_envs(v)
elif hasattr(v, "close"):
_close_single_env(v)
+4 -4
View File
@@ -342,7 +342,7 @@ class MotorsBus(abc.ABC):
raise TypeError(motors)
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
if isinstance(values, (int, float)):
if isinstance(values, (int | float)):
return dict.fromkeys(self.ids, values)
elif isinstance(values, dict):
return {self.motors[motor].id: val for motor, val in values.items()}
@@ -669,7 +669,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -697,7 +697,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
@@ -733,7 +733,7 @@ class MotorsBus(abc.ABC):
"""
if motors is None:
motors = list(self.motors)
elif isinstance(motors, (str, int)):
elif isinstance(motors, (str | int)):
motors = [motors]
elif not isinstance(motors, list):
raise TypeError(motors)
+1 -4
View File
@@ -398,10 +398,7 @@ class ACT(nn.Module):
"actions must be provided when using the variational objective in training mode."
)
if OBS_IMAGES in batch:
batch_size = batch[OBS_IMAGES][0].shape[0]
else:
batch_size = batch[OBS_ENV_STATE].shape[0]
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
# Prepare the latent for input to the transformer encoder.
if self.config.use_vae and ACTION in batch and self.training:
+3 -5
View File
@@ -260,13 +260,11 @@ class GPT(nn.Module):
param_dict = dict(self.named_parameters())
inter_params = decay & no_decay
union_params = decay | no_decay
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
str(inter_params)
assert len(inter_params) == 0, (
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
)
assert len(param_dict.keys() - union_params) == 0, (
"parameters {} were not separated into either decay/no_decay set!".format(
str(param_dict.keys() - union_params),
)
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
)
decay = [param_dict[pn] for pn in sorted(decay)]
+1 -1
View File
@@ -340,7 +340,7 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
"""
action = self.transition.get(TransitionKey.ACTION)
raw_joint_positions = complementary_data.get("raw_joint_positions", None)
raw_joint_positions = complementary_data.get("raw_joint_positions")
if raw_joint_positions is None:
return complementary_data
+6 -7
View File
@@ -119,13 +119,12 @@ class _NormalizationMixin:
)
self.features = reconstructed
if self.norm_map:
# if keys are strings (JSON), rebuild enum map
if all(isinstance(k, str) for k in self.norm_map.keys()):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
# if keys are strings (JSON), rebuild enum map
if self.norm_map and all(isinstance(k, str) for k in self.norm_map):
reconstructed = {}
for ft_type_str, norm_mode_str in self.norm_map.items():
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
self.norm_map = reconstructed
# Convert stats to tensors and move to the target device once during initialization.
self.stats = self.stats or {}
@@ -152,7 +152,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
"""
# Build a new features mapping keyed by the same FeatureType buckets
# We assume callers already placed features in the correct FeatureType.
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
exact_pairs = {
"pixels": OBS_IMAGE,
+2 -2
View File
@@ -176,7 +176,7 @@ class ReplayBuffer:
self.complementary_info[key] = torch.empty(
(self.capacity, *value_shape), device=self.storage_device
)
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
# Handle scalar values similar to reward
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
else:
@@ -223,7 +223,7 @@ class ReplayBuffer:
value = complementary_info[key]
if isinstance(value, torch.Tensor):
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
elif isinstance(value, (int, float)):
elif isinstance(value, (int | float)):
self.complementary_info[key][self.position] = value
self.position = (self.position + 1) % self.capacity
+1 -1
View File
@@ -137,7 +137,7 @@ class WandBLogger:
self._wandb.define_metric(new_custom_key, hidden=True)
for k, v in d.items():
if not isinstance(v, (int, float, str)):
if not isinstance(v, (int | float | str)):
logging.warning(
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
)
+1 -5
View File
@@ -267,11 +267,7 @@ def record_loop(
for t in teleop
if isinstance(
t,
(
so100_leader.SO100Leader,
so101_leader.SO101Leader,
koch_leader.KochLeader,
),
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
)
),
None,
@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {
self._buffers: dict[str, deque[int]] = {
joint: deque(maxlen=n)
for joint in (
"shoulder_pitch",
@@ -18,7 +18,6 @@ import logging
import threading
from collections import deque
from pprint import pformat
from typing import Deque
import serial
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
self.n: int = n
self.alpha: float = 2 / (n + 1)
# one deque *per joint* so we can inspect raw history if needed
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
# running EMA value per joint lazily initialised on first read
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
+1 -1
View File
@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
for key, val in transition["complementary_info"].items():
if isinstance(val, torch.Tensor):
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
elif isinstance(val, (int, float, bool)):
elif isinstance(val, (int | float | bool)):
transition["complementary_info"][key] = torch.tensor(val, device=device)
else:
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
+2 -5
View File
@@ -32,11 +32,8 @@ def init_rerun(session_name: str = "lerobot_control_loop") -> None:
def _is_scalar(x):
return (
isinstance(x, float)
or isinstance(x, numbers.Real)
or isinstance(x, (np.integer, np.floating))
or (isinstance(x, np.ndarray) and x.ndim == 0)
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
isinstance(x, np.ndarray) and x.ndim == 0
)
@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
for key, param in policy.named_parameters():
if param.requires_grad:
grad_stats[f"{key}_mean"] = param.grad.mean()
grad_stats[f"{key}_std"] = (
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
)
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
optimizer.step()
param_stats = {}
for key, param in policy.named_parameters():
param_stats[f"{key}_mean"] = param.mean()
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
optimizer.zero_grad()
policy.reset()
+2 -2
View File
@@ -85,7 +85,7 @@ def policy_feature_factory():
def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None:
assert isinstance(features, dict)
assert all(isinstance(k, PipelineFeatureType) for k in features.keys())
assert all(isinstance(k, PipelineFeatureType) for k in features)
assert all(isinstance(v, dict) for v in features.values())
assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values())
assert all(all(isinstance(nk, str) for nk in v) for v in features.values())
assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values())
+1 -1
View File
@@ -949,7 +949,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
# Check that statistics exist for all features
assert loaded_dataset.meta.stats is not None, "No statistics found"
for feature_name in features.keys():
for feature_name in features:
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
feature_stats = loaded_dataset.meta.stats[feature_name]
+5 -8
View File
@@ -246,7 +246,7 @@ def test_step_through():
# Ensure all results are dicts (same format as input)
for result in results:
assert isinstance(result, dict)
assert all(isinstance(k, TransitionKey) for k in result.keys())
assert all(isinstance(k, TransitionKey) for k in result)
def test_step_through_with_dict():
@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
# Add type validation for multiplier
if isinstance(multiplier, str):
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
if not isinstance(multiplier, (int, float)):
if not isinstance(multiplier, (int | float)):
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
self.multiplier = float(multiplier)
self.env = env # Non-serializable parameter (like gym.Env)
@@ -1623,9 +1623,7 @@ def test_override_with_callables():
# Define a transform function
def double_values(x):
if isinstance(x, (int, float)):
return x * 2
elif isinstance(x, torch.Tensor):
if isinstance(x, (int | float | torch.Tensor)):
return x * 2
return x
@@ -1797,10 +1795,9 @@ def test_from_pretrained_nonexistent_path():
)
# Test with a local directory that exists but has no config files
with tempfile.TemporaryDirectory() as tmp_dir:
with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(FileNotFoundError):
# Since the directory exists but has no config, it will raise FileNotFoundError
with pytest.raises(FileNotFoundError):
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
def test_save_load_with_custom_converter_functions():
+1 -4
View File
@@ -32,10 +32,7 @@ class MockTokenizer:
**kwargs,
) -> dict[str, torch.Tensor]:
"""Mock tokenization that returns deterministic tokens based on text."""
if isinstance(text, str):
texts = [text]
else:
texts = text
texts = [text] if isinstance(text, str) else text
batch_size = len(texts)
+5 -5
View File
@@ -245,14 +245,14 @@ def test_get_observation(reachy2):
obs = reachy2.get_observation()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
expected_keys.update(reachy2.cameras.keys())
assert set(obs.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
if reachy2.config.with_mobile_base:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
if reachy2.config.with_left_teleop_camera:
assert obs["teleop_left"].shape == (
@@ -282,7 +282,7 @@ def test_send_action(reachy2):
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
previous_present_position = {
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict
}
returned = reachy2.send_action(action)
@@ -290,7 +290,7 @@ def test_send_action(reachy2):
assert returned == action
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
expected_pos = action[motor]
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.max_relative_target is None:
@@ -121,20 +121,20 @@ def test_get_action(reachy2):
action = reachy2.get_action()
expected_keys = set(reachy2.joints_dict)
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
assert set(action.keys()) == expected_keys
for motor in reachy2.joints_dict.keys():
for motor in reachy2.joints_dict:
if reachy2.config.use_present_position:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
else:
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
if reachy2.config.with_mobile_base:
if reachy2.config.use_present_position:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
else:
for vel in REACHY2_VEL.keys():
for vel in REACHY2_VEL:
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
+1 -1
View File
@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
if isinstance(obj, torch.Tensor):
return get_tensor_memory_consumption(obj)
elif isinstance(obj, (list, tuple)):
elif isinstance(obj, (list | tuple)):
for item in obj:
total_size += get_tensors_memory_consumption(item, visited_addresses)
elif isinstance(obj, dict):