Compare commits

..

12 Commits

Author SHA1 Message Date
CarolinePascal fcd8ab5800 fix(claude): claude reviews 2026-06-10 20:25:12 +02:00
CarolinePascal ee6eb745b8 chore(imports): cleaning up imports 2026-06-10 20:00:08 +02:00
CarolinePascal 27b482adf7 chore(simplification): removing no longer needed reshape 2026-06-10 19:50:26 +02:00
CarolinePascal 21d158e066 chore(colors): removing unreliable colors 2026-06-10 19:46:04 +02:00
CarolinePascal 22991ed69a test(update): update tests 2026-06-10 19:32:14 +02:00
CarolinePascal 1adc7a0309 feat(grid): Leveraging rerun's automatic grid arangement for improved layout 2026-06-10 19:23:55 +02:00
CarolinePascal f72fc3b4ba feat(blueprints): switching to blueprints for backwards (and forward) compatibiltiy 2026-06-10 19:23:55 +02:00
CarolinePascal dabf88ef9f feat(blueprints): switching to blueprints for backwards (and forward) compatibiltiy 2026-06-10 19:23:55 +02:00
CarolinePascal 2c47217825 feat(features names and color): improving features names and display colors when replaying an episode 2026-06-10 19:23:54 +02:00
CarolinePascal 9c502e204e chore(format): formatting code 2026-06-10 19:23:54 +02:00
CarolinePascal c55df19e6c chore(updae): update rerun logging to use the latest features 2026-06-10 15:24:03 +02:00
ntjohnson1 c91f345092 Update upper bound to latest rerun-sdk 2026-06-10 15:24:03 +02:00
13 changed files with 242 additions and 206 deletions
+1 -1
View File
@@ -124,7 +124,7 @@ hardware = [
"lerobot[deepdiff-dep]",
]
viz = [
"rerun-sdk>=0.24.0,<0.27.0",
"rerun-sdk>=0.24.0,<0.34.0",
]
# ── User-facing composite extras (map to CLI scripts) ─────
# lerobot-record, lerobot-replay, lerobot-calibrate, lerobot-teleoperate, etc.
+1 -7
View File
@@ -30,7 +30,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: int = 0,
drop_n_last_frames: int = 0,
shuffle: bool = False,
generator: torch.Generator | None = None,
):
"""Sampler that optionally incorporates episode boundary information.
@@ -42,10 +41,6 @@ class EpisodeAwareSampler:
drop_n_first_frames: Number of frames to drop from the start of each episode.
drop_n_last_frames: Number of frames to drop from the end of each episode.
shuffle: Whether to shuffle the indices.
generator: Generator used for shuffling. Exposing this attribute (even when None) lets
`accelerate` register it as the synchronized RNG in distributed training, so
every rank draws the same permutation and batch shards stay disjoint. When
None, shuffling falls back to the global torch RNG.
"""
if drop_n_first_frames < 0:
raise ValueError(f"drop_n_first_frames must be >= 0, got {drop_n_first_frames}")
@@ -78,11 +73,10 @@ class EpisodeAwareSampler:
self.indices = indices
self.shuffle = shuffle
self.generator = generator
def __iter__(self) -> Iterator[int]:
if self.shuffle:
for i in torch.randperm(len(self.indices), generator=self.generator):
for i in torch.randperm(len(self.indices)):
yield self.indices[i]
else:
for i in self.indices:
@@ -52,13 +52,7 @@ class BiRebotB601Follower(Robot):
cameras=config.left_arm_config.cameras,
motor_can_ids=config.left_arm_config.motor_can_ids,
pos_vel_velocity=config.left_arm_config.pos_vel_velocity,
control_mode=config.left_arm_config.control_mode,
mit_kp=config.left_arm_config.mit_kp,
mit_kd=config.left_arm_config.mit_kd,
gripper_control_mode=config.left_arm_config.gripper_control_mode,
gripper_torque_ratio=config.left_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.left_arm_config.gripper_mit_kp,
gripper_mit_kd=config.left_arm_config.gripper_mit_kd,
joint_limits=config.left_arm_config.joint_limits,
)
@@ -73,13 +67,7 @@ class BiRebotB601Follower(Robot):
cameras=config.right_arm_config.cameras,
motor_can_ids=config.right_arm_config.motor_can_ids,
pos_vel_velocity=config.right_arm_config.pos_vel_velocity,
control_mode=config.right_arm_config.control_mode,
mit_kp=config.right_arm_config.mit_kp,
mit_kd=config.right_arm_config.mit_kd,
gripper_control_mode=config.right_arm_config.gripper_control_mode,
gripper_torque_ratio=config.right_arm_config.gripper_torque_ratio,
gripper_mit_kp=config.right_arm_config.gripper_mit_kp,
gripper_mit_kd=config.right_arm_config.gripper_mit_kd,
joint_limits=config.right_arm_config.joint_limits,
)
@@ -65,33 +65,18 @@ class RebotB601FollowerConfig:
}
)
# Max speed (deg/s) per joint for POS_VEL arms and FORCE_POS gripper (motor order).
pos_vel_velocity: float | list[float] = field(
default_factory=lambda: [150.0, 150.0, 150.0, 150.0, 150.0, 150.0, 500.0]
)
# Target velocity for joints running in POS_VEL mode, in degrees/s. A scalar is
# applied to every joint; a list provides one value per joint (in motor order).
pos_vel_velocity: float | list[float] = field(default_factory=lambda: [150.0] * 7)
# Arm control: "mit" or "pos_vel".
control_mode: str = "mit"
# MIT kp/kd per arm joint (motor order). Unused when control_mode="pos_vel".
mit_kp: float | list[float] = field(default_factory=lambda: [45.0, 45.0, 45.0, 8.0, 9.0, 8.0, 8.0])
mit_kd: float | list[float] = field(default_factory=lambda: [12.0, 12.0, 12.0, 1.0, 1.0, 1.0, 1.0])
# Gripper control: "force_pos" or "mit".
gripper_control_mode: str = "force_pos"
# FORCE_POS only: max grip force, in [0, 1].
gripper_torque_ratio: float = 0.05
# MIT only.
gripper_mit_kp: float = 8.0
gripper_mit_kd: float = 0.3
# Torque/current ratio for the gripper's FORCE_POS mode, in range [0, 1].
gripper_torque_ratio: float = 0.1
# Soft joint limits (degrees). These are clipped against on every action.
joint_limits: dict[str, tuple[float, float]] = field(
default_factory=lambda: {
"shoulder_pan": (-150.0, 150.0),
"shoulder_lift": (-200.0, 1.0),
"shoulder_pan": (-145.0, 145.0),
"shoulder_lift": (-170.0, 1.0),
"elbow_flex": (-200.0, 1.0),
"wrist_flex": (-80.0, 90.0),
"wrist_yaw": (-90.0, 90.0),
@@ -169,25 +169,11 @@ class RebotB601Follower(Robot):
print(f"Calibration saved to {self.calibration_fpath}")
def configure(self) -> None:
if self.config.control_mode not in ("pos_vel", "mit"):
raise ValueError(
f"Unsupported control_mode '{self.config.control_mode}'. Use 'pos_vel' or 'mit'."
)
if self.config.gripper_control_mode not in ("force_pos", "mit"):
raise ValueError(
f"Unsupported gripper_control_mode '{self.config.gripper_control_mode}'. "
"Use 'force_pos' or 'mit'."
)
use_mit = self.config.control_mode == "mit"
gripper_use_mit = self.config.gripper_control_mode == "mit"
self.bus.enable_all()
for motor_name, motor in self.motors.items():
if motor_name == GRIPPER_MOTOR:
target_mode = MotorBridgeMode.MIT if gripper_use_mit else MotorBridgeMode.FORCE_POS
elif use_mit:
target_mode = MotorBridgeMode.MIT
else:
target_mode = MotorBridgeMode.POS_VEL
target_mode = (
MotorBridgeMode.FORCE_POS if motor_name == GRIPPER_MOTOR else MotorBridgeMode.POS_VEL
)
for attempt in range(_ENSURE_MODE_RETRIES + 1):
try:
motor.ensure_mode(target_mode)
@@ -266,34 +252,22 @@ class RebotB601Follower(Robot):
goal_present_pos = {key: (g, present_pos.get(key, g)) for key, g in goal_pos.items()}
goal_pos = ensure_safe_goal_position(goal_present_pos, self.config.max_relative_target)
use_mit = self.config.control_mode == "mit"
for motor_name, position_deg in goal_pos.items():
motor = self.motors.get(motor_name)
if motor is None:
continue
idx = self.motor_names.index(motor_name)
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
pos_rad = math.radians(position_deg)
vel_rad = math.radians(vel_deg_s)
if motor_name == GRIPPER_MOTOR:
if self.config.gripper_control_mode == "mit":
motor.send_mit(pos_rad, 0.0, self.config.gripper_mit_kp, self.config.gripper_mit_kd, 0.0)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_force_pos(pos_rad, math.radians(vel_deg_s), self.config.gripper_torque_ratio)
elif use_mit:
kp = self.config.mit_kp[idx] if isinstance(self.config.mit_kp, list) else self.config.mit_kp
kd = self.config.mit_kd[idx] if isinstance(self.config.mit_kd, list) else self.config.mit_kd
motor.send_mit(pos_rad, 0.0, kp, kd, 0.0)
motor.send_force_pos(pos_rad, vel_rad, self.config.gripper_torque_ratio)
else:
vel_deg_s = (
self.config.pos_vel_velocity[idx]
if isinstance(self.config.pos_vel_velocity, list)
else self.config.pos_vel_velocity
)
motor.send_pos_vel(pos_rad, math.radians(vel_deg_s))
motor.send_pos_vel(pos_rad, vel_rad)
return {f"{motor}.pos": val for motor, val in goal_pos.items()}
+49 -13
View File
@@ -77,6 +77,21 @@ from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD
from lerobot.utils.utils import init_logging
def get_feature_names(dataset: LeRobotDataset, key: str) -> list[str]:
"""Return per-dimension names for a feature from the dataset metadata.
Only flat-list ``names`` metadata is used. Dict-style ``names`` and missing names fall back to ``{key}_{i}`` indices.
"""
feature = dataset.features[key]
dim = feature["shape"][-1]
names = feature.get("names")
if isinstance(names, list) and len(names) == dim:
return [str(name) for name in names]
return [f"{key}_{d}" for d in range(dim)]
def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
assert chw_float32_torch.dtype == torch.float32
assert chw_float32_torch.ndim == 3
@@ -86,6 +101,31 @@ def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray:
return hwc_uint8_numpy
def build_blueprint_from_dataset(dataset: LeRobotDataset):
"""Build a Rerun blueprint laying out camera images and time series for the given dataset.
Camera images and scalar signals (action, state, reward, done, success) are arranged in a grid.
The per-dimension series names for ``action`` and ``state`` are applied directly
via blueprint overrides.
"""
import rerun as rr
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=key, name=key) for key in dataset.meta.camera_keys]
# Style multi-dimensional signals (action, state) with per-dimension names.
for origin, key in ((ACTION, ACTION), ("state", OBS_STATE)):
if key in dataset.features:
names = get_feature_names(dataset, key)
styling = rr.SeriesLines(names=names)
views.append(rrb.TimeSeriesView(origin=origin, name=origin, overrides={origin: styling}))
for key in (DONE, REWARD, "next.success"):
if key in dataset.features:
views.append(rrb.TimeSeriesView(origin=key, name=key))
return rrb.Blueprint(rrb.Grid(*views))
def visualize_dataset(
dataset: LeRobotDataset,
episode_index: int,
@@ -124,7 +164,8 @@ def visualize_dataset(
import rerun as rr
spawn_local_viewer = mode == "local" and not save
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer)
blueprint = build_blueprint_from_dataset(dataset)
rr.init(f"{repo_id}/episode_{episode_index}", spawn=spawn_local_viewer, default_blueprint=blueprint)
# Manually call python garbage collector after `rr.init` to avoid hanging in a blocking flush
# when iterating on a dataloader with `num_workers` > 0
@@ -142,26 +183,21 @@ def visualize_dataset(
for batch in tqdm.tqdm(dataloader, total=len(dataloader)):
if first_index is None:
first_index = batch["index"][0].item()
# iterate over the batch
for i in range(len(batch["index"])):
rr.set_time("frame_index", sequence=batch["index"][i].item() - first_index)
rr.set_time("timestamp", timestamp=batch["timestamp"][i].item())
# display each camera image
for key in dataset.meta.camera_keys:
img = to_hwc_uint8_numpy(batch[key][i])
img_entity = rr.Image(img).compress() if display_compressed_images else rr.Image(img)
rr.log(key, entity=img_entity)
# display each dimension of action space (e.g. actuators command)
if ACTION in batch:
for dim_idx, val in enumerate(batch[ACTION][i]):
rr.log(f"{ACTION}/{dim_idx}", rr.Scalars(val.item()))
rr.log(ACTION, rr.Scalars(batch[ACTION][i].numpy()))
# display each dimension of observed state space (e.g. agent position in joint space)
if OBS_STATE in batch:
for dim_idx, val in enumerate(batch[OBS_STATE][i]):
rr.log(f"state/{dim_idx}", rr.Scalars(val.item()))
rr.log("state", rr.Scalars(batch[OBS_STATE][i].numpy()))
if DONE in batch:
rr.log(DONE, rr.Scalars(batch[DONE][i].item()))
@@ -173,8 +209,6 @@ def visualize_dataset(
rr.log("next.success", rr.Scalars(batch["next.success"][i].item()))
if mode == "local" and save:
# save .rrd locally
output_dir = Path(output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
repo_id_str = repo_id.replace("/", "_")
rrd_path = output_dir / f"{repo_id_str}_episode_{episode_index}.rrd"
@@ -182,7 +216,7 @@ def visualize_dataset(
return rrd_path
elif mode == "distant":
# stop the process from exiting since it is serving the websocket connection
# Keep the process alive while it serves the gRPC/web connection.
try:
while True:
time.sleep(1)
@@ -297,12 +331,14 @@ def main():
)
logging.warning("Setting grpc_port to ws_port value.")
kwargs["grpc_port"] = kwargs.pop("ws_port")
else:
kwargs.pop("ws_port") # Always remove ws_port from kwargs
init_logging()
logging.info("Loading dataset")
dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)
visualize_dataset(dataset, **vars(args))
visualize_dataset(dataset, **kwargs)
if __name__ == "__main__":
+5 -15
View File
@@ -232,18 +232,15 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
torch.backends.cudnn.benchmark = True
torch.backends.cuda.matmul.allow_tf32 = True
# Dataset loading synchronization: each node's local main process downloads first to avoid
# race conditions (the global main process only exists on node 0, so gating on it would let
# all ranks of the other nodes download and build the Arrow cache concurrently).
if accelerator.is_local_main_process:
if is_main_process:
logging.info("Creating dataset")
# Dataset loading synchronization: main process downloads first to avoid race conditions
if is_main_process:
logging.info("Creating dataset")
dataset = make_dataset(cfg)
accelerator.wait_for_everyone()
# Now all other processes can safely load the dataset from the local cache
if not accelerator.is_local_main_process:
# Now all other processes can safely load the dataset
if not is_main_process:
dataset = make_dataset(cfg)
# Create environment used for evaluating checkpoints during training on simulation data.
@@ -389,19 +386,12 @@ def train(cfg: TrainPipelineConfig, accelerator: "Accelerator | None" = None):
# create dataloader for offline training
if hasattr(active_cfg, "drop_n_last_frames"):
shuffle = False
# A dedicated generator (rather than the global torch RNG) lets accelerator.prepare
# synchronize the shuffle permutation across ranks, keeping batch shards disjoint even
# when ranks consume the global RNG asymmetrically (e.g. eval on the main process only).
sampler_generator = torch.Generator()
if cfg.seed is not None:
sampler_generator.manual_seed(cfg.seed)
sampler = EpisodeAwareSampler(
dataset.meta.episodes["dataset_from_index"],
dataset.meta.episodes["dataset_to_index"],
episode_indices_to_use=dataset.episodes,
drop_n_last_frames=active_cfg.drop_n_last_frames,
shuffle=True,
generator=sampler_generator,
)
else:
shuffle = True
@@ -65,7 +65,7 @@ class RebotArm102LeaderConfig:
joint_ranges: dict[str, list[int]] = field(
default_factory=lambda: {
"shoulder_pan": [-150, 150],
"shoulder_lift": [-200, 1],
"shoulder_lift": [-170, 1],
"elbow_flex": [-200, 1],
"wrist_flex": [-80, 90],
"wrist_yaw": [-90, 90],
+53 -12
View File
@@ -38,6 +38,8 @@ def init_rerun(
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
log_rerun_data.blueprint = None # Reset blueprint cache for new session
batch_size = os.getenv("RERUN_FLUSH_NUM_BYTES", "8000")
os.environ["RERUN_FLUSH_NUM_BYTES"] = batch_size
rr.init(session_name)
@@ -63,6 +65,38 @@ def _is_scalar(x):
)
def _build_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]):
"""Build a Rerun blueprint laying out camera images, observation and action scalars in separate views.
Camera images, observation and action scalars are arranged in a grid.
"""
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun.blueprint as rrb
views = [rrb.Spatial2DView(origin=path, name=path) for path in sorted(image_paths)]
if observation_paths:
views.append(rrb.TimeSeriesView(name="observation", contents=sorted(observation_paths)))
if action_paths:
views.append(rrb.TimeSeriesView(name="action", contents=sorted(action_paths)))
return rrb.Blueprint(rrb.Grid(*views))
def _ensure_blueprint(observation_paths: set[str], action_paths: set[str], image_paths: set[str]) -> None:
"""Build and send the blueprint once, from the first observation and action data."""
if getattr(log_rerun_data, "blueprint", None) is not None:
return
# Safe + zero-overhead: `log_rerun_data` already ran the `require_package` guard and imported rerun.
import rerun as rr
blueprint = _build_blueprint(observation_paths, action_paths, image_paths)
log_rerun_data.blueprint = blueprint
rr.send_blueprint(blueprint)
def log_rerun_data(
observation: RobotObservation | None = None,
action: RobotAction | None = None,
@@ -76,11 +110,15 @@ def log_rerun_data(
- Scalars values (floats, ints) are logged as `rr.Scalars`.
- 3D NumPy arrays that resemble images (e.g., with 1, 3, or 4 channels first) are transposed
from CHW to HWC format, (optionally) compressed to JPEG and logged as `rr.Image` or `rr.EncodedImage`.
- 1D NumPy arrays are logged as a series of individual scalars, with each element indexed.
- Other multi-dimensional arrays are flattened and logged as individual scalars.
- 1D NumPy arrays are logged as a single `rr.Scalars` batch under one entity path, so that every
dimension shares the same view instead of being split across one view per element.
- Multi-dimensional **action** arrays are flattened and logged as a single `rr.Scalars` batch.
Keys are automatically namespaced with "observation." or "action." if not already present.
On the first call, a blueprint is built and sent so observation and action scalars get separate
time-series views and each image gets its own spatial view.
Args:
observation: An optional dictionary containing observation data to log.
action: An optional dictionary containing action data to log.
@@ -90,6 +128,10 @@ def log_rerun_data(
require_package("rerun-sdk", extra="viz", import_name="rerun")
import rerun as rr
observation_paths: set[str] = set()
action_paths: set[str] = set()
image_paths: set[str] = set()
if observation:
for k, v in observation.items():
if v is None:
@@ -98,17 +140,19 @@ def log_rerun_data(
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
observation_paths.add(key)
elif isinstance(v, np.ndarray):
arr = v
# Convert CHW -> HWC when needed
if arr.ndim == 3 and arr.shape[0] in (1, 3, 4) and arr.shape[-1] not in (1, 3, 4):
arr = np.transpose(arr, (1, 2, 0))
if arr.ndim == 1:
for i, vi in enumerate(arr):
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
rr.log(key, rr.Scalars(arr.astype(float)))
observation_paths.add(key)
else:
img_entity = rr.Image(arr).compress() if compress_images else rr.Image(arr)
rr.log(key, entity=img_entity, static=True)
image_paths.add(key)
if action:
for k, v in action.items():
@@ -118,12 +162,9 @@ def log_rerun_data(
if _is_scalar(v):
rr.log(key, rr.Scalars(float(v)))
action_paths.add(key)
elif isinstance(v, np.ndarray):
if v.ndim == 1:
for i, vi in enumerate(v):
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
else:
# Fall back to flattening higher-dimensional arrays
flat = v.flatten()
for i, vi in enumerate(flat):
rr.log(f"{key}_{i}", rr.Scalars(float(vi)))
rr.log(key, rr.Scalars(v.reshape(-1).astype(float)))
action_paths.add(key)
_ensure_blueprint(observation_paths, action_paths, image_paths)
-24
View File
@@ -114,30 +114,6 @@ def test_shuffle():
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_shuffle_with_generator_is_deterministic():
# Two samplers shuffling with same-seed generators must yield identical permutations.
# This is what keeps batch shards disjoint across ranks in distributed training, where
# accelerate synchronizes the sampler's generator state instead of the global torch RNG.
sampler_a = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
sampler_b = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
assert list(sampler_a) == list(sampler_b)
# Desyncing the global RNG must not affect the permutation.
sampler_c = EpisodeAwareSampler([0], [6], shuffle=True, generator=torch.Generator().manual_seed(42))
order_before = list(sampler_c)
sampler_c.generator.manual_seed(42)
torch.randperm(1000) # consume global RNG, as rank-asymmetric code (e.g. eval) would
assert list(sampler_c) == order_before
def test_generator_attribute_defaults_to_none():
# accelerate detects synchronizable samplers via `hasattr(sampler, "generator")`,
# so the attribute must exist even when no generator is passed.
sampler = EpisodeAwareSampler([0], [6], shuffle=True)
assert sampler.generator is None
assert set(sampler) == {0, 1, 2, 3, 4, 5}
def test_negative_drop_first_frames_raises():
with pytest.raises(ValueError, match="drop_n_first_frames must be >= 0"):
EpisodeAwareSampler([0], [10], drop_n_first_frames=-1)
+3 -20
View File
@@ -91,11 +91,10 @@ def test_get_observation_converts_to_degrees(follower):
def test_send_action_clips_to_joint_limits(follower):
# shoulder_pan limit is (-150, 150); request beyond the upper bound.
# shoulder_pan limit is (-145, 145); request beyond the upper bound.
returned = follower.send_action({"shoulder_pan.pos": 999.0})
assert returned["shoulder_pan.pos"] == 150.0
# Default control_mode is "mit", so arm joints are driven via send_mit.
follower.motors["shoulder_pan"].send_mit.assert_called_once()
assert returned["shoulder_pan.pos"] == 145.0
follower.motors["shoulder_pan"].send_pos_vel.assert_called_once()
def test_send_action_routes_gripper_to_force_pos(follower):
@@ -104,22 +103,6 @@ def test_send_action_routes_gripper_to_force_pos(follower):
follower.motors["gripper"].send_pos_vel.assert_not_called()
def test_gripper_mit_mode_routes_to_send_mit():
bus_mock = _make_bus_mock()
with (
patch(f"{_MODULE}.require_package", lambda *a, **kw: None),
patch(f"{_MODULE}.MotorBridgeController") as controller_cls,
patch(f"{_MODULE}.MotorBridgeMode", MagicMock()),
):
controller_cls.from_dm_serial.return_value = bus_mock
cfg = RebotB601FollowerRobotConfig(port="/dev/null", gripper_control_mode="mit")
robot = RebotB601Follower(cfg)
robot.connect(calibrate=False)
robot.send_action({"gripper.pos": -10.0})
robot.motors["gripper"].send_mit.assert_called_once()
robot.motors["gripper"].send_force_pos.assert_not_called()
def test_bimanual_prefixes_features():
with patch(f"{_MODULE}.require_package", lambda *a, **kw: None):
cfg = BiRebotB601FollowerConfig(
+103 -34
View File
@@ -30,25 +30,46 @@ from lerobot.utils.constants import OBS_STATE
@pytest.fixture
def mock_rerun(monkeypatch):
"""
Provide a mock `rerun` module so tests don't depend on the real library.
Also reload the module-under-test so it binds to this mock `rr`.
Provide a mock `rerun` module (and `rerun.blueprint` submodule) so tests don't
depend on the real library. Also reload the module-under-test so it binds to
this mock `rr`.
"""
calls = []
blueprints = []
class DummyScalar:
def __init__(self, value):
self.value = float(value)
# Scalars may be built from a single float or from a 1D array batch.
self.value = value
class DummyImage:
def __init__(self, arr):
self.arr = arr
def compress(self, *a, **k):
return self
def dummy_log(key, obj=None, **kwargs):
# Accept either positional `obj` or keyword `entity` and record remaining kwargs.
if obj is None and "entity" in kwargs:
obj = kwargs.pop("entity")
calls.append((key, obj, kwargs))
def dummy_send_blueprint(blueprint, *a, **k):
blueprints.append(blueprint)
# Mock the `rerun.blueprint` submodule used to build the layout.
dummy_rrb = SimpleNamespace(
Spatial2DView=lambda origin=None, name=None: SimpleNamespace(
kind="Spatial2DView", origin=origin, name=name
),
TimeSeriesView=lambda name=None, contents=None: SimpleNamespace(
kind="TimeSeriesView", name=name, contents=contents
),
Grid=lambda *views: SimpleNamespace(kind="Grid", views=list(views)),
Blueprint=lambda root: SimpleNamespace(kind="Blueprint", root=root),
)
dummy_rr = SimpleNamespace(
__name__="rerun",
__package__="rerun",
@@ -56,20 +77,23 @@ def mock_rerun(monkeypatch):
Scalars=DummyScalar,
Image=DummyImage,
log=dummy_log,
send_blueprint=dummy_send_blueprint,
init=lambda *a, **k: None,
spawn=lambda *a, **k: None,
blueprint=dummy_rrb,
)
# Inject fake module into sys.modules
# Inject fake modules into sys.modules (both `rerun` and `rerun.blueprint`).
monkeypatch.setitem(sys.modules, "rerun", dummy_rr)
monkeypatch.setitem(sys.modules, "rerun.blueprint", dummy_rrb)
# Now import and reload the module under test, to bind to our rerun mock
import lerobot.utils.visualization_utils as vu
importlib.reload(vu)
# Expose both the reloaded module and the call recorder
yield vu, calls
# Expose the reloaded module, the call recorder and the captured blueprints
yield vu, calls, blueprints
def _keys(calls):
@@ -92,8 +116,13 @@ def _kwargs_for(calls, key):
raise KeyError(f"Key {key} not found in calls: {calls}")
def _views_by_kind(blueprint, kind):
"""Return the views of a given kind from the (single) blueprint's grid."""
return [v for v in blueprint.root.views if v.kind == kind]
def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
vu, calls = mock_rerun
vu, calls, blueprints = mock_rerun
# Build EnvTransition dict
obs = {
@@ -103,7 +132,7 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
}
act = {
"action.throttle": 0.7,
# 1D array should log individual Scalars with suffix _i
# 1D array should be logged as a single Scalars batch under one entity path
"action.vector": np.array([1.0, 2.0], dtype=np.float32),
}
transition = {
@@ -120,31 +149,28 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
# - observation.state.temperature -> Scalars
# - observation.camera -> Image (HWC) with static=True
# - action.throttle -> Scalars
# - action.vector_0, action.vector_1 -> Scalars
# - action.vector -> single Scalars batch (no per-element suffix)
expected_keys = {
f"{OBS_STATE}.temperature",
"observation.camera",
"action.throttle",
"action.vector_0",
"action.vector_1",
"action.vector",
}
assert set(_keys(calls)) == expected_keys
# Check scalar types and values
temp_obj = _obj_for(calls, f"{OBS_STATE}.temperature")
assert type(temp_obj).__name__ == "DummyScalar"
assert temp_obj.value == pytest.approx(25.0)
assert float(temp_obj.value) == pytest.approx(25.0)
throttle_obj = _obj_for(calls, "action.throttle")
assert type(throttle_obj).__name__ == "DummyScalar"
assert throttle_obj.value == pytest.approx(0.7)
assert float(throttle_obj.value) == pytest.approx(0.7)
v0 = _obj_for(calls, "action.vector_0")
v1 = _obj_for(calls, "action.vector_1")
assert type(v0).__name__ == "DummyScalar"
assert type(v1).__name__ == "DummyScalar"
assert v0.value == pytest.approx(1.0)
assert v1.value == pytest.approx(2.0)
# 1D vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vector")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [1.0, 2.0])
# Check image handling: CHW -> HWC
img_obj = _obj_for(calls, "observation.camera")
@@ -152,9 +178,24 @@ def test_log_rerun_data_envtransition_scalars_and_image(mock_rerun):
assert img_obj.arr.shape == (10, 20, 3) # transposed
assert _kwargs_for(calls, "observation.camera").get("static", False) is True # static=True for images
# A blueprint should have been built and sent exactly once, and cached on the function.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is blueprints[0]
bp = blueprints[0]
# One spatial view per image path
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.camera"}
# One time-series view each for observation and action scalars
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert set(ts_views) == {"observation", "action"}
assert ts_views["observation"].contents == [f"{OBS_STATE}.temperature"]
assert ts_views["action"].contents == ["action.throttle", "action.vector"]
def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
vu, calls = mock_rerun
vu, calls, blueprints = mock_rerun
# First dict without prefixes treated as observation
# Second dict without prefixes treated as action
@@ -173,14 +214,12 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
# First dict was treated as observation, second as action
vu.log_rerun_data(observation=obs_plain, action=act_plain)
# Expected keys with auto-prefixes
# Expected keys with auto-prefixes. The 1D vector is a single batched Scalars.
expected = {
"observation.temp",
"observation.img",
"action.throttle",
"action.vec_0",
"action.vec_1",
"action.vec_2",
"action.vec",
}
logged = set(_keys(calls))
assert logged == expected
@@ -188,11 +227,11 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
# Scalars
t = _obj_for(calls, "observation.temp")
assert type(t).__name__ == "DummyScalar"
assert t.value == pytest.approx(1.5)
assert float(t.value) == pytest.approx(1.5)
throttle = _obj_for(calls, "action.throttle")
assert type(throttle).__name__ == "DummyScalar"
assert throttle.value == pytest.approx(0.3)
assert float(throttle.value) == pytest.approx(0.3)
# Image stays HWC
img = _obj_for(calls, "observation.img")
@@ -200,15 +239,23 @@ def test_log_rerun_data_plain_list_ordering_and_prefixes(mock_rerun):
assert img.arr.shape == (5, 6, 3)
assert _kwargs_for(calls, "observation.img").get("static", False) is True
# Vectors
for i, val in enumerate([9, 8, 7]):
o = _obj_for(calls, f"action.vec_{i}")
assert type(o).__name__ == "DummyScalar"
assert o.value == pytest.approx(val)
# Vector logged as a single batched Scalars under one entity path
vec = _obj_for(calls, "action.vec")
assert type(vec).__name__ == "DummyScalar"
np.testing.assert_allclose(np.asarray(vec.value), [9, 8, 7])
# Blueprint sent once with the expected view layout
assert len(blueprints) == 1
bp = blueprints[0]
spatial_views = _views_by_kind(bp, "Spatial2DView")
assert {v.origin for v in spatial_views} == {"observation.img"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.throttle", "action.vec"]
def test_log_rerun_data_kwargs_only(mock_rerun):
vu, calls = mock_rerun
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(
observation={"observation.temp": 10.0, "observation.gray": np.zeros((8, 8, 1), dtype=np.uint8)},
@@ -222,7 +269,7 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
temp = _obj_for(calls, "observation.temp")
assert type(temp).__name__ == "DummyScalar"
assert temp.value == pytest.approx(10.0)
assert float(temp.value) == pytest.approx(10.0)
img = _obj_for(calls, "observation.gray")
assert type(img).__name__ == "DummyImage"
@@ -231,4 +278,26 @@ def test_log_rerun_data_kwargs_only(mock_rerun):
a = _obj_for(calls, "action.a")
assert type(a).__name__ == "DummyScalar"
assert a.value == pytest.approx(1.0)
assert float(a.value) == pytest.approx(1.0)
# Blueprint sent once, with a spatial view for the image and time-series views for scalars
assert len(blueprints) == 1
bp = blueprints[0]
assert {v.origin for v in _views_by_kind(bp, "Spatial2DView")} == {"observation.gray"}
ts_views = {v.name: v for v in _views_by_kind(bp, "TimeSeriesView")}
assert ts_views["observation"].contents == ["observation.temp"]
assert ts_views["action"].contents == ["action.a"]
def test_log_rerun_data_blueprint_sent_only_once(mock_rerun):
"""The blueprint is built from the first call and not resent on subsequent calls."""
vu, calls, blueprints = mock_rerun
vu.log_rerun_data(observation={"temp": 1.0}, action={"a": 2.0})
assert len(blueprints) == 1
first_blueprint = vu.log_rerun_data.blueprint
vu.log_rerun_data(observation={"temp": 3.0}, action={"a": 4.0})
# Still only one blueprint, and the cached one is unchanged.
assert len(blueprints) == 1
assert vu.log_rerun_data.blueprint is first_blueprint
Generated
+8 -8
View File
@@ -1,5 +1,5 @@
version = 1
revision = 3
revision = 2
requires-python = ">=3.12"
resolution-markers = [
"(python_full_version >= '3.15' and platform_machine == 'AMD64' and sys_platform == 'linux') or (python_full_version >= '3.15' and platform_machine == 'x86_64' and sys_platform == 'linux')",
@@ -3257,7 +3257,7 @@ requires-dist = [
{ name = "qwen-vl-utils", marker = "extra == 'qwen-vl-utils-dep'", specifier = ">=0.0.11,<0.1.0" },
{ name = "reachy2-sdk", marker = "extra == 'reachy2'", specifier = ">=1.0.15,<1.1.0" },
{ name = "requests", specifier = ">=2.32.0,<3.0.0" },
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.27.0" },
{ name = "rerun-sdk", marker = "extra == 'viz'", specifier = ">=0.24.0,<0.34.0" },
{ name = "ruff", marker = "extra == 'dev'", specifier = ">=0.14.1" },
{ name = "safetensors", specifier = ">=0.4.3,<1.0.0" },
{ name = "scikit-image", marker = "extra == 'video-benchmark'", specifier = ">=0.23.2,<0.26.0" },
@@ -5636,21 +5636,21 @@ wheels = [
[[package]]
name = "rerun-sdk"
version = "0.26.2"
version = "0.33.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "numpy" },
{ name = "pillow" },
{ name = "psutil" },
{ name = "pyarrow" },
{ name = "typing-extensions" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/4b/4a/767c20e1529d74d9be5b5e55c6c26b63a6918ef3c1709fc422d08a460114/rerun_sdk-0.26.2-cp39-abi3-macosx_10_12_x86_64.whl", hash = "sha256:3d4151c9a3484e112b53d1df90c8fa07397dc7b8bfbb420f09e011eff20f1ef2", size = 93349439, upload-time = "2025-10-27T11:34:10.745Z" },
{ url = "https://files.pythonhosted.org/packages/2b/3d/d8dd0af9c287a85d51ec99d69406cc4b94a9feb1d6f192d3bbcaac9f0b81/rerun_sdk-0.26.2-cp39-abi3-macosx_11_0_arm64.whl", hash = "sha256:03977d2aba4966d9a70b682eca196123fda11408fecd733441ede9916c6341e2", size = 86323042, upload-time = "2025-10-27T11:34:17.995Z" },
{ url = "https://files.pythonhosted.org/packages/13/29/53d8d98799ab32418fd4ba6834d6a5749c31f56160d3c87f52a7219887e9/rerun_sdk-0.26.2-cp39-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:b6128c3c4f014cae5be18e4d37657c5932d1bcdb2ce5e9d4b488a6eed47f7437", size = 92677274, upload-time = "2025-10-27T11:34:22.601Z" },
{ url = "https://files.pythonhosted.org/packages/f5/86/0b9c8f56398b4fc85f8e99279907c258413a297e5603f8f2537fe5806e51/rerun_sdk-0.26.2-cp39-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:a6f97b60aaa7d4e8c6124a3f6b97ce9dbd09520050955f0e0bdacb72b0eb106a", size = 98768129, upload-time = "2025-10-27T11:34:27.36Z" },
{ url = "https://files.pythonhosted.org/packages/be/e7/99fc91c0f99f69d7d43e1db0a6f6cb8273ffc02111539bfc1fee43749bad/rerun_sdk-0.26.2-cp39-abi3-win_amd64.whl", hash = "sha256:a493ad6c8357022cba2ca6f8954a81d0faf984b0b22154eb1d976bfc7649df63", size = 84267089, upload-time = "2025-10-27T11:34:32.023Z" },
{ url = "https://files.pythonhosted.org/packages/31/17/5a521e86ac0064bd0f452e3e98e2422433511b54110423c0217d2cc1234f/rerun_sdk-0.33.0-cp310-abi3-macosx_11_0_arm64.whl", hash = "sha256:97f123e3ef6aa69b60194bc566e5435c7d4040757ed4f58297ea46c8ef320c5c", size = 125707606, upload-time = "2026-05-29T09:42:53.584Z" },
{ url = "https://files.pythonhosted.org/packages/34/2f/2ca2599aca03b69fbcac7c8391ef50376968edd7c58b96de53a4b7f20624/rerun_sdk-0.33.0-cp310-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8f734cf59419dcfbc46915bea6cec030224f16e96c3a597f0ccf7cb7b058dd43", size = 135271020, upload-time = "2026-05-29T09:43:00.106Z" },
{ url = "https://files.pythonhosted.org/packages/2e/ba/d70997b43e6db4f58c4326c29c6a6a384ddc6c2fe125f231c885ad9b3b1f/rerun_sdk-0.33.0-cp310-abi3-manylinux_2_28_x86_64.whl", hash = "sha256:53d95609f8b330026bcd041bf6d11b46ee1c18b6fbde155135f291fe86328eeb", size = 139552018, upload-time = "2026-05-29T09:43:06.275Z" },
{ url = "https://files.pythonhosted.org/packages/14/a5/0cac294d16aff6c9a2f183f838428a0380b4d2fd9e053bb37b3041999ad5/rerun_sdk-0.33.0-cp310-abi3-win_amd64.whl", hash = "sha256:b152992a72ec240062c8c285bd30ab681b464a25efbe1464c66fdac82320de1f", size = 120418186, upload-time = "2026-05-29T09:43:13.733Z" },
]
[[package]]