Update tests

This commit is contained in:
glannuzel
2025-08-04 16:28:42 +02:00
parent ddaca00801
commit 8217f44235
3 changed files with 566 additions and 11 deletions
@@ -0,0 +1,447 @@
import logging
from pathlib import Path
from typing import List
import json
import os
import shutil
import cProfile
import pstats
from concurrent import futures
from typing import Dict
import threading
import grpc
import time
from logging import getLogger
import torch.multiprocessing as mp
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig
from lerobot.teleoperators.reachy2_fake_teleoperator import Reachy2FakeTeleoperator, Reachy2FakeTeleoperatorConfig
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
from data_acquisition_api.data_acquisition_pb2 import (
ActionAck,
EpisodeRating,
SessionParams,
Dataset,
DatasetList,
DatasetPushState,
)
from data_acquisition_api.data_acquisition_pb2_grpc import add_DataAcquisitionServiceServicer_to_server
from google.protobuf.empty_pb2 import Empty
class DataAcquisitionServicer():
def __init__(
self,
):
self._logger = getLogger(__name__)
self.play_sound = True
self.thread: threading.Thread = None
self.events: Dict = {}
# self.events["stop_episode"] = False
# self.events["stop_session"] = False
# self.events["skip_break_time"] = False
self.events["exit_early"] = False
# self.events["rerecord_episode"] = False
# self.events["stop_recording"] = False
self.setup_over = False
self.dataset_list_path = Path("datasets.json")
self.robot: Reachy2Robot = None
self.task: str = None
self.dataset: LeRobotDataset = None
self.fps: int
self.episode_duration: int
self.break_time_duration: int
self.episode_recording_in_progress: bool = False
# self.break_time_in_progress: bool = False
self.episode_saved: bool = False
self.episode_recorded_in_session: bool = False
self.run_compute_stats: bool = True
def register_to_server(self, server):
add_DataAcquisitionServiceServicer_to_server(self, server)
def GetDatasetList(self, request: Empty, context: grpc.ServicerContext) -> DatasetList:
dataset_list = self.read_datasets_from_json(self.dataset_list_path)
return dataset_list
def RemoveDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
self.remove_dataset_by_name(request.dataset_name, self.dataset_list_path)
return ActionAck(success_ack=True)
def AddDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
# Create a new dataset
dataset = self.create_dataset(request.dataset_name, DatasetPushState.PUSHED, 0)
# Add it to the JSON file
self.add_dataset_to_json_file(dataset, self.dataset_list_path)
return ActionAck(success_ack=True)
def UpdateDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
self.update_dataset(request, self.dataset_list_path)
return ActionAck(success_ack=True)
def ClearAllDatasets(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
# Clear all datasets from the JSON file
self.clear_all_datasets(self.dataset_list_path)
return ActionAck(success_ack=True)
def RemoveSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck:
self.delete_folder(request.session_name)
return ActionAck(success_ack=True)
def ClearAllSessions(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
pass
# Session and Episode Management with LeRobot
def StartSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck:
self._logger.error(f"Starting session with params: {request}")
try:
self.setup_recording_session(request, context)
self.setup_over = True
ack = ActionAck(success_ack=True)
except Exception as e:
self._logger.error(f"Error starting session: {e}")
ack = ActionAck(success_ack=False)
return ack
def StopSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
self._logger.error("Stopping session")
try:
# self.events["stop_episode"] = True
# self.events["stop_session"] = True
if self.thread and self.thread.is_alive():
self.thread.join() # Wait for the thread to finish
log_say("Stop recording", play_sounds=self.play_sound, blocking=True)
self.robot.disconnect()
ack = ActionAck(success_ack=True)
self._logger.error("Session stopped")
except Exception as e:
self._logger.error(f"Error stopping session: {e}")
ack = ActionAck(success_ack=False)
return ack
def StartEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
self._logger.error("Starting episode")
try:
if self.episode_recorded_in_session and not self.episode_saved:
self._logger.error("Episode not saved. Clearing episode buffer.")
self.dataset.clear_episode_buffer()
if not self.setup_over:
raise RuntimeError("Setup not completed. Please call StartSession first.")
if self.episode_recording_in_progress:
raise RuntimeError("Episode recording already in progress. Please stop it before starting a new one.")
self.episode_saved = False
# if self.break_time_in_progress:
# self.events["skip_break_time"] = True
self.thread = threading.Thread(target=self.record_episode)
self.thread.daemon = True
self.thread.start()
self.episode_recorded_in_session = True
ack = ActionAck(success_ack=True)
self._logger.error("Episode started")
except Exception as e:
self._logger.error(f"Error starting episode: {e}")
ack = ActionAck(success_ack=False)
return ack
def StopEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
self._logger.error("Stopping episode")
try:
self.events["exit_early"] = True
if self.thread and self.thread.is_alive():
self.thread.join() # Wait for the thread to finish
self.events["exit_early"] = False
ack = ActionAck(success_ack=True)
self._logger.error("Episode stopped")
except Exception as e:
self._logger.error(f"Error stopping session: {e}")
ack = ActionAck(success_ack=False)
return ack
def SaveEpisode(self, request: EpisodeRating, context: grpc.ServicerContext) -> ActionAck:
self._logger.error("Saving episode")
try:
if self.episode_recording_in_progress:
raise RuntimeError("Episode recording in progress. Please stop it before saving.")
self.dataset.save_episode()
self.episode_saved = True
ack = ActionAck(success_ack=True)
self._logger.error("Episode saved")
except Exception as e:
self._logger.error(f"Error saving episode: {e}")
ack = ActionAck(success_ack=False)
return ack
def UploadSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
try:
if self.episode_recording_in_progress:
raise RuntimeError("Episode recording in progress. Please stop it before uploading.")
self.dataset.push_to_hub()
ack = ActionAck(success_ack=True)
self._logger.error("Session uploaded")
except Exception as e:
self._logger.error(f"Error uploading session: {e}")
ack = ActionAck(success_ack=False)
return ack
def setup_recording_session(
self,
request: SessionParams,
context: grpc.ServicerContext,
):
# Create the robot and teleoperator configurations
self.robot_config = Reachy2RobotConfig(
ip_address=request.robot.ip_address,
id=request.robot.robot_id,
)
self.teleop_config = Reachy2FakeTeleoperatorConfig(
ip_address=request.robot.ip_address,
)
# Initialize the robot and teleoperator
self.robot = Reachy2Robot(self.robot_config)
self.teleop = Reachy2FakeTeleoperator(self.teleop_config)
self.fps = request.fps if request.HasField("fps") else 30
self.task = request.task_description
self.episode_duration = request.episode_duration + 1 # TO FIX
self.break_time_duration = request.break_time_duration
# Configure the dataset features
action_features = hw_to_dataset_features(self.robot.action_features, "action")
obs_features = hw_to_dataset_features(self.robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
self.dataset = LeRobotDataset.create(
repo_id=request.dataset_name,
fps=self.fps,
features=dataset_features,
robot_type=self.robot.name,
use_videos=True,
image_writer_threads=4,
)
current_dataset = self.create_dataset(request.dataset_name, DatasetPushState.LOCAL_ONLY, 0)
self.add_dataset_to_json_file(current_dataset, self.dataset_list_path)
# Connect the robot and teleoperator
if not self.robot.is_connected:
self.robot.connect()
if not self.teleop.is_connected:
self.teleop.connect()
def record_episode(
self,
display_cameras: bool = False,
play_sounds: bool = True,
):
self.episode_recording_in_progress = True
record_loop(
robot=self.robot,
events=self.events,
fps=self.fps,
teleop=self.teleop,
dataset=self.dataset,
control_time_s=self.episode_duration,
single_task=self.task,
display_data=True,
)
self.episode_recording_in_progress = False
def create_dataset(self, dataset_name: str, pushed: DatasetPushState, nb_episodes: int) -> Dataset:
dataset = Dataset()
dataset.dataset_name = dataset_name
dataset.pushed = pushed
dataset.nb_episodes = nb_episodes
return dataset
def dataset_to_dict(self, dataset: Dataset) -> dict:
"""Convert Dataset proto to a dict manually."""
return {
"dataset_name": dataset.dataset_name,
"pushed": DatasetPushState.Name(dataset.pushed),
"nb_episodes": dataset.nb_episodes
}
def add_dataset_to_json_file(self, dataset: Dataset, json_filename: str):
if os.path.exists(json_filename):
with open(json_filename, "r") as f:
try:
data = json.load(f)
if not isinstance(data, list):
print("Warning: JSON root is not a list, resetting.")
data = []
except json.JSONDecodeError:
print("Warning: Invalid JSON file, resetting.")
data = []
else:
data = []
# Add the new dataset
data.append(self.dataset_to_dict(dataset))
# Save back to file
with open(json_filename, "w") as f:
json.dump(data, f, indent=2)
def update_dataset(self, dataset: Dataset, json_filename: str):
"""Update the 'pushed' field of a dataset in the JSON file."""
if not os.path.exists(json_filename):
print(f"File '{json_filename}' does not exist.")
return
with open(json_filename, "r") as f:
try:
data = json.load(f)
if not isinstance(data, list):
print("Warning: JSON root is not a list. No update performed.")
return
except json.JSONDecodeError:
print("Warning: JSON file is invalid. No update performed.")
return
dataset_found = False
for d in data:
if d.get("dataset_name") == dataset.dataset_name:
d["pushed"] = DatasetPushState.Name(dataset.pushed) # Convert enum to string
dataset_found = True
break
if not dataset_found:
print(f"No dataset with name '{dataset.dataset_name}' found.")
return
with open(json_filename, "w") as f:
json.dump(data, f, indent=2)
print(f"Dataset '{dataset.dataset_name}' updated successfully.")
def dict_to_dataset(self, d: Dict) -> Dataset:
"""Convert a dictionary to a Dataset proto."""
dataset = Dataset()
dataset.dataset_name = d.get("dataset_name", "")
pushed_str = d.get("pushed", "UNKNOWN")
if isinstance(pushed_str, str):
# Convert string to enum value
dataset.pushed = DatasetPushState.Value(pushed_str)
elif isinstance(pushed_str, int):
# Already an integer value
dataset.pushed = pushed_str
else:
dataset.pushed = DatasetPushState.UNKNOWN
dataset.nb_episodes = d.get("nb_episodes", 0)
return dataset
def read_datasets_from_json(self, json_filename: str) -> DatasetList:
dataset_list = DatasetList()
if not os.path.exists(json_filename):
print(f"File {json_filename} does not exist.")
return dataset_list # Empty list
with open(json_filename, "r") as f:
try:
data = json.load(f)
if not isinstance(data, list):
print("Warning: JSON root is not a list. Ignored.")
return dataset_list
except json.JSONDecodeError:
print("Warning: Invalid JSON file. Ignored.")
return dataset_list
for d in data:
dataset = self.dict_to_dataset(d)
dataset_list.datasets.append(dataset)
return dataset_list
def remove_dataset_by_name(self, dataset_name: str, json_filename: str):
"""Remove a dataset with a given name from the JSON file."""
if not os.path.exists(json_filename):
print(f"File {json_filename} does not exist.")
return
with open(json_filename, "r") as f:
try:
data = json.load(f)
if not isinstance(data, list):
print("Warning: JSON root is not a list. No action taken.")
return
except json.JSONDecodeError:
print("Warning: Invalid JSON file. No action taken.")
return
original_length = len(data)
# Filter out datasets with the matching name
data = [d for d in data if d.get("dataset_name") != dataset_name]
if len(data) == original_length:
print(f"No dataset found with name '{dataset_name}'. No action taken.")
return
# Save the modified list back to the file
with open(json_filename, "w") as f:
json.dump(data, f, indent=2)
print(f"Dataset '{dataset_name}' removed successfully.")
def clear_all_datasets(self, json_filename: str):
"""Clear all datasets from the JSON file."""
if not os.path.exists(json_filename):
print(f"File {json_filename} does not exist. Nothing to clear.")
return
# Overwrite the file with an empty list
with open(json_filename, "w") as f:
json.dump([], f, indent=2)
print(f"All datasets have been cleared from '{json_filename}'.")
def delete_folder(self, folder_path: str):
"""Delete an entire folder and its content."""
if not os.path.exists(folder_path):
print(f"Folder '{folder_path}' does not exist.")
return
if not os.path.isdir(folder_path):
print(f"Path '{folder_path}' is not a folder.")
return
try:
shutil.rmtree(folder_path)
print(f"Folder '{folder_path}' and all its contents have been deleted.")
except Exception as e:
print(f"Error while deleting folder '{folder_path}': {e}")
def main():
mp.set_start_method('spawn', force=True)
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
data_acquisition_servicer = DataAcquisitionServicer()
data_acquisition_servicer.register_to_server(server)
server.add_insecure_port('[::]:50062')
server.start()
server.wait_for_termination()
if __name__ == '__main__':
main()
+93
View File
@@ -0,0 +1,93 @@
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.utils import hw_to_dataset_features
from lerobot.policies.act.modeling_act import ACTPolicy
from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig
from lerobot.utils.control_utils import init_keyboard_listener
from lerobot.utils.utils import log_say
# from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
from reachy2_sdk import ReachySDK
import numpy as np
import time
NUM_EPISODES = 5
FPS = 20
EPISODE_TIME_SEC = 10
TASK_DESCRIPTION = "Grab a cube in Mujoco simulation"
# Create the robot configuration
robot_config = Reachy2RobotConfig(
ip_address="localhost",
# ip_address="172.18.131.66",
id="test_reachy",
)
# Initialize the robot
robot = Reachy2Robot(robot_config)
reachy = ReachySDK("localhost")
reachy.turn_on()
reachy.mobile_base.goto(-0.2, -0.3, 0, wait=True)
time.sleep(2)
reachy.r_arm.goto_posture("elbow_90", wait=True)
reachy.r_arm.gripper.open()
reachy.mobile_base.goto(0, -0.3, 0)
# Initialize the policy
policy = ACTPolicy.from_pretrained("pepijn223/grab_cube_simulation_2")
# Configure the dataset features
action_features = hw_to_dataset_features(robot.action_features, "action")
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="glannuzel/eval_grab_cube_simulation_2",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
use_videos=True,
image_writer_threads=4,
)
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
# _init_rerun(session_name="recording")
# Connect the robot
robot.connect()
M = reachy.r_arm.get_default_posture_matrix("elbow_90")
np.round(M, 3)
first_pose = M.copy()
first_pose[0, 3] += 0.05
first_pose[1, 3] += 0.1
for episode_idx in range(NUM_EPISODES):
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
# Run the policy inference loop
record_loop(
robot=robot,
events=events,
fps=FPS,
policy=policy,
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=False,
)
reachy.r_arm.gripper.goto(100, percentage=True, wait=True)
reachy.head.goto_posture()
reachy.r_arm.goto(first_pose)
reachy.r_arm.goto(M, wait=True)
dataset.save_episode()
# Clean up
robot.disconnect()
# dataset.push_to_hub()
+26 -11
View File
@@ -7,17 +7,25 @@ from lerobot.utils.utils import log_say
from lerobot.utils.visualization_utils import _init_rerun
from lerobot.record import record_loop
NUM_EPISODES = 2
FPS = 30
EPISODE_TIME_SEC = 5
import time
NUM_EPISODES = 35
FPS = 20
EPISODE_TIME_SEC = 10
RESET_TIME_SEC = 5
TASK_DESCRIPTION = "My task description"
TASK_DESCRIPTION = "Grab a cube in Mujoco simulation"
# Create the robot and teleoperator configurations
robot_config = Reachy2RobotConfig(
id="test_reachy"
# ip_address="localhost",
# ip_address="172.18.131.66",
ip_address="192.168.0.200",
id="test_reachy",
)
teleop_config = Reachy2FakeTeleoperatorConfig(
# ip_address="172.18.131.66",
ip_address="192.168.0.200",
)
teleop_config = Reachy2FakeTeleoperatorConfig()
# Initialize the robot and teleoperator
robot = Reachy2Robot(robot_config)
@@ -30,7 +38,7 @@ dataset_features = {**action_features, **obs_features}
# Create the dataset
dataset = LeRobotDataset.create(
repo_id="test/repo_test",
repo_id="glannuzel/grab_cube",
fps=FPS,
features=dataset_features,
robot_type=robot.name,
@@ -40,7 +48,7 @@ dataset = LeRobotDataset.create(
# Initialize the keyboard listener and rerun visualization
_, events = init_keyboard_listener()
_init_rerun(session_name="recording")
# _init_rerun(session_name="recording")
# Connect the robot and teleoperator
robot.connect()
@@ -48,8 +56,11 @@ teleop.connect()
episode_idx = 0
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
start_time = time.time()
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
print("########### RECORDING ###########")
record_loop(
robot=robot,
events=events,
@@ -58,12 +69,14 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
dataset=dataset,
control_time_s=EPISODE_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
display_data=False,
)
# Reset the environment if not stopping or re-recording
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
log_say("Reset the environment")
print("------------- RESETTING -------------")
record_loop(
robot=robot,
events=events,
@@ -71,7 +84,7 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
teleop=teleop,
control_time_s=RESET_TIME_SEC,
single_task=TASK_DESCRIPTION,
display_data=True,
display_data=False,
)
if events["rerecord_episode"]:
@@ -81,10 +94,12 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
dataset.clear_episode_buffer()
continue
# episode_idx = NUM_EPISODES
dataset.save_episode()
episode_idx += 1
print(time.time()-start_time)
# Clean up
log_say("Stop recording")
robot.disconnect()
# dataset.push_to_hub()
dataset.push_to_hub()