diff --git a/src/lerobot/scripts/lerobot_dataset_viz.py b/src/lerobot/scripts/lerobot_dataset_viz.py index 2cd48eab8..29d64554f 100644 --- a/src/lerobot/scripts/lerobot_dataset_viz.py +++ b/src/lerobot/scripts/lerobot_dataset_viz.py @@ -47,16 +47,14 @@ local$ rerun lerobot_pusht_episode_0.rrd ``` - Visualize data stored on a distant machine through streaming: -(You need to forward the websocket port to the distant machine, with -`ssh -L 9087:localhost:9087 username@remote-host`) ``` distant$ lerobot-dataset-viz \ --repo-id lerobot/pusht \ --episode-index 0 \ --mode distant \ - --ws-port 9087 + --grpc-port 9876 -local$ rerun ws://localhost:9087 +local$ rerun rerun+http://IP:GRPC_PORT/proxy ``` """ @@ -75,6 +73,7 @@ import tqdm from lerobot.datasets.lerobot_dataset import LeRobotDataset from lerobot.utils.constants import ACTION, DONE, OBS_STATE, REWARD +from lerobot.utils.utils import init_logging def to_hwc_uint8_numpy(chw_float32_torch: torch.Tensor) -> np.ndarray: @@ -93,10 +92,11 @@ def visualize_dataset( num_workers: int = 0, mode: str = "local", web_port: int = 9090, - ws_port: int = 9087, + grpc_port: int = 9876, save: bool = False, output_dir: Path | None = None, display_compressed_images: bool = False, + **kwargs, ) -> Path | None: if save: assert output_dir is not None, ( @@ -126,7 +126,9 @@ def visualize_dataset( gc.collect() if mode == "distant": - rr.serve_web_viewer(open_browser=False, web_port=web_port) + server_uri = rr.serve_grpc(grpc_port=grpc_port) + logging.info(f"Connect to a Rerun Server: rerun rerun+http://IP:{grpc_port}/proxy") + rr.serve_web_viewer(open_browser=False, web_port=web_port, connect_to=server_uri) logging.info("Logging to Rerun") @@ -226,7 +228,7 @@ def main(): "Mode of viewing between 'local' or 'distant'. " "'local' requires data to be on a local machine. It spawns a viewer to visualize the data locally. " "'distant' creates a server on the distant machine where the data is stored. " - "Visualize the data by connecting to the server with `rerun ws://localhost:PORT` on the local machine." + "Visualize the data by connecting to the server with `rerun rerun+http://IP:GRPC_PORT/proxy` on the local machine." ), ) parser.add_argument( @@ -238,8 +240,13 @@ def main(): parser.add_argument( "--ws-port", type=int, - default=9087, - help="Web socket port for rerun.io when `--mode distant` is set.", + help="deprecated, please use --grpc-port instead.", + ) + parser.add_argument( + "--grpc-port", + type=int, + default=9876, + help="gRPC port for rerun.io when `--mode distant` is set.", ) parser.add_argument( "--save", @@ -265,9 +272,7 @@ def main(): parser.add_argument( "--display-compressed-images", - type=bool, - required=True, - default=False, + action="store_true", help="If set, display compressed images in Rerun instead of uncompressed ones.", ) @@ -277,6 +282,14 @@ def main(): root = kwargs.pop("root") tolerance_s = kwargs.pop("tolerance_s") + if kwargs["ws_port"] is not None: + logging.warning( + "--ws-port is deprecated and will be removed in future versions. Please use --grpc-port instead." + ) + logging.warning("Setting grpc_port to ws_port value.") + kwargs["grpc_port"] = kwargs.pop("ws_port") + + init_logging() logging.info("Loading dataset") dataset = LeRobotDataset(repo_id, episodes=[args.episode_index], root=root, tolerance_s=tolerance_s)