mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-15 16:49:55 +00:00
fixup! Add matplotliv to dev
This commit is contained in:
@@ -27,7 +27,7 @@ Usage:
|
|||||||
--use_torch_compile=true \
|
--use_torch_compile=true \
|
||||||
--torch_compile_mode=max-autotune
|
--torch_compile_mode=max-autotune
|
||||||
|
|
||||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
# With torch.compile on CUDA
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
--policy.path=helper2424/smolvla_check_rtc_last3 \
|
||||||
--dataset.repo_id=helper2424/check_rtc \
|
--dataset.repo_id=helper2424/check_rtc \
|
||||||
@@ -202,31 +202,6 @@ class RTCEvaluator:
|
|||||||
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
logging.info(" Note: Policies will be freed sequentially during evaluation to manage memory")
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
|
|
||||||
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
|
||||||
self.policy_prev_chunk = self._init_policy(
|
|
||||||
name="policy_prev_chunk",
|
|
||||||
rtc_enabled=False,
|
|
||||||
rtc_debug=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
|
||||||
self.policy_no_rtc = self._init_policy(
|
|
||||||
name="policy_no_rtc",
|
|
||||||
rtc_enabled=False,
|
|
||||||
rtc_debug=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
|
||||||
self.policy_rtc = self._init_policy(
|
|
||||||
name="policy_rtc",
|
|
||||||
rtc_enabled=True,
|
|
||||||
rtc_debug=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
logging.info("=" * 80)
|
|
||||||
logging.info("All policies initialized successfully")
|
|
||||||
logging.info("=" * 80)
|
|
||||||
|
|
||||||
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
def _init_policy(self, name: str, rtc_enabled: bool, rtc_debug: bool):
|
||||||
"""Initialize a single policy instance with specified RTC configuration.
|
"""Initialize a single policy instance with specified RTC configuration.
|
||||||
|
|
||||||
@@ -290,8 +265,11 @@ class RTCEvaluator:
|
|||||||
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
logging.info(f" [{policy_name}] Applying torch.compile to predict_action_chunk...")
|
||||||
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
logging.info(f" Backend: {self.cfg.torch_compile_backend}")
|
||||||
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
logging.info(f" Mode: {self.cfg.torch_compile_mode}")
|
||||||
|
logging.info(" Note: Debug tracker excluded from compilation via @torch._dynamo.disable")
|
||||||
|
|
||||||
# Compile the predict_action_chunk method
|
# Compile the predict_action_chunk method
|
||||||
|
# The debug tracker is excluded from compilation via @torch._dynamo.disable decorator
|
||||||
|
# on the Tracker.track() method, so it won't cause graph breaks
|
||||||
original_method = policy.predict_action_chunk
|
original_method = policy.predict_action_chunk
|
||||||
compiled_method = torch.compile(
|
compiled_method = torch.compile(
|
||||||
original_method,
|
original_method,
|
||||||
@@ -368,56 +346,77 @@ class RTCEvaluator:
|
|||||||
preprocessed_first_sample = self.preprocessor(first_sample)
|
preprocessed_first_sample = self.preprocessor(first_sample)
|
||||||
preprocessed_second_sample = self.preprocessor(second_sample)
|
preprocessed_second_sample = self.preprocessor(second_sample)
|
||||||
|
|
||||||
|
# Policy 1: For generating previous chunk (RTC disabled, no debug)
|
||||||
|
policy_prev_chunk_policy = self._init_policy(
|
||||||
|
name="policy_prev_chunk",
|
||||||
|
rtc_enabled=False,
|
||||||
|
rtc_debug=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy 2: For non-RTC inference (RTC disabled, debug enabled)
|
||||||
|
policy_no_rtc_policy = self._init_policy(
|
||||||
|
name="policy_no_rtc",
|
||||||
|
rtc_enabled=False,
|
||||||
|
rtc_debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Policy 3: For RTC inference (RTC enabled, debug enabled)
|
||||||
|
policy_rtc_policy = self._init_policy(
|
||||||
|
name="policy_rtc",
|
||||||
|
rtc_enabled=True,
|
||||||
|
rtc_debug=True,
|
||||||
|
)
|
||||||
|
|
||||||
# Step 1: Generate previous chunk using policy_prev_chunk
|
# Step 1: Generate previous chunk using policy_prev_chunk
|
||||||
# This policy is only used to generate the reference chunk and then freed
|
# This policy is only used to generate the reference chunk and then freed
|
||||||
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
logging.info("Step 1: Generating previous chunk with policy_prev_chunk")
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
prev_chunk_left_over = self.policy_prev_chunk.predict_action_chunk(
|
prev_chunk_left_over = policy_prev_chunk_policy.predict_action_chunk(
|
||||||
preprocessed_first_sample,
|
preprocessed_first_sample,
|
||||||
)[:, :25, :].squeeze(0)
|
)[:, :25, :].squeeze(0)
|
||||||
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
logging.info(f" Generated prev_chunk shape: {prev_chunk_left_over.shape}")
|
||||||
|
|
||||||
# Destroy policy_prev_chunk to free memory for large models
|
# Destroy policy_prev_chunk to free memory for large models
|
||||||
self._destroy_policy(self.policy_prev_chunk, "policy_prev_chunk")
|
self._destroy_policy(policy_prev_chunk_policy, "policy_prev_chunk")
|
||||||
|
|
||||||
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
# Sample noise (use same noise for both RTC and non-RTC for fair comparison)
|
||||||
noise_size = (1, self.policy_no_rtc.config.chunk_size, self.policy_no_rtc.config.max_action_dim)
|
noise_size = (1, policy_no_rtc_policy.config.chunk_size, policy_no_rtc_policy.config.max_action_dim)
|
||||||
noise = self.policy_no_rtc.model.sample_noise(noise_size, self.device)
|
noise = policy_no_rtc_policy.model.sample_noise(noise_size, self.device)
|
||||||
noise_clone = noise.clone()
|
noise_clone = noise.clone()
|
||||||
|
|
||||||
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
# Step 2: Generate actions WITHOUT RTC using policy_no_rtc
|
||||||
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
logging.info("Step 2: Generating actions WITHOUT RTC with policy_no_rtc")
|
||||||
self.policy_no_rtc.rtc_processor.reset_tracker()
|
policy_no_rtc_policy.rtc_processor.reset_tracker()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy_no_rtc.predict_action_chunk(
|
_ = policy_no_rtc_policy.predict_action_chunk(
|
||||||
preprocessed_second_sample,
|
preprocessed_second_sample,
|
||||||
noise=noise,
|
noise=noise,
|
||||||
)
|
)
|
||||||
no_rtc_tracked_steps = self.policy_no_rtc.rtc_processor.tracker.get_all_steps()
|
no_rtc_tracked_steps = policy_no_rtc_policy.rtc_processor.tracker.get_all_steps()
|
||||||
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
logging.info(f" Tracked {len(no_rtc_tracked_steps)} steps without RTC")
|
||||||
|
|
||||||
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
# Destroy policy_no_rtc to free memory before loading policy_rtc
|
||||||
self._destroy_policy(self.policy_no_rtc, "policy_no_rtc")
|
self._destroy_policy(policy_no_rtc_policy, "policy_no_rtc")
|
||||||
|
|
||||||
# Step 3: Generate actions WITH RTC using policy_rtc
|
# Step 3: Generate actions WITH RTC using policy_rtc
|
||||||
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
logging.info("Step 3: Generating actions WITH RTC with policy_rtc")
|
||||||
self.policy_rtc.rtc_processor.reset_tracker()
|
policy_rtc_policy.rtc_processor.reset_tracker()
|
||||||
with torch.no_grad():
|
with torch.no_grad():
|
||||||
_ = self.policy_rtc.predict_action_chunk(
|
_ = policy_rtc_policy.predict_action_chunk(
|
||||||
preprocessed_second_sample,
|
preprocessed_second_sample,
|
||||||
noise=noise_clone,
|
noise=noise_clone,
|
||||||
inference_delay=self.cfg.inference_delay,
|
inference_delay=self.cfg.inference_delay,
|
||||||
prev_chunk_left_over=prev_chunk_left_over,
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
execution_horizon=self.cfg.rtc.execution_horizon,
|
execution_horizon=self.cfg.rtc.execution_horizon,
|
||||||
)
|
)
|
||||||
rtc_tracked_steps = self.policy_rtc.rtc_processor.get_all_debug_steps()
|
rtc_tracked_steps = policy_rtc_policy.rtc_processor.get_all_debug_steps()
|
||||||
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
logging.info(f" Tracked {len(rtc_tracked_steps)} steps with RTC")
|
||||||
|
|
||||||
# Save num_steps before destroying policy (needed for plotting)
|
# Save num_steps before destroying policy (needed for plotting)
|
||||||
num_steps = self.policy_rtc.config.num_steps
|
num_steps = policy_rtc_policy.config.num_steps
|
||||||
|
|
||||||
# Destroy policy_rtc after final use
|
# Destroy policy_rtc after final use
|
||||||
self._destroy_policy(self.policy_rtc, "policy_rtc")
|
self._destroy_policy(policy_rtc_policy, "policy_rtc")
|
||||||
|
|
||||||
# Plot and save results
|
# Plot and save results
|
||||||
logging.info("=" * 80)
|
logging.info("=" * 80)
|
||||||
|
|||||||
@@ -19,6 +19,7 @@
|
|||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
import torch
|
||||||
from torch import Tensor
|
from torch import Tensor
|
||||||
|
|
||||||
|
|
||||||
@@ -120,6 +121,7 @@ class Tracker:
|
|||||||
self._steps.clear()
|
self._steps.clear()
|
||||||
self._step_counter = 0
|
self._step_counter = 0
|
||||||
|
|
||||||
|
@torch._dynamo.disable
|
||||||
def track(
|
def track(
|
||||||
self,
|
self,
|
||||||
time: float | Tensor,
|
time: float | Tensor,
|
||||||
@@ -139,6 +141,9 @@ class Tracker:
|
|||||||
If a step with the given time already exists, it will be updated with the new data.
|
If a step with the given time already exists, it will be updated with the new data.
|
||||||
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
Otherwise, a new step will be created. Only non-None fields are updated/set.
|
||||||
|
|
||||||
|
Note: This method is excluded from torch.compile to avoid graph breaks from
|
||||||
|
operations like .item() which are incompatible with compiled graphs.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
time (float | Tensor): Time parameter - used as the key to identify the step.
|
time (float | Tensor): Time parameter - used as the key to identify the step.
|
||||||
x_t (Tensor | None): Current latent/state tensor.
|
x_t (Tensor | None): Current latent/state tensor.
|
||||||
|
|||||||
Reference in New Issue
Block a user