mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-23 20:50:02 +00:00
Update tests
This commit is contained in:
@@ -0,0 +1,447 @@
|
|||||||
|
import logging
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import List
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
|
||||||
|
import cProfile
|
||||||
|
import pstats
|
||||||
|
|
||||||
|
from concurrent import futures
|
||||||
|
from typing import Dict
|
||||||
|
import threading
|
||||||
|
|
||||||
|
import grpc
|
||||||
|
import time
|
||||||
|
from logging import getLogger
|
||||||
|
|
||||||
|
import torch.multiprocessing as mp
|
||||||
|
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.utils import hw_to_dataset_features
|
||||||
|
from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig
|
||||||
|
from lerobot.teleoperators.reachy2_fake_teleoperator import Reachy2FakeTeleoperator, Reachy2FakeTeleoperatorConfig
|
||||||
|
from lerobot.utils.control_utils import init_keyboard_listener
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
from lerobot.utils.visualization_utils import _init_rerun
|
||||||
|
from lerobot.record import record_loop
|
||||||
|
|
||||||
|
from data_acquisition_api.data_acquisition_pb2 import (
|
||||||
|
ActionAck,
|
||||||
|
EpisodeRating,
|
||||||
|
SessionParams,
|
||||||
|
Dataset,
|
||||||
|
DatasetList,
|
||||||
|
DatasetPushState,
|
||||||
|
)
|
||||||
|
from data_acquisition_api.data_acquisition_pb2_grpc import add_DataAcquisitionServiceServicer_to_server
|
||||||
|
from google.protobuf.empty_pb2 import Empty
|
||||||
|
|
||||||
|
|
||||||
|
class DataAcquisitionServicer():
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
):
|
||||||
|
self._logger = getLogger(__name__)
|
||||||
|
self.play_sound = True
|
||||||
|
|
||||||
|
self.thread: threading.Thread = None
|
||||||
|
self.events: Dict = {}
|
||||||
|
# self.events["stop_episode"] = False
|
||||||
|
# self.events["stop_session"] = False
|
||||||
|
# self.events["skip_break_time"] = False
|
||||||
|
|
||||||
|
self.events["exit_early"] = False
|
||||||
|
# self.events["rerecord_episode"] = False
|
||||||
|
# self.events["stop_recording"] = False
|
||||||
|
|
||||||
|
self.setup_over = False
|
||||||
|
|
||||||
|
self.dataset_list_path = Path("datasets.json")
|
||||||
|
self.robot: Reachy2Robot = None
|
||||||
|
self.task: str = None
|
||||||
|
self.dataset: LeRobotDataset = None
|
||||||
|
|
||||||
|
self.fps: int
|
||||||
|
self.episode_duration: int
|
||||||
|
self.break_time_duration: int
|
||||||
|
|
||||||
|
self.episode_recording_in_progress: bool = False
|
||||||
|
# self.break_time_in_progress: bool = False
|
||||||
|
self.episode_saved: bool = False
|
||||||
|
self.episode_recorded_in_session: bool = False
|
||||||
|
|
||||||
|
self.run_compute_stats: bool = True
|
||||||
|
|
||||||
|
def register_to_server(self, server):
|
||||||
|
add_DataAcquisitionServiceServicer_to_server(self, server)
|
||||||
|
|
||||||
|
def GetDatasetList(self, request: Empty, context: grpc.ServicerContext) -> DatasetList:
|
||||||
|
dataset_list = self.read_datasets_from_json(self.dataset_list_path)
|
||||||
|
return dataset_list
|
||||||
|
|
||||||
|
def RemoveDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self.remove_dataset_by_name(request.dataset_name, self.dataset_list_path)
|
||||||
|
return ActionAck(success_ack=True)
|
||||||
|
|
||||||
|
def AddDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
# Create a new dataset
|
||||||
|
dataset = self.create_dataset(request.dataset_name, DatasetPushState.PUSHED, 0)
|
||||||
|
# Add it to the JSON file
|
||||||
|
self.add_dataset_to_json_file(dataset, self.dataset_list_path)
|
||||||
|
return ActionAck(success_ack=True)
|
||||||
|
|
||||||
|
def UpdateDataset(self, request: Dataset, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self.update_dataset(request, self.dataset_list_path)
|
||||||
|
return ActionAck(success_ack=True)
|
||||||
|
|
||||||
|
def ClearAllDatasets(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
# Clear all datasets from the JSON file
|
||||||
|
self.clear_all_datasets(self.dataset_list_path)
|
||||||
|
return ActionAck(success_ack=True)
|
||||||
|
|
||||||
|
def RemoveSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self.delete_folder(request.session_name)
|
||||||
|
return ActionAck(success_ack=True)
|
||||||
|
|
||||||
|
def ClearAllSessions(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# Session and Episode Management with LeRobot
|
||||||
|
|
||||||
|
def StartSession(self, request: SessionParams, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self._logger.error(f"Starting session with params: {request}")
|
||||||
|
try:
|
||||||
|
self.setup_recording_session(request, context)
|
||||||
|
self.setup_over = True
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error starting session: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def StopSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self._logger.error("Stopping session")
|
||||||
|
try:
|
||||||
|
# self.events["stop_episode"] = True
|
||||||
|
# self.events["stop_session"] = True
|
||||||
|
if self.thread and self.thread.is_alive():
|
||||||
|
self.thread.join() # Wait for the thread to finish
|
||||||
|
log_say("Stop recording", play_sounds=self.play_sound, blocking=True)
|
||||||
|
self.robot.disconnect()
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
self._logger.error("Session stopped")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error stopping session: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def StartEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self._logger.error("Starting episode")
|
||||||
|
try:
|
||||||
|
if self.episode_recorded_in_session and not self.episode_saved:
|
||||||
|
self._logger.error("Episode not saved. Clearing episode buffer.")
|
||||||
|
self.dataset.clear_episode_buffer()
|
||||||
|
if not self.setup_over:
|
||||||
|
raise RuntimeError("Setup not completed. Please call StartSession first.")
|
||||||
|
if self.episode_recording_in_progress:
|
||||||
|
raise RuntimeError("Episode recording already in progress. Please stop it before starting a new one.")
|
||||||
|
self.episode_saved = False
|
||||||
|
# if self.break_time_in_progress:
|
||||||
|
# self.events["skip_break_time"] = True
|
||||||
|
self.thread = threading.Thread(target=self.record_episode)
|
||||||
|
self.thread.daemon = True
|
||||||
|
self.thread.start()
|
||||||
|
self.episode_recorded_in_session = True
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
self._logger.error("Episode started")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error starting episode: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def StopEpisode(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self._logger.error("Stopping episode")
|
||||||
|
try:
|
||||||
|
self.events["exit_early"] = True
|
||||||
|
if self.thread and self.thread.is_alive():
|
||||||
|
self.thread.join() # Wait for the thread to finish
|
||||||
|
self.events["exit_early"] = False
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
self._logger.error("Episode stopped")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error stopping session: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def SaveEpisode(self, request: EpisodeRating, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
self._logger.error("Saving episode")
|
||||||
|
try:
|
||||||
|
if self.episode_recording_in_progress:
|
||||||
|
raise RuntimeError("Episode recording in progress. Please stop it before saving.")
|
||||||
|
self.dataset.save_episode()
|
||||||
|
self.episode_saved = True
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
self._logger.error("Episode saved")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error saving episode: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def UploadSession(self, request: Empty, context: grpc.ServicerContext) -> ActionAck:
|
||||||
|
try:
|
||||||
|
if self.episode_recording_in_progress:
|
||||||
|
raise RuntimeError("Episode recording in progress. Please stop it before uploading.")
|
||||||
|
self.dataset.push_to_hub()
|
||||||
|
ack = ActionAck(success_ack=True)
|
||||||
|
self._logger.error("Session uploaded")
|
||||||
|
except Exception as e:
|
||||||
|
self._logger.error(f"Error uploading session: {e}")
|
||||||
|
ack = ActionAck(success_ack=False)
|
||||||
|
return ack
|
||||||
|
|
||||||
|
def setup_recording_session(
|
||||||
|
self,
|
||||||
|
request: SessionParams,
|
||||||
|
context: grpc.ServicerContext,
|
||||||
|
):
|
||||||
|
# Create the robot and teleoperator configurations
|
||||||
|
self.robot_config = Reachy2RobotConfig(
|
||||||
|
ip_address=request.robot.ip_address,
|
||||||
|
id=request.robot.robot_id,
|
||||||
|
)
|
||||||
|
self.teleop_config = Reachy2FakeTeleoperatorConfig(
|
||||||
|
ip_address=request.robot.ip_address,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the robot and teleoperator
|
||||||
|
self.robot = Reachy2Robot(self.robot_config)
|
||||||
|
self.teleop = Reachy2FakeTeleoperator(self.teleop_config)
|
||||||
|
|
||||||
|
self.fps = request.fps if request.HasField("fps") else 30
|
||||||
|
self.task = request.task_description
|
||||||
|
self.episode_duration = request.episode_duration + 1 # TO FIX
|
||||||
|
self.break_time_duration = request.break_time_duration
|
||||||
|
|
||||||
|
# Configure the dataset features
|
||||||
|
action_features = hw_to_dataset_features(self.robot.action_features, "action")
|
||||||
|
obs_features = hw_to_dataset_features(self.robot.observation_features, "observation")
|
||||||
|
dataset_features = {**action_features, **obs_features}
|
||||||
|
|
||||||
|
# Create the dataset
|
||||||
|
self.dataset = LeRobotDataset.create(
|
||||||
|
repo_id=request.dataset_name,
|
||||||
|
fps=self.fps,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=self.robot.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
current_dataset = self.create_dataset(request.dataset_name, DatasetPushState.LOCAL_ONLY, 0)
|
||||||
|
self.add_dataset_to_json_file(current_dataset, self.dataset_list_path)
|
||||||
|
|
||||||
|
# Connect the robot and teleoperator
|
||||||
|
if not self.robot.is_connected:
|
||||||
|
self.robot.connect()
|
||||||
|
if not self.teleop.is_connected:
|
||||||
|
self.teleop.connect()
|
||||||
|
|
||||||
|
def record_episode(
|
||||||
|
self,
|
||||||
|
display_cameras: bool = False,
|
||||||
|
play_sounds: bool = True,
|
||||||
|
):
|
||||||
|
self.episode_recording_in_progress = True
|
||||||
|
record_loop(
|
||||||
|
robot=self.robot,
|
||||||
|
events=self.events,
|
||||||
|
fps=self.fps,
|
||||||
|
teleop=self.teleop,
|
||||||
|
dataset=self.dataset,
|
||||||
|
control_time_s=self.episode_duration,
|
||||||
|
single_task=self.task,
|
||||||
|
display_data=True,
|
||||||
|
)
|
||||||
|
self.episode_recording_in_progress = False
|
||||||
|
|
||||||
|
def create_dataset(self, dataset_name: str, pushed: DatasetPushState, nb_episodes: int) -> Dataset:
|
||||||
|
dataset = Dataset()
|
||||||
|
dataset.dataset_name = dataset_name
|
||||||
|
dataset.pushed = pushed
|
||||||
|
dataset.nb_episodes = nb_episodes
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def dataset_to_dict(self, dataset: Dataset) -> dict:
|
||||||
|
"""Convert Dataset proto to a dict manually."""
|
||||||
|
return {
|
||||||
|
"dataset_name": dataset.dataset_name,
|
||||||
|
"pushed": DatasetPushState.Name(dataset.pushed),
|
||||||
|
"nb_episodes": dataset.nb_episodes
|
||||||
|
}
|
||||||
|
|
||||||
|
def add_dataset_to_json_file(self, dataset: Dataset, json_filename: str):
|
||||||
|
if os.path.exists(json_filename):
|
||||||
|
with open(json_filename, "r") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
print("Warning: JSON root is not a list, resetting.")
|
||||||
|
data = []
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Warning: Invalid JSON file, resetting.")
|
||||||
|
data = []
|
||||||
|
else:
|
||||||
|
data = []
|
||||||
|
|
||||||
|
# Add the new dataset
|
||||||
|
data.append(self.dataset_to_dict(dataset))
|
||||||
|
|
||||||
|
# Save back to file
|
||||||
|
with open(json_filename, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
def update_dataset(self, dataset: Dataset, json_filename: str):
|
||||||
|
"""Update the 'pushed' field of a dataset in the JSON file."""
|
||||||
|
if not os.path.exists(json_filename):
|
||||||
|
print(f"File '{json_filename}' does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(json_filename, "r") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
print("Warning: JSON root is not a list. No update performed.")
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Warning: JSON file is invalid. No update performed.")
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset_found = False
|
||||||
|
for d in data:
|
||||||
|
if d.get("dataset_name") == dataset.dataset_name:
|
||||||
|
d["pushed"] = DatasetPushState.Name(dataset.pushed) # Convert enum to string
|
||||||
|
dataset_found = True
|
||||||
|
break
|
||||||
|
|
||||||
|
if not dataset_found:
|
||||||
|
print(f"No dataset with name '{dataset.dataset_name}' found.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(json_filename, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
|
||||||
|
print(f"Dataset '{dataset.dataset_name}' updated successfully.")
|
||||||
|
|
||||||
|
def dict_to_dataset(self, d: Dict) -> Dataset:
|
||||||
|
"""Convert a dictionary to a Dataset proto."""
|
||||||
|
dataset = Dataset()
|
||||||
|
dataset.dataset_name = d.get("dataset_name", "")
|
||||||
|
|
||||||
|
pushed_str = d.get("pushed", "UNKNOWN")
|
||||||
|
if isinstance(pushed_str, str):
|
||||||
|
# Convert string to enum value
|
||||||
|
dataset.pushed = DatasetPushState.Value(pushed_str)
|
||||||
|
elif isinstance(pushed_str, int):
|
||||||
|
# Already an integer value
|
||||||
|
dataset.pushed = pushed_str
|
||||||
|
else:
|
||||||
|
dataset.pushed = DatasetPushState.UNKNOWN
|
||||||
|
|
||||||
|
dataset.nb_episodes = d.get("nb_episodes", 0)
|
||||||
|
return dataset
|
||||||
|
|
||||||
|
def read_datasets_from_json(self, json_filename: str) -> DatasetList:
|
||||||
|
dataset_list = DatasetList()
|
||||||
|
|
||||||
|
if not os.path.exists(json_filename):
|
||||||
|
print(f"File {json_filename} does not exist.")
|
||||||
|
return dataset_list # Empty list
|
||||||
|
|
||||||
|
with open(json_filename, "r") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
print("Warning: JSON root is not a list. Ignored.")
|
||||||
|
return dataset_list
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Warning: Invalid JSON file. Ignored.")
|
||||||
|
return dataset_list
|
||||||
|
|
||||||
|
for d in data:
|
||||||
|
dataset = self.dict_to_dataset(d)
|
||||||
|
dataset_list.datasets.append(dataset)
|
||||||
|
|
||||||
|
return dataset_list
|
||||||
|
|
||||||
|
def remove_dataset_by_name(self, dataset_name: str, json_filename: str):
|
||||||
|
"""Remove a dataset with a given name from the JSON file."""
|
||||||
|
if not os.path.exists(json_filename):
|
||||||
|
print(f"File {json_filename} does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
with open(json_filename, "r") as f:
|
||||||
|
try:
|
||||||
|
data = json.load(f)
|
||||||
|
if not isinstance(data, list):
|
||||||
|
print("Warning: JSON root is not a list. No action taken.")
|
||||||
|
return
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
print("Warning: Invalid JSON file. No action taken.")
|
||||||
|
return
|
||||||
|
|
||||||
|
original_length = len(data)
|
||||||
|
# Filter out datasets with the matching name
|
||||||
|
data = [d for d in data if d.get("dataset_name") != dataset_name]
|
||||||
|
|
||||||
|
if len(data) == original_length:
|
||||||
|
print(f"No dataset found with name '{dataset_name}'. No action taken.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Save the modified list back to the file
|
||||||
|
with open(json_filename, "w") as f:
|
||||||
|
json.dump(data, f, indent=2)
|
||||||
|
print(f"Dataset '{dataset_name}' removed successfully.")
|
||||||
|
|
||||||
|
def clear_all_datasets(self, json_filename: str):
|
||||||
|
"""Clear all datasets from the JSON file."""
|
||||||
|
if not os.path.exists(json_filename):
|
||||||
|
print(f"File {json_filename} does not exist. Nothing to clear.")
|
||||||
|
return
|
||||||
|
|
||||||
|
# Overwrite the file with an empty list
|
||||||
|
with open(json_filename, "w") as f:
|
||||||
|
json.dump([], f, indent=2)
|
||||||
|
print(f"All datasets have been cleared from '{json_filename}'.")
|
||||||
|
|
||||||
|
def delete_folder(self, folder_path: str):
|
||||||
|
"""Delete an entire folder and its content."""
|
||||||
|
if not os.path.exists(folder_path):
|
||||||
|
print(f"Folder '{folder_path}' does not exist.")
|
||||||
|
return
|
||||||
|
|
||||||
|
if not os.path.isdir(folder_path):
|
||||||
|
print(f"Path '{folder_path}' is not a folder.")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
shutil.rmtree(folder_path)
|
||||||
|
print(f"Folder '{folder_path}' and all its contents have been deleted.")
|
||||||
|
except Exception as e:
|
||||||
|
print(f"Error while deleting folder '{folder_path}': {e}")
|
||||||
|
|
||||||
|
|
||||||
|
def main():
|
||||||
|
mp.set_start_method('spawn', force=True)
|
||||||
|
server = grpc.server(futures.ThreadPoolExecutor(max_workers=10))
|
||||||
|
|
||||||
|
data_acquisition_servicer = DataAcquisitionServicer()
|
||||||
|
data_acquisition_servicer.register_to_server(server)
|
||||||
|
server.add_insecure_port('[::]:50062')
|
||||||
|
server.start()
|
||||||
|
server.wait_for_termination()
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
main()
|
||||||
@@ -0,0 +1,93 @@
|
|||||||
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.utils import hw_to_dataset_features
|
||||||
|
from lerobot.policies.act.modeling_act import ACTPolicy
|
||||||
|
from lerobot.robots.reachy2 import Reachy2Robot, Reachy2RobotConfig
|
||||||
|
from lerobot.utils.control_utils import init_keyboard_listener
|
||||||
|
from lerobot.utils.utils import log_say
|
||||||
|
# from lerobot.utils.visualization_utils import _init_rerun
|
||||||
|
from lerobot.record import record_loop
|
||||||
|
from reachy2_sdk import ReachySDK
|
||||||
|
import numpy as np
|
||||||
|
import time
|
||||||
|
|
||||||
|
NUM_EPISODES = 5
|
||||||
|
FPS = 20
|
||||||
|
EPISODE_TIME_SEC = 10
|
||||||
|
TASK_DESCRIPTION = "Grab a cube in Mujoco simulation"
|
||||||
|
|
||||||
|
|
||||||
|
# Create the robot configuration
|
||||||
|
robot_config = Reachy2RobotConfig(
|
||||||
|
ip_address="localhost",
|
||||||
|
# ip_address="172.18.131.66",
|
||||||
|
id="test_reachy",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the robot
|
||||||
|
robot = Reachy2Robot(robot_config)
|
||||||
|
|
||||||
|
reachy = ReachySDK("localhost")
|
||||||
|
reachy.turn_on()
|
||||||
|
reachy.mobile_base.goto(-0.2, -0.3, 0, wait=True)
|
||||||
|
time.sleep(2)
|
||||||
|
reachy.r_arm.goto_posture("elbow_90", wait=True)
|
||||||
|
reachy.r_arm.gripper.open()
|
||||||
|
reachy.mobile_base.goto(0, -0.3, 0)
|
||||||
|
|
||||||
|
# Initialize the policy
|
||||||
|
policy = ACTPolicy.from_pretrained("pepijn223/grab_cube_simulation_2")
|
||||||
|
|
||||||
|
# Configure the dataset features
|
||||||
|
action_features = hw_to_dataset_features(robot.action_features, "action")
|
||||||
|
obs_features = hw_to_dataset_features(robot.observation_features, "observation")
|
||||||
|
dataset_features = {**action_features, **obs_features}
|
||||||
|
|
||||||
|
# Create the dataset
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id="glannuzel/eval_grab_cube_simulation_2",
|
||||||
|
fps=FPS,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=robot.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_threads=4,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Initialize the keyboard listener and rerun visualization
|
||||||
|
_, events = init_keyboard_listener()
|
||||||
|
# _init_rerun(session_name="recording")
|
||||||
|
|
||||||
|
# Connect the robot
|
||||||
|
robot.connect()
|
||||||
|
|
||||||
|
M = reachy.r_arm.get_default_posture_matrix("elbow_90")
|
||||||
|
np.round(M, 3)
|
||||||
|
first_pose = M.copy()
|
||||||
|
first_pose[0, 3] += 0.05
|
||||||
|
first_pose[1, 3] += 0.1
|
||||||
|
|
||||||
|
for episode_idx in range(NUM_EPISODES):
|
||||||
|
log_say(f"Running inference, recording eval episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
|
# Run the policy inference loop
|
||||||
|
record_loop(
|
||||||
|
robot=robot,
|
||||||
|
events=events,
|
||||||
|
fps=FPS,
|
||||||
|
policy=policy,
|
||||||
|
dataset=dataset,
|
||||||
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
|
single_task=TASK_DESCRIPTION,
|
||||||
|
display_data=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
reachy.r_arm.gripper.goto(100, percentage=True, wait=True)
|
||||||
|
reachy.head.goto_posture()
|
||||||
|
reachy.r_arm.goto(first_pose)
|
||||||
|
reachy.r_arm.goto(M, wait=True)
|
||||||
|
|
||||||
|
dataset.save_episode()
|
||||||
|
|
||||||
|
# Clean up
|
||||||
|
robot.disconnect()
|
||||||
|
# dataset.push_to_hub()
|
||||||
@@ -7,17 +7,25 @@ from lerobot.utils.utils import log_say
|
|||||||
from lerobot.utils.visualization_utils import _init_rerun
|
from lerobot.utils.visualization_utils import _init_rerun
|
||||||
from lerobot.record import record_loop
|
from lerobot.record import record_loop
|
||||||
|
|
||||||
NUM_EPISODES = 2
|
import time
|
||||||
FPS = 30
|
|
||||||
EPISODE_TIME_SEC = 5
|
NUM_EPISODES = 35
|
||||||
|
FPS = 20
|
||||||
|
EPISODE_TIME_SEC = 10
|
||||||
RESET_TIME_SEC = 5
|
RESET_TIME_SEC = 5
|
||||||
TASK_DESCRIPTION = "My task description"
|
TASK_DESCRIPTION = "Grab a cube in Mujoco simulation"
|
||||||
|
|
||||||
# Create the robot and teleoperator configurations
|
# Create the robot and teleoperator configurations
|
||||||
robot_config = Reachy2RobotConfig(
|
robot_config = Reachy2RobotConfig(
|
||||||
id="test_reachy"
|
# ip_address="localhost",
|
||||||
|
# ip_address="172.18.131.66",
|
||||||
|
ip_address="192.168.0.200",
|
||||||
|
id="test_reachy",
|
||||||
|
)
|
||||||
|
teleop_config = Reachy2FakeTeleoperatorConfig(
|
||||||
|
# ip_address="172.18.131.66",
|
||||||
|
ip_address="192.168.0.200",
|
||||||
)
|
)
|
||||||
teleop_config = Reachy2FakeTeleoperatorConfig()
|
|
||||||
|
|
||||||
# Initialize the robot and teleoperator
|
# Initialize the robot and teleoperator
|
||||||
robot = Reachy2Robot(robot_config)
|
robot = Reachy2Robot(robot_config)
|
||||||
@@ -30,7 +38,7 @@ dataset_features = {**action_features, **obs_features}
|
|||||||
|
|
||||||
# Create the dataset
|
# Create the dataset
|
||||||
dataset = LeRobotDataset.create(
|
dataset = LeRobotDataset.create(
|
||||||
repo_id="test/repo_test",
|
repo_id="glannuzel/grab_cube",
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
features=dataset_features,
|
features=dataset_features,
|
||||||
robot_type=robot.name,
|
robot_type=robot.name,
|
||||||
@@ -40,7 +48,7 @@ dataset = LeRobotDataset.create(
|
|||||||
|
|
||||||
# Initialize the keyboard listener and rerun visualization
|
# Initialize the keyboard listener and rerun visualization
|
||||||
_, events = init_keyboard_listener()
|
_, events = init_keyboard_listener()
|
||||||
_init_rerun(session_name="recording")
|
# _init_rerun(session_name="recording")
|
||||||
|
|
||||||
# Connect the robot and teleoperator
|
# Connect the robot and teleoperator
|
||||||
robot.connect()
|
robot.connect()
|
||||||
@@ -48,8 +56,11 @@ teleop.connect()
|
|||||||
|
|
||||||
episode_idx = 0
|
episode_idx = 0
|
||||||
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
||||||
|
start_time = time.time()
|
||||||
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
log_say(f"Recording episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
|
|
||||||
|
print("########### RECORDING ###########")
|
||||||
|
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
@@ -58,12 +69,14 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
dataset=dataset,
|
dataset=dataset,
|
||||||
control_time_s=EPISODE_TIME_SEC,
|
control_time_s=EPISODE_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Reset the environment if not stopping or re-recording
|
# Reset the environment if not stopping or re-recording
|
||||||
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
if not events["stop_recording"] and (episode_idx < NUM_EPISODES - 1 or events["rerecord_episode"]):
|
||||||
log_say("Reset the environment")
|
log_say("Reset the environment")
|
||||||
|
|
||||||
|
print("------------- RESETTING -------------")
|
||||||
record_loop(
|
record_loop(
|
||||||
robot=robot,
|
robot=robot,
|
||||||
events=events,
|
events=events,
|
||||||
@@ -71,7 +84,7 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
teleop=teleop,
|
teleop=teleop,
|
||||||
control_time_s=RESET_TIME_SEC,
|
control_time_s=RESET_TIME_SEC,
|
||||||
single_task=TASK_DESCRIPTION,
|
single_task=TASK_DESCRIPTION,
|
||||||
display_data=True,
|
display_data=False,
|
||||||
)
|
)
|
||||||
|
|
||||||
if events["rerecord_episode"]:
|
if events["rerecord_episode"]:
|
||||||
@@ -81,10 +94,12 @@ while episode_idx < NUM_EPISODES and not events["stop_recording"]:
|
|||||||
dataset.clear_episode_buffer()
|
dataset.clear_episode_buffer()
|
||||||
continue
|
continue
|
||||||
|
|
||||||
|
# episode_idx = NUM_EPISODES
|
||||||
dataset.save_episode()
|
dataset.save_episode()
|
||||||
episode_idx += 1
|
episode_idx += 1
|
||||||
|
print(time.time()-start_time)
|
||||||
|
|
||||||
# Clean up
|
# Clean up
|
||||||
log_say("Stop recording")
|
log_say("Stop recording")
|
||||||
robot.disconnect()
|
robot.disconnect()
|
||||||
# dataset.push_to_hub()
|
dataset.push_to_hub()
|
||||||
|
|||||||
Reference in New Issue
Block a user