diff --git a/src/lerobot/robots/so_follower/so_follower.py b/src/lerobot/robots/so_follower/so_follower.py index 0651f566c..77c5534be 100644 --- a/src/lerobot/robots/so_follower/so_follower.py +++ b/src/lerobot/robots/so_follower/so_follower.py @@ -68,9 +68,12 @@ class SOFollower(Robot): @property def _cameras_ft(self) -> dict[str, tuple]: - return { - cam: (self.config.cameras[cam].height, self.config.cameras[cam].width, 3) for cam in self.cameras - } + features: dict[str, tuple] = {} + for cam in self.cameras: + features[cam] = (self.cameras[cam].height, self.cameras[cam].width, 3) + if getattr(self.cameras[cam], "use_depth", False): + features[f"{cam}_depth"] = (self.cameras[cam].height, self.cameras[cam].width,1) + return features @cached_property def observation_features(self) -> dict[str, type | tuple]: @@ -190,6 +193,12 @@ class SOFollower(Robot): dt_ms = (time.perf_counter() - start) * 1e3 logger.debug(f"{self} read {cam_key}: {dt_ms:.1f}ms") + if getattr(cam, "use_depth", False): + start = time.perf_counter() + obs_dict[f"{cam_key}_depth"] = cam.read_latest_depth() + dt_ms = (time.perf_counter() - start) * 1e3 + logger.debug(f"{self} read {cam_key} depth: {dt_ms:.1f}ms") + return obs_dict @check_if_not_connected