mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-18 10:10:08 +00:00
add eval dataset
This commit is contained in:
@@ -37,14 +37,19 @@ import torch
|
|||||||
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
from lerobot.cameras.opencv.configuration_opencv import OpenCVCameraConfig
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.configs.train import TrainPipelineConfig
|
from lerobot.configs.train import TrainPipelineConfig
|
||||||
|
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||||
|
from lerobot.datasets.pipeline_features import aggregate_pipeline_dataset_features, create_initial_features
|
||||||
|
from lerobot.datasets.utils import build_dataset_frame, combine_feature_dicts
|
||||||
|
from lerobot.model.kinematics import RobotKinematics
|
||||||
|
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
||||||
|
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline, make_default_processors
|
||||||
|
from lerobot.utils.constants import ACTION, OBS_STR
|
||||||
from lerobot.utils.relative_actions import (
|
from lerobot.utils.relative_actions import (
|
||||||
convert_state_to_relative,
|
convert_state_to_relative,
|
||||||
convert_from_relative_actions,
|
convert_from_relative_actions,
|
||||||
PerTimestepNormalizer,
|
PerTimestepNormalizer,
|
||||||
)
|
)
|
||||||
from lerobot.model.kinematics import RobotKinematics
|
from lerobot.utils.visualization_utils import init_rerun, log_rerun_data
|
||||||
from lerobot.policies.factory import make_policy, make_pre_post_processors
|
|
||||||
from lerobot.processor import RobotAction, RobotObservation, RobotProcessorPipeline
|
|
||||||
from lerobot.processor.converters import (
|
from lerobot.processor.converters import (
|
||||||
robot_action_observation_to_transition,
|
robot_action_observation_to_transition,
|
||||||
robot_action_to_transition,
|
robot_action_to_transition,
|
||||||
@@ -65,6 +70,7 @@ from lerobot.utils.utils import log_say
|
|||||||
|
|
||||||
# Configuration
|
# Configuration
|
||||||
HF_MODEL_ID = "lerobot-data-collection/pi0_ee" # TODO: Replace with your EE-trained model
|
HF_MODEL_ID = "lerobot-data-collection/pi0_ee" # TODO: Replace with your EE-trained model
|
||||||
|
HF_EVAL_DATASET_ID = "your-org/your-ee-eval-dataset" # TODO: Replace with your eval dataset
|
||||||
TASK_DESCRIPTION = "ee-policy-task" # TODO: Replace with your task
|
TASK_DESCRIPTION = "ee-policy-task" # TODO: Replace with your task
|
||||||
|
|
||||||
NUM_EPISODES = 1
|
NUM_EPISODES = 1
|
||||||
@@ -261,12 +267,15 @@ def run_ee_inference_loop(
|
|||||||
postprocessor,
|
postprocessor,
|
||||||
joints_to_ee,
|
joints_to_ee,
|
||||||
ee_to_joints,
|
ee_to_joints,
|
||||||
|
dataset: LeRobotDataset,
|
||||||
fps: int,
|
fps: int,
|
||||||
duration_s: float,
|
duration_s: float,
|
||||||
events: dict,
|
events: dict,
|
||||||
|
task: str,
|
||||||
use_relative_actions: bool = False,
|
use_relative_actions: bool = False,
|
||||||
use_relative_state: bool = False,
|
use_relative_state: bool = False,
|
||||||
relative_normalizer: PerTimestepNormalizer | None = None,
|
relative_normalizer: PerTimestepNormalizer | None = None,
|
||||||
|
display_data: bool = True,
|
||||||
):
|
):
|
||||||
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
|
"""Run inference loop with EE conversion and optional UMI-style relative actions."""
|
||||||
dt = 1.0 / fps
|
dt = 1.0 / fps
|
||||||
@@ -354,6 +363,17 @@ def run_ee_inference_loop(
|
|||||||
# 8. Send joint commands to robot
|
# 8. Send joint commands to robot
|
||||||
robot.send_action(joint_action)
|
robot.send_action(joint_action)
|
||||||
|
|
||||||
|
# 9. Save frame to dataset
|
||||||
|
if dataset is not None:
|
||||||
|
observation_frame = build_dataset_frame(dataset.features, robot_obs, prefix=OBS_STR)
|
||||||
|
action_frame = build_dataset_frame(dataset.features, joint_action, prefix=ACTION)
|
||||||
|
frame = {**observation_frame, **action_frame, "task": task}
|
||||||
|
dataset.add_frame(frame)
|
||||||
|
|
||||||
|
# 10. Visualization
|
||||||
|
if display_data:
|
||||||
|
log_rerun_data(observation=robot_obs, action=joint_action)
|
||||||
|
|
||||||
# Progress logging
|
# Progress logging
|
||||||
step += 1
|
step += 1
|
||||||
if step % (fps * 5) == 0:
|
if step % (fps * 5) == 0:
|
||||||
@@ -375,6 +395,7 @@ def main():
|
|||||||
print("OpenArms End-Effector Policy Evaluation")
|
print("OpenArms End-Effector Policy Evaluation")
|
||||||
print("=" * 70)
|
print("=" * 70)
|
||||||
print(f"\nModel: {HF_MODEL_ID}")
|
print(f"\nModel: {HF_MODEL_ID}")
|
||||||
|
print(f"Dataset: {HF_EVAL_DATASET_ID}")
|
||||||
print(f"Task: {TASK_DESCRIPTION}")
|
print(f"Task: {TASK_DESCRIPTION}")
|
||||||
print(f"Episodes: {NUM_EPISODES}")
|
print(f"Episodes: {NUM_EPISODES}")
|
||||||
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
print(f"Episode Duration: {EPISODE_TIME_SEC}s")
|
||||||
@@ -387,14 +408,14 @@ def main():
|
|||||||
urdf_path = str(urdf_path)
|
urdf_path = str(urdf_path)
|
||||||
|
|
||||||
# Build kinematics pipelines
|
# Build kinematics pipelines
|
||||||
print("\n[1/4] Building kinematics pipelines...")
|
print("\n[1/5] Building kinematics pipelines...")
|
||||||
joints_to_ee, ee_to_joints = build_kinematics_pipelines(
|
joints_to_ee, ee_to_joints = build_kinematics_pipelines(
|
||||||
urdf_path, DEFAULT_LEFT_EE_FRAME, DEFAULT_RIGHT_EE_FRAME
|
urdf_path, DEFAULT_LEFT_EE_FRAME, DEFAULT_RIGHT_EE_FRAME
|
||||||
)
|
)
|
||||||
print(" FK and IK pipelines ready")
|
print(" FK and IK pipelines ready")
|
||||||
|
|
||||||
# Initialize robot
|
# Initialize robot
|
||||||
print("\n[2/4] Connecting to robot...")
|
print("\n[2/5] Connecting to robot...")
|
||||||
follower_config = OpenArmsFollowerConfig(
|
follower_config = OpenArmsFollowerConfig(
|
||||||
port_left=FOLLOWER_LEFT_PORT,
|
port_left=FOLLOWER_LEFT_PORT,
|
||||||
port_right=FOLLOWER_RIGHT_PORT,
|
port_right=FOLLOWER_RIGHT_PORT,
|
||||||
@@ -430,17 +451,53 @@ def main():
|
|||||||
leader.bus_left.enable_torque()
|
leader.bus_left.enable_torque()
|
||||||
print(" Leader connected with gravity compensation")
|
print(" Leader connected with gravity compensation")
|
||||||
|
|
||||||
|
# Create dataset for saving evaluation data
|
||||||
|
print(f"\n[3/5] Creating evaluation dataset...")
|
||||||
|
teleop_action_processor, robot_action_processor, robot_observation_processor = make_default_processors()
|
||||||
|
action_features_hw = {k: v for k, v in follower.action_features.items() if k.endswith(".pos")}
|
||||||
|
|
||||||
|
dataset_features = combine_feature_dicts(
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=teleop_action_processor,
|
||||||
|
initial_features=create_initial_features(action=action_features_hw),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
aggregate_pipeline_dataset_features(
|
||||||
|
pipeline=robot_observation_processor,
|
||||||
|
initial_features=create_initial_features(observation=follower.observation_features),
|
||||||
|
use_videos=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
dataset_path = Path.home() / ".cache" / "huggingface" / "lerobot" / HF_EVAL_DATASET_ID
|
||||||
|
if dataset_path.exists():
|
||||||
|
print(f" Dataset exists at: {dataset_path}")
|
||||||
|
if input(" Continue and overwrite? (y/n): ").strip().lower() != 'y':
|
||||||
|
follower.disconnect()
|
||||||
|
return
|
||||||
|
|
||||||
|
dataset = LeRobotDataset.create(
|
||||||
|
repo_id=HF_EVAL_DATASET_ID,
|
||||||
|
fps=FPS,
|
||||||
|
features=dataset_features,
|
||||||
|
robot_type=follower.name,
|
||||||
|
use_videos=True,
|
||||||
|
image_writer_processes=0,
|
||||||
|
image_writer_threads=12,
|
||||||
|
)
|
||||||
|
print(" Dataset created")
|
||||||
|
|
||||||
# Load policy
|
# Load policy
|
||||||
print(f"\n[3/4] Loading policy from {HF_MODEL_ID}...")
|
print(f"\n[4/5] Loading policy from {HF_MODEL_ID}...")
|
||||||
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
policy_config = PreTrainedConfig.from_pretrained(HF_MODEL_ID)
|
||||||
policy_config.pretrained_path = HF_MODEL_ID
|
policy_config.pretrained_path = HF_MODEL_ID
|
||||||
|
|
||||||
# Create policy without dataset meta (use config defaults)
|
policy = make_policy(policy_config, ds_meta=dataset.meta)
|
||||||
policy = make_policy(policy_config, ds_meta=None)
|
|
||||||
|
|
||||||
preprocessor, postprocessor = make_pre_post_processors(
|
preprocessor, postprocessor = make_pre_post_processors(
|
||||||
policy_cfg=policy.config,
|
policy_cfg=policy.config,
|
||||||
pretrained_path=HF_MODEL_ID,
|
pretrained_path=HF_MODEL_ID,
|
||||||
|
dataset_stats=dataset.meta.stats,
|
||||||
preprocessor_overrides={
|
preprocessor_overrides={
|
||||||
"device_processor": {"device": str(policy.config.device)}
|
"device_processor": {"device": str(policy.config.device)}
|
||||||
},
|
},
|
||||||
@@ -455,16 +512,17 @@ def main():
|
|||||||
mode = "relative actions + state" if use_relative_state else "relative actions only"
|
mode = "relative actions + state" if use_relative_state else "relative actions only"
|
||||||
print(f" Mode: {mode}")
|
print(f" Mode: {mode}")
|
||||||
|
|
||||||
# Initialize keyboard listener
|
# Initialize keyboard listener and visualization
|
||||||
print("\n[4/4] Starting evaluation...")
|
print("\n[5/5] Starting evaluation...")
|
||||||
listener, events = init_keyboard_listener()
|
listener, events = init_keyboard_listener()
|
||||||
|
init_rerun(session_name="openarms_eval_ee")
|
||||||
|
|
||||||
|
print("\nControls: ESC=stop, →=next episode, ←=rerecord")
|
||||||
|
episode_idx = 0
|
||||||
|
|
||||||
try:
|
try:
|
||||||
for episode_idx in range(NUM_EPISODES):
|
while episode_idx < NUM_EPISODES and not events.get("stop_recording"):
|
||||||
if events.get("stop_recording"):
|
log_say(f"Episode {episode_idx + 1} of {NUM_EPISODES}")
|
||||||
break
|
|
||||||
|
|
||||||
log_say(f"Starting episode {episode_idx + 1} of {NUM_EPISODES}")
|
|
||||||
print(f"\n{'='*50}")
|
print(f"\n{'='*50}")
|
||||||
print(f"Episode {episode_idx + 1}/{NUM_EPISODES}")
|
print(f"Episode {episode_idx + 1}/{NUM_EPISODES}")
|
||||||
print(f"{'='*50}")
|
print(f"{'='*50}")
|
||||||
@@ -480,21 +538,38 @@ def main():
|
|||||||
postprocessor=postprocessor,
|
postprocessor=postprocessor,
|
||||||
joints_to_ee=joints_to_ee,
|
joints_to_ee=joints_to_ee,
|
||||||
ee_to_joints=ee_to_joints,
|
ee_to_joints=ee_to_joints,
|
||||||
|
dataset=dataset,
|
||||||
fps=FPS,
|
fps=FPS,
|
||||||
duration_s=EPISODE_TIME_SEC,
|
duration_s=EPISODE_TIME_SEC,
|
||||||
events=events,
|
events=events,
|
||||||
|
task=TASK_DESCRIPTION,
|
||||||
use_relative_actions=use_relative_actions,
|
use_relative_actions=use_relative_actions,
|
||||||
use_relative_state=use_relative_state,
|
use_relative_state=use_relative_state,
|
||||||
relative_normalizer=relative_normalizer,
|
relative_normalizer=relative_normalizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Handle re-recording
|
||||||
|
if events.get("rerecord_episode", False):
|
||||||
|
log_say("Re-recording episode")
|
||||||
|
events["rerecord_episode"] = False
|
||||||
|
events["exit_early"] = False
|
||||||
|
dataset.clear_episode_buffer()
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Save episode if we have data
|
||||||
|
if dataset.episode_buffer is not None and dataset.episode_buffer.get("size", 0) > 0:
|
||||||
|
print(f" Saving episode {episode_idx + 1}...")
|
||||||
|
dataset.save_episode()
|
||||||
|
episode_idx += 1
|
||||||
|
|
||||||
|
events["exit_early"] = False
|
||||||
|
|
||||||
# Reset between episodes
|
# Reset between episodes
|
||||||
if episode_idx < NUM_EPISODES - 1 and not events.get("stop_recording"):
|
if episode_idx < NUM_EPISODES and not events.get("stop_recording"):
|
||||||
if USE_LEADER_FOR_RESETS and leader and leader.is_connected:
|
if USE_LEADER_FOR_RESETS and leader and leader.is_connected:
|
||||||
log_say("Reset environment using leader arms")
|
log_say("Reset environment using leader arms")
|
||||||
print(f"\nManual reset ({RESET_TIME_SEC}s) - use leader arms...")
|
print(f"\nManual reset ({RESET_TIME_SEC}s) - use leader arms...")
|
||||||
|
|
||||||
# Simple teleop reset loop
|
|
||||||
reset_start = time.perf_counter()
|
reset_start = time.perf_counter()
|
||||||
while time.perf_counter() - reset_start < RESET_TIME_SEC:
|
while time.perf_counter() - reset_start < RESET_TIME_SEC:
|
||||||
if events.get("exit_early") or events.get("stop_recording"):
|
if events.get("exit_early") or events.get("stop_recording"):
|
||||||
@@ -506,10 +581,9 @@ def main():
|
|||||||
follower.send_action(follower_action)
|
follower.send_action(follower_action)
|
||||||
time.sleep(1/FPS)
|
time.sleep(1/FPS)
|
||||||
else:
|
else:
|
||||||
log_say("Manual reset required")
|
input("\nReset environment and press ENTER...")
|
||||||
input("Reset environment and press ENTER...")
|
|
||||||
|
|
||||||
print(f"\n✓ Evaluation complete! {NUM_EPISODES} episodes")
|
print(f"\n✓ Evaluation complete! {episode_idx} episodes recorded")
|
||||||
log_say("Evaluation complete", blocking=True)
|
log_say("Evaluation complete", blocking=True)
|
||||||
|
|
||||||
except KeyboardInterrupt:
|
except KeyboardInterrupt:
|
||||||
@@ -528,6 +602,11 @@ def main():
|
|||||||
if listener is not None:
|
if listener is not None:
|
||||||
listener.stop()
|
listener.stop()
|
||||||
|
|
||||||
|
# Finalize and push dataset
|
||||||
|
dataset.finalize()
|
||||||
|
print("Uploading to Hub...")
|
||||||
|
dataset.push_to_hub(private=True)
|
||||||
|
|
||||||
print("✓ Done!")
|
print("✓ Done!")
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user