mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 09:09:48 +00:00
add eval dataset
This commit is contained in:
@@ -37,14 +37,19 @@ import torch
|
||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.train import TrainPipelineConfig
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
|
||||
from lerobot.utils.constants import ACTION, OBS_STR
|
||||
from lerobot.utils.relative_actions import (
|
||||
convert_state_to_relative,
|
||||
convert_from_relative_actions,
|
||||
PerTimestepNormalizer,
|
||||
)
|
||||
from lerobot.model.kinematics import RobotKinematics
|
||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
||||
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||
from lerobot.processor.converters import (
|
||||
robot_action_observation_to_transition,
|
||||
robot_action_to_transition,
|
||||
@@ -65,6 +70,7 @@ from lerobot.utils.utils import log_say
|
||||
|
||||
# Configuration
|
||||
HF_MODEL_ID = "lerobot-data-collection/pi0_ee" # TODO: Replace with your EE-trained model
|
||||
HF_EVAL_DATASET_ID = "your-org/your-ee-eval-dataset" # TODO: Replace with your eval dataset
|
||||
TASK_DESCRIPTION = "ee-policy-task" # TODO: Replace with your task
|
||||
|
||||
NUM_EPISODES = 1
|
||||
@@ -261,12 +267,15 @@ def run_ee_inference_loop(
|
||||
postprocessor,
|
||||
joints_to_ee,
|
||||
ee_to_joints,
|
||||
dataset: LeRobotDataset,
|
||||
fps: int,
|
||||
duration_s: float,
|
||||
events: dict,
|
||||
task: str,
|
||||
use_relative_actions: bool = False,
|
||||
use_relative_state: bool = False,
|
||||
relative_normalizer: PerTimestepNormalizer | None = None,
|
||||
display_data: bool = True,
|
||||
):
|
||||
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
|
||||
dt = 1.0 / fps
|
||||
@@ -354,6 +363,17 @@ def run_ee_inference_loop(
|
||||
# 8. Send joint commands to robot
|
||||
robot.send_action(joint_action)
|
||||
|
||||
# 9. Save frame to dataset
|
||||
if dataset is not None:
|
||||
observation_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
|
||||
action_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
|
||||
frame = {**observation_frame, **action_frame, "task": task}
|
||||
dataset.add_frame(frame)
|
||||
|
||||
# 10. Visualization
|
||||
if display_data:
|
||||
log_rerun_data(observation=robot_obs, action=joint_action)
|
||||
|
||||
# Progress logging
|
||||
step += 1
|
||||
if step % (fps * 5) == 0:
|
||||
@@ -375,6 +395,7 @@ def main():
|
||||
print("OpenArms End-Effector Policy Evaluation")
|
||||
print("=" * 70)
|
||||
print(f"\nModel: {HF_MODEL_ID}")
|
||||
print(f"Dataset: {HF_EVAL_DATASET_ID}")
|
||||
print(f"Task: {TASK_DESCRIPTION}")
|
||||
print(f"Episodes: {NUM_EPISODES}")
|
||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||
@@ -387,14 +408,14 @@ def main():
|
||||
urdf_path = str(urdf_path)
|
||||
|
||||
# Build kinematics pipelines
|
||||
print("\n[1/4] Building kinematics pipelines...")
|
||||
print("\n[1/5] Building kinematics pipelines...")
|
||||
joints_to_ee, ee_to_joints = build_kinematics_pipelines(
|
||||
urdf_path, DEFAULT_LEFT_EE_FRAME, DEFAULT_RIGHT_EE_FRAME
|
||||
)
|
||||
print(" FK and IK pipelines ready")
|
||||
|
||||
# Initialize robot
|
||||
print("\n[2/4] Connecting to robot...")
|
||||
print("\n[2/5] Connecting to robot...")
|
||||
follower_config = OpenArmsFollowerConfig(
|
||||
port_left=FOLLOWER_LEFT_PORT,
|
||||
port_right=FOLLOWER_RIGHT_PORT,
|
||||
@@ -430,17 +451,53 @@ def main():
|
||||
leader.bus_left.enable_torque()
|
||||
print(" Leader connected with gravity compensation")
|
||||
|
||||
# Create dataset for saving evaluation data
|
||||
print(f"\n[3/5] Creating evaluation dataset...")
|
||||
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||
action_features_hw = {k: v for k, v in follower.action_features.items() if k.endswith(".pos")}
|
||||
|
||||
dataset_features = combine_feature_dicts(
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=teleop_action_processor,
|
||||
initial_features=create_initial_features(action=action_features_hw),
|
||||
use_videos=True,
|
||||
),
|
||||
aggregate_pipeline_dataset_features(
|
||||
pipeline=robot_observation_processor,
|
||||
initial_features=create_initial_features(observation=follower.observation_features),
|
||||
use_videos=True,
|
||||
),
|
||||
)
|
||||
|
||||
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||
if dataset_path.exists():
|
||||
print(f" Dataset exists at: {dataset_path}")
|
||||
if input(" Continue and overwrite? (y/n): ").strip().lower() != 'y':
|
||||
follower.disconnect()
|
||||
return
|
||||
|
||||
dataset = LeRobotDataset.create(
|
||||
repo_id=HF_EVAL_DATASET_ID,
|
||||
fps=FPS,
|
||||
features=dataset_features,
|
||||
robot_type=follower.name,
|
||||
use_videos=True,
|
||||
image_writer_processes=0,
|
||||
image_writer_threads=12,
|
||||
)
|
||||
print(" Dataset created")
|
||||
|
||||
# Load policy
|
||||
print(f"\n[3/4] Loading policy from {HF_MODEL_ID}...")
|
||||
print(f"\n[4/5] Loading policy from {HF_MODEL_ID}...")
|
||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||
policy_config.pretrained_path = HF_MODEL_ID
|
||||
|
||||
# Create policy without dataset meta (use config defaults)
|
||||
policy = make_policy(policy_config, ds_meta=None)
|
||||
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||
|
||||
preprocessor, postprocessor = make_pre_post_processors(
|
||||
policy_cfg=policy.config,
|
||||
pretrained_path=HF_MODEL_ID,
|
||||
dataset_stats=dataset.meta.stats,
|
||||
preprocessor_overrides={
|
||||
"device_processor": {"device": str(policy.config.device)}
|
||||
},
|
||||
@@ -455,16 +512,17 @@ def main():
|
||||
mode = "relative actions + state" if use_relative_state else "relative actions only"
|
||||
print(f" Mode: {mode}")
|
||||
|
||||
# Initialize keyboard listener
|
||||
print("\n[4/4] Starting evaluation...")
|
||||
# Initialize keyboard listener and visualization
|
||||
print("\n[5/5] Starting evaluation...")
|
||||
listener, events = init_keyboard_listener()
|
||||
init_rerun(session_name="openarms_eval_ee")
|
||||
|
||||
print("\nControls: ESC=stop, →=next episode, ←=rerecord")
|
||||
episode_idx = 0
|
||||
|
||||
try:
|
||||
for episode_idx in range(NUM_EPISODES):
|
||||
if events.get("stop_recording"):
|
||||
break
|
||||
|
||||
log_say(f"Starting episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
while episode_idx < NUM_EPISODES and not events.get("stop_recording"):
|
||||
log_say(f"Episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||
print(f"\n{'='*50}")
|
||||
print(f"Episode {episode_idx + 1}/{NUM_EPISODES}")
|
||||
print(f"{'='*50}")
|
||||
@@ -480,21 +538,38 @@ def main():
|
||||
postprocessor=postprocessor,
|
||||
joints_to_ee=joints_to_ee,
|
||||
ee_to_joints=ee_to_joints,
|
||||
dataset=dataset,
|
||||
fps=FPS,
|
||||
duration_s=EPISODE_TIME_SEC,
|
||||
events=events,
|
||||
task=TASK_DESCRIPTION,
|
||||
use_relative_actions=use_relative_actions,
|
||||
use_relative_state=use_relative_state,
|
||||
relative_normalizer=relative_normalizer,
|
||||
)
|
||||
|
||||
# Handle re-recording
|
||||
if events.get("rerecord_episode", False):
|
||||
log_say("Re-recording episode")
|
||||
events["rerecord_episode"] = False
|
||||
events["exit_early"] = False
|
||||
dataset.clear_episode_buffer()
|
||||
continue
|
||||
|
||||
# Save episode if we have data
|
||||
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||
print(f" Saving episode {episode_idx + 1}...")
|
||||
dataset.save_episode()
|
||||
episode_idx += 1
|
||||
|
||||
events["exit_early"] = False
|
||||
|
||||
# Reset between episodes
|
||||
if episode_idx < NUM_EPISODES - 1 and not events.get("stop_recording"):
|
||||
if episode_idx < NUM_EPISODES and not events.get("stop_recording"):
|
||||
if USE_LEADER_FOR_RESETS and leader and leader.is_connected:
|
||||
log_say("Reset environment using leader arms")
|
||||
print(f"\nManual reset ({RESET_TIME_SEC}s) - use leader arms...")
|
||||
|
||||
# Simple teleop reset loop
|
||||
reset_start = time.perf_counter()
|
||||
while time.perf_counter() - reset_start < RESET_TIME_SEC:
|
||||
if events.get("exit_early") or events.get("stop_recording"):
|
||||
@@ -506,10 +581,9 @@ def main():
|
||||
follower.send_action(follower_action)
|
||||
time.sleep(1/FPS)
|
||||
else:
|
||||
log_say("Manual reset required")
|
||||
input("Reset environment and press ENTER...")
|
||||
input("\nReset environment and press ENTER...")
|
||||
|
||||
print(f"\n✓ Evaluation complete! {NUM_EPISODES} episodes")
|
||||
print(f"\n✓ Evaluation complete! {episode_idx} episodes recorded")
|
||||
log_say("Evaluation complete", blocking=True)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
@@ -528,6 +602,11 @@ def main():
|
||||
if listener is not None:
|
||||
listener.stop()
|
||||
|
||||
# Finalize and push dataset
|
||||
dataset.finalize()
|
||||
print("Uploading to Hub...")
|
||||
dataset.push_to_hub(private=True)
|
||||
|
||||
print("✓ Done!")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user