mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 08:39:49 +00:00
fixup! Add matplotliv to dev
This commit is contained in:
@@ -27,7 +27,7 @@ Usage:
|
||||
--use_torch_compile=true \
|
||||
--torch_compile_mode=max-autotune
|
||||
|
||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||
# With torch.compile on CUDA
|
||||
uv run python examples/rtc/eval_dataset.py \
|
||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||
--dataset.repo_id=helper2424/check_rtc \
|
||||
@@ -202,31 +202,6 @@ class RTCEvaluator:
|
||||
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
||||
logging.info("=" * 80)
|
||||
|
||||
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||
self.policy_prev_chunk = self._init_policy(
|
||||
name="policy_prev_chunk",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=False,
|
||||
)
|
||||
|
||||
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||
self.policy_no_rtc = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||
self.policy_rtc = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
logging.info("=" * 80)
|
||||
logging.info("All policies initialized successfully")
|
||||
logging.info("=" * 80)
|
||||
|
||||
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||
"""Initialize a single policy instance with specified RTC configuration.
|
||||
|
||||
@@ -290,8 +265,11 @@ class RTCEvaluator:
|
||||
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator
|
||||
# on the Tracker.track() method, so it won't cause graph breaks
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(
|
||||
original_method,
|
||||
@@ -368,56 +346,77 @@ class RTCEvaluator:
|
||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||
|
||||
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||
policy_prev_chunk_policy = self._init_policy(
|
||||
name="policy_prev_chunk",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=False,
|
||||
)
|
||||
|
||||
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||
policy_no_rtc_policy = self._init_policy(
|
||||
name="policy_no_rtc",
|
||||
rtc_enabled=False,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||
policy_rtc_policy = self._init_policy(
|
||||
name="policy_rtc",
|
||||
rtc_enabled=True,
|
||||
rtc_debug=True,
|
||||
)
|
||||
|
||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||
# This policy is only used to generate the reference chunk and then freed
|
||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||
with torch.no_grad():
|
||||
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
|
||||
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
|
||||
preprocessed_first_sample,
|
||||
)[:, :25, :].squeeze(0)
|
||||
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
||||
|
||||
# Destroy policy_prev_chunk to free memory for large models
|
||||
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
|
||||
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
|
||||
|
||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
|
||||
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
|
||||
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
|
||||
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||
self.policy_no_rtc.rtc_processor.reset_tracker()
|
||||
policy_no_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = self.policy_no_rtc.predict_action_chunk(
|
||||
_ = policy_no_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise,
|
||||
)
|
||||
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
|
||||
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
|
||||
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||
|
||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
|
||||
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
|
||||
|
||||
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||
self.policy_rtc.rtc_processor.reset_tracker()
|
||||
policy_rtc_policy.rtc_processor.reset_tracker()
|
||||
with torch.no_grad():
|
||||
_ = self.policy_rtc.predict_action_chunk(
|
||||
_ = policy_rtc_policy.predict_action_chunk(
|
||||
preprocessed_second_sample,
|
||||
noise=noise_clone,
|
||||
inference_delay=self.cfg.inference_delay,
|
||||
prev_chunk_left_over=prev_chunk_left_over,
|
||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||
)
|
||||
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
|
||||
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
|
||||
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||
|
||||
# Save num_steps before destroying policy (needed for plotting)
|
||||
num_steps = self.policy_rtc.config.num_steps
|
||||
num_steps = policy_rtc_policy.config.num_steps
|
||||
|
||||
# Destroy policy_rtc after final use
|
||||
self._destroy_policy(self.policy_rtc, "policy_rtc")
|
||||
self._destroy_policy(policy_rtc_policy, "policy_rtc")
|
||||
|
||||
# Plot and save results
|
||||
logging.info("=" * 80)
|
||||
|
||||
Reference in New Issue
Block a user