mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fixup! Pi0 eval dataset
This commit is contained in:
@@ -82,7 +82,8 @@ from lerobot.configs import parser
|
||||
from lerobot.configs.default import DatasetConfig
|
||||
from lerobot.configs.policies import PreTrainedConfig
|
||||
from lerobot.configs.types import RTCAttentionSchedule
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset
|
||||
from lerobot.datasets.factory import resolve_delta_timestamps
|
||||
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
|
||||
@@ -206,9 +207,21 @@ class RTCEvaluator:
|
||||
self.cfg = cfg
|
||||
self.device = cfg.device
|
||||
|
||||
# Load dataset first (needed for preprocessor)
|
||||
# Load dataset with proper delta_timestamps based on policy configuration
|
||||
# Calculate delta_timestamps using the same logic as make_dataset factory
|
||||
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
|
||||
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
|
||||
|
||||
# Get dataset metadata to extract FPS
|
||||
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
|
||||
|
||||
# Calculate delta_timestamps from policy's delta_indices
|
||||
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
|
||||
|
||||
# Create dataset with calculated delta_timestamps
|
||||
self.dataset = LeRobotDataset(
|
||||
cfg.dataset.repo_id,
|
||||
delta_timestamps=delta_timestamps,
|
||||
)
|
||||
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
|
||||
|
||||
# Create preprocessor/postprocessor
|
||||
@@ -220,17 +233,12 @@ class RTCEvaluator:
|
||||
},
|
||||
)
|
||||
|
||||
# Initialize three separate policy instances
|
||||
# Note: These policies are initialized here but will be freed sequentially during
|
||||
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
|
||||
# cannot fit three instances in memory simultaneously. Each policy is deleted and
|
||||
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
|
||||
logging.info("=" * 80)
|
||||
logging.info("Initializing three policy instances:")
|
||||
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
|
||||
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
|
||||
logging.info(" 3. policy_rtc (for RTC inference)")
|
||||
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
||||
logging.info("Ready to run evaluation with sequential policy loading:")
|
||||
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
|
||||
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
|
||||
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
|
||||
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
|
||||
logging.info("=" * 80)
|
||||
|
||||
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||
@@ -383,30 +391,20 @@ class RTCEvaluator:
|
||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||
|
||||
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||
# ============================================================================
|
||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||
# ============================================================================
|
||||
# This policy is only used to generate the reference chunk and then freed
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Initialize policy 1
|
||||
policy_prev_chunk_policy = self._init_policy(
|
||||
name="policy_prev_chunk",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=False,
|
||||
)
|
||||
|
||||
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||
policy_no_rtc_policy = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||
policy_rtc_policy = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||
# This policy is only used to generate the reference chunk and then freed
|
||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||
with torch.no_grad():
|
||||
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
@@ -416,13 +414,24 @@ class RTCEvaluator:
|
||||
# Destroy policy_prev_chunk to free memory for large models
|
||||
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
|
||||
|
||||
# ============================================================================
|
||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||
# ============================================================================
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Initialize policy 2
|
||||
policy_no_rtc_policy = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
|
||||
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
policy_no_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = policy_no_rtc_policy.predict_action_chunk(
|
||||
@@ -435,8 +444,19 @@ class RTCEvaluator:
|
||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
|
||||
|
||||
# ============================================================================
|
||||
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||
# ============================================================================
|
||||
logging.info("=" * 80)
|
||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Initialize policy 3
|
||||
policy_rtc_policy = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
policy_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = policy_rtc_policy.predict_action_chunk(
|
||||
|
||||
Reference in New Issue
Block a user