mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-25 13:40:00 +00:00
Merge branch 'main' into feat/add_pi
This commit is contained in:
+1
-1
@@ -201,7 +201,7 @@ exclude = ["tests/artifacts/**/*.safetensors", "*_pb2.py", "*_pb2_grpc.py"]
|
||||
# N: pep8-naming
|
||||
# TODO: Uncomment rules when ready to use
|
||||
select = [
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N" # "SIM", "A", "S", "D", "RUF", "UP"
|
||||
"E", "W", "F", "I", "B", "C4", "T20", "N", "UP", "SIM" #, "A", "S", "D", "RUF"
|
||||
]
|
||||
ignore = [
|
||||
"E501", # Line too long
|
||||
|
||||
@@ -1421,7 +1421,7 @@ class MultiLeRobotDataset(torch.utils.data.Dataset):
|
||||
"""Keys to access image and video stream from cameras."""
|
||||
keys = []
|
||||
for key, feats in self.features.items():
|
||||
if isinstance(feats, (datasets.Image, VideoFrame)):
|
||||
if isinstance(feats, (datasets.Image | VideoFrame)):
|
||||
keys.append(key)
|
||||
return keys
|
||||
|
||||
|
||||
@@ -120,7 +120,7 @@ class SharpnessJitter(Transform):
|
||||
self.sharpness = self._check_input(sharpness)
|
||||
|
||||
def _check_input(self, sharpness):
|
||||
if isinstance(sharpness, (int, float)):
|
||||
if isinstance(sharpness, (int | float)):
|
||||
if sharpness < 0:
|
||||
raise ValueError("If sharpness is a single number, it must be non negative.")
|
||||
sharpness = [1.0 - sharpness, 1.0 + sharpness]
|
||||
|
||||
@@ -21,7 +21,7 @@ from collections import deque
|
||||
from collections.abc import Iterable, Iterator
|
||||
from pathlib import Path
|
||||
from pprint import pformat
|
||||
from typing import Any, Deque, Generic, TypeVar
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
import datasets
|
||||
import numpy as np
|
||||
@@ -207,13 +207,13 @@ def serialize_dict(stats: dict[str, torch.Tensor | np.ndarray | dict]) -> dict:
|
||||
"""
|
||||
serialized_dict = {}
|
||||
for key, value in flatten_dict(stats).items():
|
||||
if isinstance(value, (torch.Tensor, np.ndarray)):
|
||||
if isinstance(value, (torch.Tensor | np.ndarray)):
|
||||
serialized_dict[key] = value.tolist()
|
||||
elif isinstance(value, list) and isinstance(value[0], (int, float, list)):
|
||||
elif isinstance(value, list) and isinstance(value[0], (int | float | list)):
|
||||
serialized_dict[key] = value
|
||||
elif isinstance(value, np.generic):
|
||||
serialized_dict[key] = value.item()
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
serialized_dict[key] = value
|
||||
else:
|
||||
raise NotImplementedError(f"The value '{value}' of type '{type(value)}' is not supported.")
|
||||
@@ -1179,7 +1179,7 @@ def item_to_torch(item: dict) -> dict:
|
||||
dict: Dictionary with all tensor-like items converted to torch.Tensor.
|
||||
"""
|
||||
for key, val in item.items():
|
||||
if isinstance(val, (np.ndarray, list)) and key not in ["task"]:
|
||||
if isinstance(val, (np.ndarray | list)) and key not in ["task"]:
|
||||
# Convert numpy arrays and lists to torch tensors
|
||||
item[key] = torch.tensor(val)
|
||||
return item
|
||||
@@ -1253,8 +1253,8 @@ class Backtrackable(Generic[T]):
|
||||
raise ValueError("lookahead must be > 0")
|
||||
|
||||
self._source: Iterator[T] = iter(iterable)
|
||||
self._back_buf: Deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: Deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._back_buf: deque[T] = deque(maxlen=history)
|
||||
self._ahead_buf: deque[T] = deque(maxlen=lookahead) if lookahead > 0 else deque()
|
||||
self._cursor: int = 0
|
||||
self._history = history
|
||||
self._lookahead = lookahead
|
||||
|
||||
@@ -437,7 +437,9 @@ def concatenate_video_files(
|
||||
tmp_concatenate_path, mode="r", format="concat", options={"safe": "0"}
|
||||
) # safe = 0 allows absolute paths as well as relative paths
|
||||
|
||||
tmp_output_video_path = tempfile.NamedTemporaryFile(suffix=".mp4", delete=False).name
|
||||
with tempfile.NamedTemporaryFile(suffix=".mp4", delete=False) as tmp_named_file:
|
||||
tmp_output_video_path = tmp_named_file.name
|
||||
|
||||
output_container = av.open(
|
||||
tmp_output_video_path, mode="w", options={"movflags": "faststart"}
|
||||
) # faststart is to move the metadata to the beginning of the file to speed up loading
|
||||
|
||||
@@ -35,7 +35,7 @@ def _parse_camera_names(camera_name: str | Sequence[str]) -> list[str]:
|
||||
"""Normalize camera_name into a non-empty list of strings."""
|
||||
if isinstance(camera_name, str):
|
||||
cams = [c.strip() for c in camera_name.split(",") if c.strip()]
|
||||
elif isinstance(camera_name, (list, tuple)):
|
||||
elif isinstance(camera_name, (list | tuple)):
|
||||
cams = [str(c).strip() for c in camera_name if str(c).strip()]
|
||||
else:
|
||||
raise TypeError(f"camera_name must be str or sequence[str], got {type(camera_name).__name__}")
|
||||
|
||||
@@ -183,10 +183,10 @@ def _(env: Mapping) -> None:
|
||||
|
||||
@close_envs.register
|
||||
def _(envs: Sequence) -> None:
|
||||
if isinstance(envs, (str, bytes)):
|
||||
if isinstance(envs, (str | bytes)):
|
||||
return
|
||||
for v in envs:
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str, bytes)):
|
||||
if isinstance(v, Mapping) or isinstance(v, Sequence) and not isinstance(v, (str | bytes)):
|
||||
close_envs(v)
|
||||
elif hasattr(v, "close"):
|
||||
_close_single_env(v)
|
||||
|
||||
@@ -342,7 +342,7 @@ class MotorsBus(abc.ABC):
|
||||
raise TypeError(motors)
|
||||
|
||||
def _get_ids_values_dict(self, values: Value | dict[str, Value] | None) -> list[str]:
|
||||
if isinstance(values, (int, float)):
|
||||
if isinstance(values, (int | float)):
|
||||
return dict.fromkeys(self.ids, values)
|
||||
elif isinstance(values, dict):
|
||||
return {self.motors[motor].id: val for motor, val in values.items()}
|
||||
@@ -669,7 +669,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -697,7 +697,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
@@ -733,7 +733,7 @@ class MotorsBus(abc.ABC):
|
||||
"""
|
||||
if motors is None:
|
||||
motors = list(self.motors)
|
||||
elif isinstance(motors, (str, int)):
|
||||
elif isinstance(motors, (str | int)):
|
||||
motors = [motors]
|
||||
elif not isinstance(motors, list):
|
||||
raise TypeError(motors)
|
||||
|
||||
@@ -398,10 +398,7 @@ class ACT(nn.Module):
|
||||
"actions must be provided when using the variational objective in training mode."
|
||||
)
|
||||
|
||||
if OBS_IMAGES in batch:
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0]
|
||||
else:
|
||||
batch_size = batch[OBS_ENV_STATE].shape[0]
|
||||
batch_size = batch[OBS_IMAGES][0].shape[0] if OBS_IMAGES in batch else batch[OBS_ENV_STATE].shape[0]
|
||||
|
||||
# Prepare the latent for input to the transformer encoder.
|
||||
if self.config.use_vae and ACTION in batch and self.training:
|
||||
|
||||
@@ -260,13 +260,11 @@ class GPT(nn.Module):
|
||||
param_dict = dict(self.named_parameters())
|
||||
inter_params = decay & no_decay
|
||||
union_params = decay | no_decay
|
||||
assert len(inter_params) == 0, "parameters {} made it into both decay/no_decay sets!".format(
|
||||
str(inter_params)
|
||||
assert len(inter_params) == 0, (
|
||||
f"parameters {str(inter_params)} made it into both decay/no_decay sets!"
|
||||
)
|
||||
assert len(param_dict.keys() - union_params) == 0, (
|
||||
"parameters {} were not separated into either decay/no_decay set!".format(
|
||||
str(param_dict.keys() - union_params),
|
||||
)
|
||||
f"parameters {str(param_dict.keys() - union_params)} were not separated into either decay/no_decay set!"
|
||||
)
|
||||
|
||||
decay = [param_dict[pn] for pn in sorted(decay)]
|
||||
|
||||
@@ -340,7 +340,7 @@ class GripperPenaltyProcessorStep(ComplementaryDataProcessorStep):
|
||||
"""
|
||||
action = self.transition.get(TransitionKey.ACTION)
|
||||
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions", None)
|
||||
raw_joint_positions = complementary_data.get("raw_joint_positions")
|
||||
if raw_joint_positions is None:
|
||||
return complementary_data
|
||||
|
||||
|
||||
@@ -119,13 +119,12 @@ class _NormalizationMixin:
|
||||
)
|
||||
self.features = reconstructed
|
||||
|
||||
if self.norm_map:
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if all(isinstance(k, str) for k in self.norm_map.keys()):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
# if keys are strings (JSON), rebuild enum map
|
||||
if self.norm_map and all(isinstance(k, str) for k in self.norm_map):
|
||||
reconstructed = {}
|
||||
for ft_type_str, norm_mode_str in self.norm_map.items():
|
||||
reconstructed[FeatureType(ft_type_str)] = NormalizationMode(norm_mode_str)
|
||||
self.norm_map = reconstructed
|
||||
|
||||
# Convert stats to tensors and move to the target device once during initialization.
|
||||
self.stats = self.stats or {}
|
||||
|
||||
@@ -152,7 +152,7 @@ class VanillaObservationProcessorStep(ObservationProcessorStep):
|
||||
"""
|
||||
# Build a new features mapping keyed by the same FeatureType buckets
|
||||
# We assume callers already placed features in the correct FeatureType.
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features.keys()}
|
||||
new_features: dict[PipelineFeatureType, dict[str, PolicyFeature]] = {ft: {} for ft in features}
|
||||
|
||||
exact_pairs = {
|
||||
"pixels": OBS_IMAGE,
|
||||
|
||||
@@ -176,7 +176,7 @@ class ReplayBuffer:
|
||||
self.complementary_info[key] = torch.empty(
|
||||
(self.capacity, *value_shape), device=self.storage_device
|
||||
)
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
# Handle scalar values similar to reward
|
||||
self.complementary_info[key] = torch.empty((self.capacity,), device=self.storage_device)
|
||||
else:
|
||||
@@ -223,7 +223,7 @@ class ReplayBuffer:
|
||||
value = complementary_info[key]
|
||||
if isinstance(value, torch.Tensor):
|
||||
self.complementary_info[key][self.position].copy_(value.squeeze(dim=0))
|
||||
elif isinstance(value, (int, float)):
|
||||
elif isinstance(value, (int | float)):
|
||||
self.complementary_info[key][self.position] = value
|
||||
|
||||
self.position = (self.position + 1) % self.capacity
|
||||
|
||||
@@ -137,7 +137,7 @@ class WandBLogger:
|
||||
self._wandb.define_metric(new_custom_key, hidden=True)
|
||||
|
||||
for k, v in d.items():
|
||||
if not isinstance(v, (int, float, str)):
|
||||
if not isinstance(v, (int | float | str)):
|
||||
logging.warning(
|
||||
f'WandB logging of key "{k}" was ignored as its type "{type(v)}" is not handled by this wrapper.'
|
||||
)
|
||||
|
||||
@@ -267,11 +267,7 @@ def record_loop(
|
||||
for t in teleop
|
||||
if isinstance(
|
||||
t,
|
||||
(
|
||||
so100_leader.SO100Leader,
|
||||
so101_leader.SO101Leader,
|
||||
koch_leader.KochLeader,
|
||||
),
|
||||
(so100_leader.SO100Leader | so101_leader.SO101Leader | koch_leader.KochLeader),
|
||||
)
|
||||
),
|
||||
None,
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque
|
||||
|
||||
import serial
|
||||
|
||||
@@ -60,7 +59,7 @@ class HomunculusArm(Teleoperator):
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: dict[str, Deque[int]] = {
|
||||
self._buffers: dict[str, deque[int]] = {
|
||||
joint: deque(maxlen=n)
|
||||
for joint in (
|
||||
"shoulder_pitch",
|
||||
|
||||
@@ -18,7 +18,6 @@ import logging
|
||||
import threading
|
||||
from collections import deque
|
||||
from pprint import pformat
|
||||
from typing import Deque
|
||||
|
||||
import serial
|
||||
|
||||
@@ -97,7 +96,7 @@ class HomunculusGlove(Teleoperator):
|
||||
self.n: int = n
|
||||
self.alpha: float = 2 / (n + 1)
|
||||
# one deque *per joint* so we can inspect raw history if needed
|
||||
self._buffers: dict[str, Deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
|
||||
self._buffers: dict[str, deque[int]] = {joint: deque(maxlen=n) for joint in self.joints}
|
||||
# running EMA value per joint – lazily initialised on first read
|
||||
self._ema: dict[str, float | None] = dict.fromkeys(self._buffers)
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ def move_transition_to_device(transition: Transition, device: str = "cpu") -> Tr
|
||||
for key, val in transition["complementary_info"].items():
|
||||
if isinstance(val, torch.Tensor):
|
||||
transition["complementary_info"][key] = val.to(device, non_blocking=non_blocking)
|
||||
elif isinstance(val, (int, float, bool)):
|
||||
elif isinstance(val, (int | float | bool)):
|
||||
transition["complementary_info"][key] = torch.tensor(val, device=device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported type {type(val)} for complementary_info[{key}]")
|
||||
|
||||
@@ -32,11 +32,8 @@ def init_rerun(session_name: str = "lerobot_control_loop") -> None:
|
||||
|
||||
|
||||
def _is_scalar(x):
|
||||
return (
|
||||
isinstance(x, float)
|
||||
or isinstance(x, numbers.Real)
|
||||
or isinstance(x, (np.integer, np.floating))
|
||||
or (isinstance(x, np.ndarray) and x.ndim == 0)
|
||||
return isinstance(x, (float | numbers.Real | np.integer | np.floating)) or (
|
||||
isinstance(x, np.ndarray) and x.ndim == 0
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -66,15 +66,13 @@ def get_policy_stats(ds_repo_id: str, policy_name: str, policy_kwargs: dict):
|
||||
for key, param in policy.named_parameters():
|
||||
if param.requires_grad:
|
||||
grad_stats[f"{key}_mean"] = param.grad.mean()
|
||||
grad_stats[f"{key}_std"] = (
|
||||
param.grad.std() if param.grad.numel() > 1 else torch.tensor(float(0.0))
|
||||
)
|
||||
grad_stats[f"{key}_std"] = param.grad.std() if param.grad.numel() > 1 else torch.tensor(0.0)
|
||||
|
||||
optimizer.step()
|
||||
param_stats = {}
|
||||
for key, param in policy.named_parameters():
|
||||
param_stats[f"{key}_mean"] = param.mean()
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(float(0.0))
|
||||
param_stats[f"{key}_std"] = param.std() if param.numel() > 1 else torch.tensor(0.0)
|
||||
|
||||
optimizer.zero_grad()
|
||||
policy.reset()
|
||||
|
||||
+2
-2
@@ -85,7 +85,7 @@ def policy_feature_factory():
|
||||
|
||||
def assert_contract_is_typed(features: dict[PipelineFeatureType, dict[str, PolicyFeature]]) -> None:
|
||||
assert isinstance(features, dict)
|
||||
assert all(isinstance(k, PipelineFeatureType) for k in features.keys())
|
||||
assert all(isinstance(k, PipelineFeatureType) for k in features)
|
||||
assert all(isinstance(v, dict) for v in features.values())
|
||||
assert all(all(isinstance(nk, str) for nk in v.keys()) for v in features.values())
|
||||
assert all(all(isinstance(nk, str) for nk in v) for v in features.values())
|
||||
assert all(all(isinstance(nv, PolicyFeature) for nv in v.values()) for v in features.values())
|
||||
|
||||
@@ -949,7 +949,7 @@ def test_statistics_metadata_validation(tmp_path, empty_lerobot_dataset_factory)
|
||||
# Check that statistics exist for all features
|
||||
assert loaded_dataset.meta.stats is not None, "No statistics found"
|
||||
|
||||
for feature_name in features.keys():
|
||||
for feature_name in features:
|
||||
assert feature_name in loaded_dataset.meta.stats, f"No statistics for feature '{feature_name}'"
|
||||
|
||||
feature_stats = loaded_dataset.meta.stats[feature_name]
|
||||
|
||||
@@ -246,7 +246,7 @@ def test_step_through():
|
||||
# Ensure all results are dicts (same format as input)
|
||||
for result in results:
|
||||
assert isinstance(result, dict)
|
||||
assert all(isinstance(k, TransitionKey) for k in result.keys())
|
||||
assert all(isinstance(k, TransitionKey) for k in result)
|
||||
|
||||
|
||||
def test_step_through_with_dict():
|
||||
@@ -770,7 +770,7 @@ class MockStepWithNonSerializableParam(ProcessorStep):
|
||||
# Add type validation for multiplier
|
||||
if isinstance(multiplier, str):
|
||||
raise ValueError(f"multiplier must be a number, got string '{multiplier}'")
|
||||
if not isinstance(multiplier, (int, float)):
|
||||
if not isinstance(multiplier, (int | float)):
|
||||
raise TypeError(f"multiplier must be a number, got {type(multiplier).__name__}")
|
||||
self.multiplier = float(multiplier)
|
||||
self.env = env # Non-serializable parameter (like gym.Env)
|
||||
@@ -1623,9 +1623,7 @@ def test_override_with_callables():
|
||||
|
||||
# Define a transform function
|
||||
def double_values(x):
|
||||
if isinstance(x, (int, float)):
|
||||
return x * 2
|
||||
elif isinstance(x, torch.Tensor):
|
||||
if isinstance(x, (int | float | torch.Tensor)):
|
||||
return x * 2
|
||||
return x
|
||||
|
||||
@@ -1797,10 +1795,9 @@ def test_from_pretrained_nonexistent_path():
|
||||
)
|
||||
|
||||
# Test with a local directory that exists but has no config files
|
||||
with tempfile.TemporaryDirectory() as tmp_dir:
|
||||
with tempfile.TemporaryDirectory() as tmp_dir, pytest.raises(FileNotFoundError):
|
||||
# Since the directory exists but has no config, it will raise FileNotFoundError
|
||||
with pytest.raises(FileNotFoundError):
|
||||
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
|
||||
DataProcessorPipeline.from_pretrained(tmp_dir, config_filename="processor.json")
|
||||
|
||||
|
||||
def test_save_load_with_custom_converter_functions():
|
||||
|
||||
@@ -32,10 +32,7 @@ class MockTokenizer:
|
||||
**kwargs,
|
||||
) -> dict[str, torch.Tensor]:
|
||||
"""Mock tokenization that returns deterministic tokens based on text."""
|
||||
if isinstance(text, str):
|
||||
texts = [text]
|
||||
else:
|
||||
texts = text
|
||||
texts = [text] if isinstance(text, str) else text
|
||||
|
||||
batch_size = len(texts)
|
||||
|
||||
|
||||
@@ -245,14 +245,14 @@ def test_get_observation(reachy2):
|
||||
obs = reachy2.get_observation()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(reachy2.cameras.keys())
|
||||
assert set(obs.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
for motor in reachy2.joints_dict:
|
||||
assert obs[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
for vel in REACHY2_VEL:
|
||||
assert obs[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
if reachy2.config.with_left_teleop_camera:
|
||||
assert obs["teleop_left"].shape == (
|
||||
@@ -282,7 +282,7 @@ def test_send_action(reachy2):
|
||||
action.update({k: i * 0.1 for i, k in enumerate(REACHY2_VEL.keys(), start=1)})
|
||||
|
||||
previous_present_position = {
|
||||
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict.keys()
|
||||
k: reachy2.reachy.joints[REACHY2_JOINTS[k]].present_position for k in reachy2.joints_dict
|
||||
}
|
||||
returned = reachy2.send_action(action)
|
||||
|
||||
@@ -290,7 +290,7 @@ def test_send_action(reachy2):
|
||||
assert returned == action
|
||||
|
||||
assert reachy2.reachy._goal_position_set_total == len(reachy2.joints_dict)
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
for motor in reachy2.joints_dict:
|
||||
expected_pos = action[motor]
|
||||
real_pos = reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.max_relative_target is None:
|
||||
|
||||
@@ -121,20 +121,20 @@ def test_get_action(reachy2):
|
||||
action = reachy2.get_action()
|
||||
|
||||
expected_keys = set(reachy2.joints_dict)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL.keys() if reachy2.config.with_mobile_base)
|
||||
expected_keys.update(f"{v}" for v in REACHY2_VEL if reachy2.config.with_mobile_base)
|
||||
assert set(action.keys()) == expected_keys
|
||||
|
||||
for motor in reachy2.joints_dict.keys():
|
||||
for motor in reachy2.joints_dict:
|
||||
if reachy2.config.use_present_position:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].present_position
|
||||
else:
|
||||
assert action[motor] == reachy2.reachy.joints[REACHY2_JOINTS[motor]].goal_position
|
||||
if reachy2.config.with_mobile_base:
|
||||
if reachy2.config.use_present_position:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
for vel in REACHY2_VEL:
|
||||
assert action[vel] == reachy2.reachy.mobile_base.odometry[REACHY2_VEL[vel]]
|
||||
else:
|
||||
for vel in REACHY2_VEL.keys():
|
||||
for vel in REACHY2_VEL:
|
||||
assert action[vel] == reachy2.reachy.mobile_base.last_cmd_vel[REACHY2_VEL[vel]]
|
||||
|
||||
|
||||
|
||||
@@ -121,7 +121,7 @@ def get_tensors_memory_consumption(obj, visited_addresses):
|
||||
|
||||
if isinstance(obj, torch.Tensor):
|
||||
return get_tensor_memory_consumption(obj)
|
||||
elif isinstance(obj, (list, tuple)):
|
||||
elif isinstance(obj, (list | tuple)):
|
||||
for item in obj:
|
||||
total_size += get_tensors_memory_consumption(item, visited_addresses)
|
||||
elif isinstance(obj, dict):
|
||||
|
||||
Reference in New Issue
Block a user