diff --git a/src/lerobot/robots/reachy2/data_acquisition_server.py b/src/lerobot/robots/reachy2/data_acquisition_server.py new file mode 100644 index 000000000..39bf4dca9 --- /dev/null +++ b/src/lerobot/robots/reachy2/data_acquisition_server.py @@ -0,0 +1,447 @@ +import logging +from pathlib import Path +from typing import List +import json +import os +import shutil + +import cProfile +import pstats + +from concurrent import futures +from typing import Dict +import threading + +import grpc +import time +from logging import getLogger + +import torch.multiprocessing as mp + +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig +from lerobot.teleoperators.reachy2_fake_teleoperator import Reachy2FakeTeleoperator, Reachy2FakeTeleoperatorConfig +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +from lerobot.utils.visualization_utils import _init_rerun +from lerobot.record import record_loop + +from data_acquisition_api.data_acquisition_pb2 import ( + ActionAck, + EpisodeRating, + SessionParams, + Dataset, + DatasetList, + DatasetPushState, +) +from data_acquisition_api.data_acquisition_pb2_grpc import add_DataAcquisitionServiceServicer_to_server +from google.protobuf.empty_pb2 import Empty + + +class DataAcquisitionServicer(): + def __init__( + self, + ): + self._logger = getLogger(__name__) + self.play_sound = True + + self.thread: threading.Thread = None + self.events: Dict = {} + # self.events["stop_episode"] = False + # self.events["stop_session"] = False + # self.events["skip_break_time"] = False + + self.events["exit_early"] = False + # self.events["rerecord_episode"] = False + # self.events["stop_recording"] = False + + self.setup_over = False + + self.dataset_list_path = Path("datasets.json") + self.robot: Reachy2Robot = None + self.task: str = None + self.dataset: LeRobotDataset = None + + self.fps: int + self.episode_duration: int + self.break_time_duration: int + + self.episode_recording_in_progress: bool = False + # self.break_time_in_progress: bool = False + self.episode_saved: bool = False + self.episode_recorded_in_session: bool = False + + self.run_compute_stats: bool = True + + def register_to_server(self, server): + add_DataAcquisitionServiceServicer_to_server(self, server) + + def GetDatasetList(self, request: Empty, context: grpc.ServicerContext) -> DatasetList: + dataset_list = self.read_datasets_from_json(self.dataset_list_path) + return dataset_list + + def RemoveDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck: + self.remove_dataset_by_name(request.dataset_name, self.dataset_list_path) + return ActionAck(success_ack=True) + + def AddDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck: + # Create a new dataset + dataset = self.create_dataset(request.dataset_name, DatasetPushState.PUSHED, 0) + # Add it to the JSON file + self.add_dataset_to_json_file(dataset, self.dataset_list_path) + return ActionAck(success_ack=True) + + def UpdateDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck: + self.update_dataset(request, self.dataset_list_path) + return ActionAck(success_ack=True) + + def ClearAllDatasets(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + # Clear all datasets from the JSON file + self.clear_all_datasets(self.dataset_list_path) + return ActionAck(success_ack=True) + + def RemoveSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck: + self.delete_folder(request.session_name) + return ActionAck(success_ack=True) + + def ClearAllSessions(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + pass + + # Session and Episode Management with LeRobot + + def StartSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck: + self._logger.error(f"Starting session with params: {request}") + try: + self.setup_recording_session(request, context) + self.setup_over = True + ack = ActionAck(success_ack=True) + except Exception as e: + self._logger.error(f"Error starting session: {e}") + ack = ActionAck(success_ack=False) + return ack + + def StopSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + self._logger.error("Stopping session") + try: + # self.events["stop_episode"] = True + # self.events["stop_session"] = True + if self.thread and self.thread.is_alive(): + self.thread.join() # Wait for the thread to finish + log_say("Stop recording", play_sounds=self.play_sound, blocking=True) + self.robot.disconnect() + ack = ActionAck(success_ack=True) + self._logger.error("Session stopped") + except Exception as e: + self._logger.error(f"Error stopping session: {e}") + ack = ActionAck(success_ack=False) + return ack + + def StartEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + self._logger.error("Starting episode") + try: + if self.episode_recorded_in_session and not self.episode_saved: + self._logger.error("Episode not saved. Clearing episode buffer.") + self.dataset.clear_episode_buffer() + if not self.setup_over: + raise RuntimeError("Setup not completed. Please call StartSession first.") + if self.episode_recording_in_progress: + raise RuntimeError("Episode recording already in progress. Please stop it before starting a new one.") + self.episode_saved = False + # if self.break_time_in_progress: + # self.events["skip_break_time"] = True + self.thread = threading.Thread(target=self.record_episode) + self.thread.daemon = True + self.thread.start() + self.episode_recorded_in_session = True + ack = ActionAck(success_ack=True) + self._logger.error("Episode started") + except Exception as e: + self._logger.error(f"Error starting episode: {e}") + ack = ActionAck(success_ack=False) + return ack + + def StopEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + self._logger.error("Stopping episode") + try: + self.events["exit_early"] = True + if self.thread and self.thread.is_alive(): + self.thread.join() # Wait for the thread to finish + self.events["exit_early"] = False + ack = ActionAck(success_ack=True) + self._logger.error("Episode stopped") + except Exception as e: + self._logger.error(f"Error stopping session: {e}") + ack = ActionAck(success_ack=False) + return ack + + def SaveEpisode(self, request: EpisodeRating, context: grpc.ServicerContext) -> ActionAck: + self._logger.error("Saving episode") + try: + if self.episode_recording_in_progress: + raise RuntimeError("Episode recording in progress. Please stop it before saving.") + self.dataset.save_episode() + self.episode_saved = True + ack = ActionAck(success_ack=True) + self._logger.error("Episode saved") + except Exception as e: + self._logger.error(f"Error saving episode: {e}") + ack = ActionAck(success_ack=False) + return ack + + def UploadSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck: + try: + if self.episode_recording_in_progress: + raise RuntimeError("Episode recording in progress. Please stop it before uploading.") + self.dataset.push_to_hub() + ack = ActionAck(success_ack=True) + self._logger.error("Session uploaded") + except Exception as e: + self._logger.error(f"Error uploading session: {e}") + ack = ActionAck(success_ack=False) + return ack + + def setup_recording_session( + self, + request: SessionParams, + context: grpc.ServicerContext, + ): + # Create the robot and teleoperator configurations + self.robot_config = Reachy2RobotConfig( + ip_address=request.robot.ip_address, + id=request.robot.robot_id, + ) + self.teleop_config = Reachy2FakeTeleoperatorConfig( + ip_address=request.robot.ip_address, + ) + + # Initialize the robot and teleoperator + self.robot = Reachy2Robot(self.robot_config) + self.teleop = Reachy2FakeTeleoperator(self.teleop_config) + + self.fps = request.fps if request.HasField("fps") else 30 + self.task = request.task_description + self.episode_duration = request.episode_duration + 1 # TO FIX + self.break_time_duration = request.break_time_duration + + # Configure the dataset features + action_features = hw_to_dataset_features(self.robot.action_features, "action") + obs_features = hw_to_dataset_features(self.robot.observation_features, "observation") + dataset_features = {**action_features, **obs_features} + + # Create the dataset + self.dataset = LeRobotDataset.create( + repo_id=request.dataset_name, + fps=self.fps, + features=dataset_features, + robot_type=self.robot.name, + use_videos=True, + image_writer_threads=4, + ) + + current_dataset = self.create_dataset(request.dataset_name, DatasetPushState.LOCAL_ONLY, 0) + self.add_dataset_to_json_file(current_dataset, self.dataset_list_path) + + # Connect the robot and teleoperator + if not self.robot.is_connected: + self.robot.connect() + if not self.teleop.is_connected: + self.teleop.connect() + + def record_episode( + self, + display_cameras: bool = False, + play_sounds: bool = True, + ): + self.episode_recording_in_progress = True + record_loop( + robot=self.robot, + events=self.events, + fps=self.fps, + teleop=self.teleop, + dataset=self.dataset, + control_time_s=self.episode_duration, + single_task=self.task, + display_data=True, + ) + self.episode_recording_in_progress = False + + def create_dataset(self, dataset_name: str, pushed: DatasetPushState, nb_episodes: int) -> Dataset: + dataset = Dataset() + dataset.dataset_name = dataset_name + dataset.pushed = pushed + dataset.nb_episodes = nb_episodes + return dataset + + def dataset_to_dict(self, dataset: Dataset) -> dict: + """Convert Dataset proto to a dict manually.""" + return { + "dataset_name": dataset.dataset_name, + "pushed": DatasetPushState.Name(dataset.pushed), + "nb_episodes": dataset.nb_episodes + } + + def add_dataset_to_json_file(self, dataset: Dataset, json_filename: str): + if os.path.exists(json_filename): + with open(json_filename, "r") as f: + try: + data = json.load(f) + if not isinstance(data, list): + print("Warning: JSON root is not a list, resetting.") + data = [] + except json.JSONDecodeError: + print("Warning: Invalid JSON file, resetting.") + data = [] + else: + data = [] + + # Add the new dataset + data.append(self.dataset_to_dict(dataset)) + + # Save back to file + with open(json_filename, "w") as f: + json.dump(data, f, indent=2) + + def update_dataset(self, dataset: Dataset, json_filename: str): + """Update the 'pushed' field of a dataset in the JSON file.""" + if not os.path.exists(json_filename): + print(f"File '{json_filename}' does not exist.") + return + + with open(json_filename, "r") as f: + try: + data = json.load(f) + if not isinstance(data, list): + print("Warning: JSON root is not a list. No update performed.") + return + except json.JSONDecodeError: + print("Warning: JSON file is invalid. No update performed.") + return + + dataset_found = False + for d in data: + if d.get("dataset_name") == dataset.dataset_name: + d["pushed"] = DatasetPushState.Name(dataset.pushed) # Convert enum to string + dataset_found = True + break + + if not dataset_found: + print(f"No dataset with name '{dataset.dataset_name}' found.") + return + + with open(json_filename, "w") as f: + json.dump(data, f, indent=2) + + print(f"Dataset '{dataset.dataset_name}' updated successfully.") + + def dict_to_dataset(self, d: Dict) -> Dataset: + """Convert a dictionary to a Dataset proto.""" + dataset = Dataset() + dataset.dataset_name = d.get("dataset_name", "") + + pushed_str = d.get("pushed", "UNKNOWN") + if isinstance(pushed_str, str): + # Convert string to enum value + dataset.pushed = DatasetPushState.Value(pushed_str) + elif isinstance(pushed_str, int): + # Already an integer value + dataset.pushed = pushed_str + else: + dataset.pushed = DatasetPushState.UNKNOWN + + dataset.nb_episodes = d.get("nb_episodes", 0) + return dataset + + def read_datasets_from_json(self, json_filename: str) -> DatasetList: + dataset_list = DatasetList() + + if not os.path.exists(json_filename): + print(f"File {json_filename} does not exist.") + return dataset_list # Empty list + + with open(json_filename, "r") as f: + try: + data = json.load(f) + if not isinstance(data, list): + print("Warning: JSON root is not a list. Ignored.") + return dataset_list + except json.JSONDecodeError: + print("Warning: Invalid JSON file. Ignored.") + return dataset_list + + for d in data: + dataset = self.dict_to_dataset(d) + dataset_list.datasets.append(dataset) + + return dataset_list + + def remove_dataset_by_name(self, dataset_name: str, json_filename: str): + """Remove a dataset with a given name from the JSON file.""" + if not os.path.exists(json_filename): + print(f"File {json_filename} does not exist.") + return + + with open(json_filename, "r") as f: + try: + data = json.load(f) + if not isinstance(data, list): + print("Warning: JSON root is not a list. No action taken.") + return + except json.JSONDecodeError: + print("Warning: Invalid JSON file. No action taken.") + return + + original_length = len(data) + # Filter out datasets with the matching name + data = [d for d in data if d.get("dataset_name") != dataset_name] + + if len(data) == original_length: + print(f"No dataset found with name '{dataset_name}'. No action taken.") + return + + # Save the modified list back to the file + with open(json_filename, "w") as f: + json.dump(data, f, indent=2) + print(f"Dataset '{dataset_name}' removed successfully.") + + def clear_all_datasets(self, json_filename: str): + """Clear all datasets from the JSON file.""" + if not os.path.exists(json_filename): + print(f"File {json_filename} does not exist. Nothing to clear.") + return + + # Overwrite the file with an empty list + with open(json_filename, "w") as f: + json.dump([], f, indent=2) + print(f"All datasets have been cleared from '{json_filename}'.") + + def delete_folder(self, folder_path: str): + """Delete an entire folder and its content.""" + if not os.path.exists(folder_path): + print(f"Folder '{folder_path}' does not exist.") + return + + if not os.path.isdir(folder_path): + print(f"Path '{folder_path}' is not a folder.") + return + + try: + shutil.rmtree(folder_path) + print(f"Folder '{folder_path}' and all its contents have been deleted.") + except Exception as e: + print(f"Error while deleting folder '{folder_path}': {e}") + + +def main(): + mp.set_start_method('spawn', force=True) + server = grpc.server(futures.ThreadPoolExecutor(max_workers=10)) + + data_acquisition_servicer = DataAcquisitionServicer() + data_acquisition_servicer.register_to_server(server) + server.add_insecure_port('[::]:50062') + server.start() + server.wait_for_termination() + + +if __name__ == '__main__': + main() diff --git a/src/lerobot/robots/reachy2/test_policy.py b/src/lerobot/robots/reachy2/test_policy.py new file mode 100644 index 000000000..3005592ad --- /dev/null +++ b/src/lerobot/robots/reachy2/test_policy.py @@ -0,0 +1,93 @@ +from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig +from lerobot.datasets.lerobot_dataset import LeRobotDataset +from lerobot.datasets.utils import hw_to_dataset_features +from lerobot.policies.act.modeling_act import ACTPolicy +from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig +from lerobot.utils.control_utils import init_keyboard_listener +from lerobot.utils.utils import log_say +# from lerobot.utils.visualization_utils import _init_rerun +from lerobot.record import record_loop +from reachy2_sdk import ReachySDK +import numpy as np +import time + +NUM_EPISODES = 5 +FPS = 20 +EPISODE_TIME_SEC = 10 +TASK_DESCRIPTION = "Grab a cube in Mujoco simulation" + + +# Create the robot configuration +robot_config = Reachy2RobotConfig( + ip_address="localhost", + # ip_address="172.18.131.66", + id="test_reachy", +) + +# Initialize the robot +robot = Reachy2Robot(robot_config) + +reachy = ReachySDK("localhost") +reachy.turn_on() +reachy.mobile_base.goto(-0.2, -0.3, 0, wait=True) +time.sleep(2) +reachy.r_arm.goto_posture("elbow_90", wait=True) +reachy.r_arm.gripper.open() +reachy.mobile_base.goto(0, -0.3, 0) + +# Initialize the policy +policy = ACTPolicy.from_pretrained("pepijn223/grab_cube_simulation_2") + +# Configure the dataset features +action_features = hw_to_dataset_features(robot.action_features, "action") +obs_features = hw_to_dataset_features(robot.observation_features, "observation") +dataset_features = {**action_features, **obs_features} + +# Create the dataset +dataset = LeRobotDataset.create( + repo_id="glannuzel/eval_grab_cube_simulation_2", + fps=FPS, + features=dataset_features, + robot_type=robot.name, + use_videos=True, + image_writer_threads=4, +) + +# Initialize the keyboard listener and rerun visualization +_, events = init_keyboard_listener() +# _init_rerun(session_name="recording") + +# Connect the robot +robot.connect() + +M = reachy.r_arm.get_default_posture_matrix("elbow_90") +np.round(M, 3) +first_pose = M.copy() +first_pose[0, 3] += 0.05 +first_pose[1, 3] += 0.1 + +for episode_idx in range(NUM_EPISODES): + log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}") + + # Run the policy inference loop + record_loop( + robot=robot, + events=events, + fps=FPS, + policy=policy, + dataset=dataset, + control_time_s=EPISODE_TIME_SEC, + single_task=TASK_DESCRIPTION, + display_data=False, + ) + + reachy.r_arm.gripper.goto(100, percentage=True, wait=True) + reachy.head.goto_posture() + reachy.r_arm.goto(first_pose) + reachy.r_arm.goto(M, wait=True) + + dataset.save_episode() + +# Clean up +robot.disconnect() +# dataset.push_to_hub() \ No newline at end of file diff --git a/src/lerobot/robots/reachy2/test_recording.py b/src/lerobot/robots/reachy2/test_recording.py index 1da9e5f93..db35af3f1 100644 --- a/src/lerobot/robots/reachy2/test_recording.py +++ b/src/lerobot/robots/reachy2/test_recording.py @@ -7,17 +7,25 @@ from lerobot.utils.utils import log_say from lerobot.utils.visualization_utils import _init_rerun from lerobot.record import record_loop -NUM_EPISODES = 2 -FPS = 30 -EPISODE_TIME_SEC = 5 +import time + +NUM_EPISODES = 35 +FPS = 20 +EPISODE_TIME_SEC = 10 RESET_TIME_SEC = 5 -TASK_DESCRIPTION = "My task description" +TASK_DESCRIPTION = "Grab a cube in Mujoco simulation" # Create the robot and teleoperator configurations robot_config = Reachy2RobotConfig( - id="test_reachy" + # ip_address="localhost", + # ip_address="172.18.131.66", + ip_address="192.168.0.200", + id="test_reachy", +) +teleop_config = Reachy2FakeTeleoperatorConfig( + # ip_address="172.18.131.66", + ip_address="192.168.0.200", ) -teleop_config = Reachy2FakeTeleoperatorConfig() # Initialize the robot and teleoperator robot = Reachy2Robot(robot_config) @@ -30,7 +38,7 @@ dataset_features = {**action_features, **obs_features} # Create the dataset dataset = LeRobotDataset.create( - repo_id="test/repo_test", + repo_id="glannuzel/grab_cube", fps=FPS, features=dataset_features, robot_type=robot.name, @@ -40,7 +48,7 @@ dataset = LeRobotDataset.create( # Initialize the keyboard listener and rerun visualization _, events = init_keyboard_listener() -_init_rerun(session_name="recording") +# _init_rerun(session_name="recording") # Connect the robot and teleoperator robot.connect() @@ -48,8 +56,11 @@ teleop.connect() episode_idx = 0 while episode_idx < NUM_EPISODES and not events["stop_recording"]: + start_time = time.time() log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}") + print("########### RECORDING ###########") + record_loop( robot=robot, events=events, @@ -58,12 +69,14 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: dataset=dataset, control_time_s=EPISODE_TIME_SEC, single_task=TASK_DESCRIPTION, - display_data=True, + display_data=False, ) # Reset the environment if not stopping or re-recording if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]): log_say("Reset the environment") + + print("------------- RESETTING -------------") record_loop( robot=robot, events=events, @@ -71,7 +84,7 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: teleop=teleop, control_time_s=RESET_TIME_SEC, single_task=TASK_DESCRIPTION, - display_data=True, + display_data=False, ) if events["rerecord_episode"]: @@ -81,10 +94,12 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]: dataset.clear_episode_buffer() continue + # episode_idx = NUM_EPISODES dataset.save_episode() episode_idx += 1 + print(time.time()-start_time) # Clean up log_say("Stop recording") robot.disconnect() -# dataset.push_to_hub() \ No newline at end of file +dataset.push_to_hub()