mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-30 16:09:44 +00:00
sync recent changes
This commit is contained in:
@@ -0,0 +1,224 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Live camera viewer for MuJoCo simulator using matplotlib
|
||||
Works without X11/GTK - suitable for SSH sessions with X forwarding
|
||||
"""
|
||||
import argparse
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
|
||||
# Add sim module to path
|
||||
sys.path.insert(0, str(Path(__file__).parent))
|
||||
|
||||
import cv2
|
||||
import numpy as np
|
||||
import matplotlib.pyplot as plt
|
||||
from matplotlib.animation import FuncAnimation
|
||||
from sim.sensor_utils import SensorClient, ImageUtils
|
||||
|
||||
|
||||
class CameraViewer:
|
||||
def __init__(self, host, port):
|
||||
self.client = SensorClient()
|
||||
self.client.start_client(server_ip=host, port=port)
|
||||
|
||||
self.fig = None
|
||||
self.axes = {}
|
||||
self.images = {}
|
||||
self.text_objs = {}
|
||||
|
||||
self.frame_count = 0
|
||||
self.last_time = time.time()
|
||||
self.fps = 0
|
||||
|
||||
def init_plot(self):
|
||||
"""Initialize matplotlib figure and axes"""
|
||||
# Wait for first frame to know how many cameras we have
|
||||
print("Waiting for first frame to detect cameras...")
|
||||
data = self.client.receive_message()
|
||||
|
||||
# Parse camera names - handle nested 'images' dict
|
||||
camera_names = []
|
||||
if "images" in data and isinstance(data["images"], dict):
|
||||
# Nested structure: data["images"]["camera_name"]
|
||||
camera_names = list(data["images"].keys())
|
||||
else:
|
||||
# Flat structure: data["camera_name"] directly
|
||||
camera_names = [k for k in data.keys() if k not in ["timestamps", "images"]]
|
||||
|
||||
num_cameras = len(camera_names)
|
||||
|
||||
if num_cameras == 0:
|
||||
print("No cameras found in stream!")
|
||||
return False
|
||||
|
||||
print(f"Found {num_cameras} camera(s): {', '.join(camera_names)}")
|
||||
|
||||
# Create subplots
|
||||
if num_cameras == 1:
|
||||
self.fig, ax = plt.subplots(1, 1, figsize=(10, 8))
|
||||
axes_list = [ax]
|
||||
elif num_cameras == 2:
|
||||
self.fig, axes_list = plt.subplots(1, 2, figsize=(16, 6))
|
||||
else:
|
||||
rows = (num_cameras + 1) // 2
|
||||
self.fig, axes_list = plt.subplots(rows, 2, figsize=(16, 6 * rows))
|
||||
axes_list = axes_list.flatten()
|
||||
|
||||
# Initialize each camera subplot
|
||||
for i, cam_name in enumerate(camera_names):
|
||||
ax = axes_list[i]
|
||||
ax.set_title(f"{cam_name}", fontsize=12, fontweight='bold')
|
||||
ax.axis('off')
|
||||
|
||||
# Get image data from nested or flat structure
|
||||
if "images" in data and cam_name in data["images"]:
|
||||
img_data = data["images"][cam_name]
|
||||
elif cam_name in data:
|
||||
img_data = data[cam_name]
|
||||
else:
|
||||
img_data = cam_name # Use the actual data if it's the value
|
||||
|
||||
# Decode first image
|
||||
if isinstance(img_data, str):
|
||||
img = ImageUtils.decode_image(img_data)
|
||||
elif isinstance(img_data, np.ndarray):
|
||||
img = img_data
|
||||
else:
|
||||
print(f"Warning: Unknown image format for {cam_name}: {type(img_data)}")
|
||||
continue
|
||||
|
||||
# Check if image is valid
|
||||
if img is None or not isinstance(img, np.ndarray):
|
||||
print(f"Warning: Invalid image data for {cam_name}")
|
||||
continue
|
||||
|
||||
# Convert BGR to RGB for matplotlib
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Display image
|
||||
im = ax.imshow(img_rgb)
|
||||
self.images[cam_name] = im
|
||||
self.axes[cam_name] = ax
|
||||
|
||||
# Add FPS text
|
||||
text = ax.text(0.02, 0.98, 'FPS: 0.0',
|
||||
transform=ax.transAxes,
|
||||
fontsize=10,
|
||||
verticalalignment='top',
|
||||
bbox=dict(boxstyle='round', facecolor='black', alpha=0.7),
|
||||
color='lime',
|
||||
fontweight='bold')
|
||||
self.text_objs[cam_name] = text
|
||||
|
||||
# Hide unused subplots
|
||||
if num_cameras < len(axes_list):
|
||||
for i in range(num_cameras, len(axes_list)):
|
||||
axes_list[i].axis('off')
|
||||
|
||||
self.fig.tight_layout()
|
||||
return True
|
||||
|
||||
def update_frame(self, frame_num):
|
||||
"""Update function for animation"""
|
||||
try:
|
||||
# Receive new frame
|
||||
data = self.client.receive_message()
|
||||
|
||||
# Calculate FPS
|
||||
self.frame_count += 1
|
||||
current_time = time.time()
|
||||
if current_time - self.last_time >= 1.0:
|
||||
self.fps = self.frame_count / (current_time - self.last_time)
|
||||
self.frame_count = 0
|
||||
self.last_time = current_time
|
||||
|
||||
# Update each camera
|
||||
for cam_name in self.images.keys():
|
||||
# Get image data from nested or flat structure
|
||||
if "images" in data and cam_name in data["images"]:
|
||||
img_data = data["images"][cam_name]
|
||||
elif cam_name in data:
|
||||
img_data = data[cam_name]
|
||||
else:
|
||||
continue
|
||||
|
||||
# Decode image
|
||||
if isinstance(img_data, str):
|
||||
img = ImageUtils.decode_image(img_data)
|
||||
elif isinstance(img_data, np.ndarray):
|
||||
img = img_data
|
||||
else:
|
||||
continue
|
||||
|
||||
# Check if image is valid
|
||||
if img is None or not isinstance(img, np.ndarray):
|
||||
continue
|
||||
|
||||
# Convert BGR to RGB for matplotlib
|
||||
img_rgb = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
|
||||
|
||||
# Update image
|
||||
self.images[cam_name].set_data(img_rgb)
|
||||
|
||||
# Update FPS text
|
||||
self.text_objs[cam_name].set_text(f'FPS: {self.fps:.1f}')
|
||||
|
||||
except Exception as e:
|
||||
print(f"Error updating frame: {e}")
|
||||
|
||||
return list(self.images.values()) + list(self.text_objs.values())
|
||||
|
||||
def start(self, interval=33):
|
||||
"""Start the live viewer"""
|
||||
if not self.init_plot():
|
||||
return
|
||||
|
||||
print(f"\n{'='*60}")
|
||||
print("📹 Live camera viewer started!")
|
||||
print("Close the window or press Ctrl+C to exit")
|
||||
print(f"{'='*60}\n")
|
||||
|
||||
# Create animation
|
||||
anim = FuncAnimation(
|
||||
self.fig,
|
||||
self.update_frame,
|
||||
interval=interval, # ms between frames
|
||||
blit=True,
|
||||
cache_frame_data=False
|
||||
)
|
||||
|
||||
try:
|
||||
plt.show()
|
||||
except KeyboardInterrupt:
|
||||
print("\nStopping viewer...")
|
||||
finally:
|
||||
self.client.stop_client()
|
||||
plt.close('all')
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="Live camera viewer for MuJoCo simulator")
|
||||
parser.add_argument("--host", type=str, default="localhost",
|
||||
help="Simulator host address (default: localhost)")
|
||||
parser.add_argument("--port", type=int, default=5555,
|
||||
help="ZMQ port (default: 5555)")
|
||||
parser.add_argument("--interval", type=int, default=33,
|
||||
help="Update interval in ms (default: 33 = ~30fps)")
|
||||
args = parser.parse_args()
|
||||
|
||||
print("="*60)
|
||||
print("📷 MuJoCo Live Camera Viewer (matplotlib)")
|
||||
print("="*60)
|
||||
print(f"🌐 Connecting to: tcp://{args.host}:{args.port}")
|
||||
print(f"⏱️ Update interval: {args.interval}ms (~{1000/args.interval:.0f} fps)")
|
||||
print("="*60)
|
||||
|
||||
viewer = CameraViewer(host=args.host, port=args.port)
|
||||
viewer.start(interval=args.interval)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
Reference in New Issue
Block a user