fixup! Pi0 eval dataset

This commit is contained in:
Eugene Mironov
2025-11-07 12:31:35 +07:00
parent e86afc883e
commit d0123c4178
+55 -35
View File
@@ -82,7 +82,8 @@ from lerobot.configs import parser
from lerobot.configs.default import DatasetConfig
from lerobot.configs.policies import PreTrainedConfig
from lerobot.configs.types import RTCAttentionSchedule
from lerobot.datasets.lerobot_dataset import LeRobotDataset
from lerobot.datasets.factory import resolve_delta_timestamps
from lerobot.datasets.lerobot_dataset import LeRobotDataset, LeRobotDatasetMetadata
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.debug_visualizer import RTCDebugVisualizer
@@ -206,9 +207,21 @@ class RTCEvaluator:
self.cfg = cfg
self.device = cfg.device
# Load dataset first (needed for preprocessor)
# Load dataset with proper delta_timestamps based on policy configuration
# Calculate delta_timestamps using the same logic as make_dataset factory
logging.info(f"Loading dataset: {cfg.dataset.repo_id}")
self.dataset = LeRobotDataset(cfg.dataset.repo_id, delta_timestamps={"action": np.arange(50) / 30})
# Get dataset metadata to extract FPS
ds_meta = LeRobotDatasetMetadata(cfg.dataset.repo_id)
# Calculate delta_timestamps from policy's delta_indices
delta_timestamps = resolve_delta_timestamps(cfg.policy, ds_meta)
# Create dataset with calculated delta_timestamps
self.dataset = LeRobotDataset(
cfg.dataset.repo_id,
delta_timestamps=delta_timestamps,
)
logging.info(f"Dataset loaded: {len(self.dataset)} samples, {self.dataset.num_episodes} episodes")
# Create preprocessor/postprocessor
@@ -220,17 +233,12 @@ class RTCEvaluator:
},
)
# Initialize three separate policy instances
# Note: These policies are initialized here but will be freed sequentially during
# evaluation to manage memory. Large models (e.g., VLAs with billions of parameters)
# cannot fit three instances in memory simultaneously. Each policy is deleted and
# memory is freed (via torch.cuda.empty_cache()) immediately after its use.
logging.info("=" * 80)
logging.info("Initializing three policy instances:")
logging.info(" 1. policy_prev_chunk (for generating previous chunk)")
logging.info(" 2. policy_no_rtc (for non-RTC inference)")
logging.info(" 3. policy_rtc (for RTC inference)")
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
logging.info("Ready to run evaluation with sequential policy loading:")
logging.info(" 1. policy_prev_chunk - Generate reference chunk, then destroy")
logging.info(" 2. policy_no_rtc - Generate without RTC, then destroy")
logging.info(" 3. policy_rtc - Generate with RTC, then destroy")
logging.info(" Note: Only one policy in memory at a time for efficient memory usage")
logging.info("=" * 80)
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
@@ -383,30 +391,20 @@ class RTCEvaluator:
preprocessed_first_sample = self.preprocessor(first_sample)
preprocessed_second_sample = self.preprocessor(second_sample)
# Policy 1: For generating previous chunk (RTC disabled, no debug)
# ============================================================================
# Step 1: Generate previous chunk using policy_prev_chunk
# ============================================================================
# This policy is only used to generate the reference chunk and then freed
logging.info("=" * 80)
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
logging.info("=" * 80)
# Initialize policy 1
policy_prev_chunk_policy = self._init_policy(
name="policy_prev_chunk",
rtc_enabled=False,
rtc_debug=False,
)
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Policy 3: For RTC inference (RTC enabled, debug enabled)
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
# Step 1: Generate previous chunk using policy_prev_chunk
# This policy is only used to generate the reference chunk and then freed
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
with torch.no_grad():
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
preprocessed_first_sample,
@@ -416,13 +414,24 @@ class RTCEvaluator:
# Destroy policy_prev_chunk to free memory for large models
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
# ============================================================================
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
logging.info("=" * 80)
# Initialize policy 2
policy_no_rtc_policy = self._init_policy(
name="policy_no_rtc",
rtc_enabled=False,
rtc_debug=True,
)
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
noise_clone = noise.clone()
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
policy_no_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
_ = policy_no_rtc_policy.predict_action_chunk(
@@ -435,8 +444,19 @@ class RTCEvaluator:
# Destroy policy_no_rtc to free memory before loading policy_rtc
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
# ============================================================================
# Step 3: Generate actions WITH RTC using policy_rtc
# ============================================================================
logging.info("=" * 80)
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
logging.info("=" * 80)
# Initialize policy 3
policy_rtc_policy = self._init_policy(
name="policy_rtc",
rtc_enabled=True,
rtc_debug=True,
)
policy_rtc_policy.rtc_processor.reset_tracker()
with torch.no_grad():
_ = policy_rtc_policy.predict_action_chunk(