mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-16 17:20:05 +00:00
Small fixes
This commit is contained in:
+86
-205
@@ -51,6 +51,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
||||
from lerobot.envs.configs import EnvConfig # noqa: F401
|
||||
from lerobot.envs.factory import make_env
|
||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||
from lerobot.processor.factory import (
|
||||
@@ -287,103 +288,6 @@ class EnvWrapper:
|
||||
return self._action_features
|
||||
|
||||
|
||||
class ActionQueue:
|
||||
def __init__(self, cfg: RTCConfig):
|
||||
self.queue = None # Processed actions for robot rollout
|
||||
self.original_queue = None # Original actions for RTC
|
||||
self.lock = Lock()
|
||||
self.last_index = 0
|
||||
self.cfg = cfg
|
||||
|
||||
def get(self) -> Tensor | None:
|
||||
with self.lock:
|
||||
if self.queue is None or self.last_index >= len(self.queue):
|
||||
return None
|
||||
|
||||
action = self.queue[self.last_index]
|
||||
self.last_index += 1
|
||||
return action.clone()
|
||||
|
||||
def qsize(self) -> int:
|
||||
# with self.lock:
|
||||
if self.queue is None:
|
||||
return 0
|
||||
length = len(self.queue)
|
||||
|
||||
return length - self.last_index
|
||||
|
||||
def empty(self) -> bool:
|
||||
# with self.lock:
|
||||
if self.queue is None:
|
||||
return True
|
||||
|
||||
length = len(self.queue)
|
||||
return length - self.last_index + 1 <= 0
|
||||
|
||||
def get_action_index(self) -> int:
|
||||
# with self.lock:
|
||||
return self.last_index
|
||||
|
||||
def get_left_over(self) -> Tensor:
|
||||
"""Get left over ORIGINAL actions for RTC prev_chunk_left_over."""
|
||||
with self.lock:
|
||||
if self.original_queue is None:
|
||||
return None
|
||||
return self.original_queue[self.last_index :]
|
||||
|
||||
def merge(
|
||||
self,
|
||||
original_actions: Tensor,
|
||||
processed_actions: Tensor,
|
||||
real_delay: int,
|
||||
action_index_before_inference: int | None = 0,
|
||||
):
|
||||
with self.lock:
|
||||
self._check_delays(real_delay, action_index_before_inference)
|
||||
|
||||
if self.cfg.enabled:
|
||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||
return
|
||||
|
||||
self._append_actions_queue(original_actions, processed_actions)
|
||||
|
||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||
self.original_queue = original_actions[real_delay:].clone()
|
||||
self.queue = processed_actions[real_delay:].clone()
|
||||
|
||||
logger.info(f"original_actions shape: {self.original_queue.shape}")
|
||||
logger.info(f"processed_actions shape: {self.queue.shape}")
|
||||
logger.info(f"real_delay: {real_delay}")
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||
if self.queue is None:
|
||||
self.original_queue = original_actions.clone()
|
||||
self.queue = processed_actions.clone()
|
||||
return
|
||||
|
||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||
self.original_queue = self.original_queue[self.last_index :]
|
||||
|
||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||
self.queue = self.queue[self.last_index :]
|
||||
|
||||
self.last_index = 0
|
||||
|
||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||
if action_index_before_inference is None:
|
||||
return
|
||||
|
||||
indexes_diff = self.last_index - action_index_before_inference
|
||||
if indexes_diff != real_delay:
|
||||
# Let's check that action index difference (real delay calculated based on action queue)
|
||||
# is the same as dealy calculated based on inference latency
|
||||
logger.warning(
|
||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class RTCDemoConfig(HubMixin):
|
||||
"""Configuration for RTC demo with action chunking policies."""
|
||||
@@ -413,16 +317,6 @@ class RTCDemoConfig(HubMixin):
|
||||
# Compute device
|
||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||
|
||||
# Compilation options
|
||||
compile_policy: bool = (
|
||||
False # Compile policy with torch.compile() for faster inference (not supported on MPS)
|
||||
)
|
||||
compile_mode: str = "default" # Compilation mode: default, reduce-overhead, max-autotune
|
||||
|
||||
# Alternative optimization options (work on all devices including MPS)
|
||||
use_channels_last: bool = False # Use channels_last memory format for images (faster on some devices)
|
||||
enable_cudnn_benchmark: bool = True # Enable cuDNN benchmarking (CUDA only)
|
||||
|
||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||
# It should be higher than inference delay + execution horizon.
|
||||
action_queue_size_to_get_new_actions: int = 30
|
||||
@@ -430,8 +324,29 @@ class RTCDemoConfig(HubMixin):
|
||||
# Task to execute
|
||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||
|
||||
# Debug options
|
||||
verbose_rtc_comparison: bool = True # Enable detailed RTC comparison output
|
||||
# Torch compile configuration
|
||||
use_torch_compile: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||
)
|
||||
|
||||
torch_compile_backend: str = field(
|
||||
default="inductor",
|
||||
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||
)
|
||||
|
||||
torch_compile_mode: str = field(
|
||||
default="default",
|
||||
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||
)
|
||||
|
||||
torch_compile_disable_cudagraphs: bool = field(
|
||||
default=True,
|
||||
metadata={
|
||||
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||
},
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||
@@ -544,25 +459,9 @@ def get_actions(
|
||||
|
||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||
|
||||
noise_size = (1, policy.config.chunk_size, policy.config.max_action_dim)
|
||||
noise = policy.model.sample_noise(noise_size, policy_device)
|
||||
noise_clone = noise.clone()
|
||||
|
||||
# Generate actions WITHOUT RTC for comparison (if verbose mode enabled)
|
||||
if cfg.verbose_rtc_comparison:
|
||||
policy.config.rtc_config.enabled = False
|
||||
not_rtc_actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
noise=noise,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
policy.config.rtc_config.enabled = True
|
||||
|
||||
# Generate actions WITH RTC
|
||||
actions = policy.predict_action_chunk(
|
||||
preproceseded_obs,
|
||||
noise=noise_clone if cfg.verbose_rtc_comparison else noise,
|
||||
inference_delay=inference_delay,
|
||||
prev_chunk_left_over=prev_actions,
|
||||
)
|
||||
@@ -570,34 +469,6 @@ def get_actions(
|
||||
# Store original actions (before postprocessing) for RTC
|
||||
original_actions = actions.squeeze(0).clone()
|
||||
|
||||
# Detailed comparison output (if verbose mode enabled)
|
||||
if cfg.verbose_rtc_comparison:
|
||||
logger.info("=" * 80)
|
||||
logger.info("RTC ACTION COMPARISON")
|
||||
logger.info("=" * 80)
|
||||
|
||||
# Print detailed statistics
|
||||
logger.info("\n" + tensor_stats_str(not_rtc_actions, "not_rtc_actions (without RTC)"))
|
||||
logger.info("\n" + tensor_stats_str(actions, "actions (with RTC)"))
|
||||
logger.info(
|
||||
"\n" + tensor_stats_str(prev_actions, "prev_actions (leftover from previous chunk)")
|
||||
)
|
||||
|
||||
# Compare RTC vs non-RTC actions
|
||||
logger.info(
|
||||
compare_tensors(actions, not_rtc_actions, "actions (RTC)", "not_rtc_actions (no RTC)")
|
||||
)
|
||||
|
||||
to_non_rtc_diff = actions - not_rtc_actions
|
||||
|
||||
print("to_non_rtc_diff", to_non_rtc_diff)
|
||||
if prev_actions is not None:
|
||||
prev_padded = torch.zeros_like(actions)
|
||||
prev_padded[:, : prev_actions.shape[1], :] = prev_actions
|
||||
to_prev_diff = actions - prev_padded
|
||||
print("to_prev_diff", to_prev_diff)
|
||||
print("=" * 80)
|
||||
|
||||
postprocessed_actions = postprocessor(actions)
|
||||
|
||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||
@@ -611,11 +482,6 @@ def get_actions(
|
||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||
)
|
||||
|
||||
logger.debug(f"[GET_ACTIONS] new_delay: {new_delay}")
|
||||
logger.debug(f"[GET_ACTIONS] original_actions shape: {original_actions.shape}")
|
||||
logger.debug(f"[GET_ACTIONS] postprocessed_actions shape: {postprocessed_actions.shape}")
|
||||
logger.debug(f"[GET_ACTIONS] action_index_before_inference: {action_index_before_inference}")
|
||||
|
||||
action_queue.merge(
|
||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||
)
|
||||
@@ -675,10 +541,56 @@ def actor_control(
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
def stop_by_duration(shutdown_event: Event, cfg: RTCDemoConfig):
|
||||
"""Stop the demo by duration."""
|
||||
time.sleep(cfg.duration)
|
||||
shutdown_event.set()
|
||||
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||
|
||||
Args:
|
||||
policy: Policy instance to compile
|
||||
cfg: Configuration containing torch compile settings
|
||||
|
||||
Returns:
|
||||
Policy with compiled predict_action_chunk method
|
||||
"""
|
||||
|
||||
# PI models handle their own compilation
|
||||
if policy.type == "pi05" or policy.type == "pi0":
|
||||
return policy
|
||||
|
||||
try:
|
||||
# Check if torch.compile is available (PyTorch 2.0+)
|
||||
if not hasattr(torch, "compile"):
|
||||
logger.warning(
|
||||
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||
f"Current version: {torch.__version__}. Skipping compilation."
|
||||
)
|
||||
return policy
|
||||
|
||||
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||
|
||||
# Compile the predict_action_chunk method
|
||||
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||
compile_kwargs = {
|
||||
"backend": cfg.torch_compile_backend,
|
||||
"mode": cfg.torch_compile_mode,
|
||||
}
|
||||
|
||||
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||
if cfg.torch_compile_disable_cudagraphs:
|
||||
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||
|
||||
original_method = policy.predict_action_chunk
|
||||
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||
policy.predict_action_chunk = compiled_method
|
||||
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to apply torch.compile: {e}")
|
||||
logger.warning("Continuing without torch.compile")
|
||||
|
||||
return policy
|
||||
|
||||
|
||||
@parser.wrap()
|
||||
@@ -701,61 +613,30 @@ def demo_cli(cfg: RTCDemoConfig):
|
||||
actor_thread = None
|
||||
|
||||
policy_class = get_policy_class(cfg.policy.type)
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
# Load config and set compile_model for pi0/pi05 models
|
||||
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||
|
||||
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||
config.compile_model = cfg.use_torch_compile
|
||||
|
||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||
|
||||
# Turn on RTC
|
||||
policy.config.rtc_config = cfg.rtc
|
||||
|
||||
# Init RTC processort, as by default if RTC disabled in the config
|
||||
# The processor won't be created
|
||||
policy.init_rtc_processor(verbose=cfg.verbose_rtc_comparison)
|
||||
policy.init_rtc_processor()
|
||||
|
||||
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
||||
|
||||
policy = policy.to(cfg.device)
|
||||
policy.eval()
|
||||
|
||||
# Apply memory format optimizations
|
||||
if cfg.use_channels_last:
|
||||
logger.info("Converting model to channels_last memory format")
|
||||
try:
|
||||
# Convert vision encoder to channels_last for better performance
|
||||
if hasattr(policy, "vision_encoder"):
|
||||
policy.vision_encoder = policy.vision_encoder.to(memory_format=torch.channels_last)
|
||||
logger.info("Successfully converted to channels_last format")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to convert to channels_last: {e}")
|
||||
|
||||
# Enable cuDNN benchmarking for CUDA
|
||||
if cfg.enable_cudnn_benchmark and cfg.device == "cuda":
|
||||
torch.backends.cudnn.benchmark = True
|
||||
logger.info("Enabled cuDNN benchmarking")
|
||||
|
||||
# Compile policy if requested
|
||||
if cfg.compile_policy:
|
||||
# Check if device is MPS - torch.compile has issues with MPS backend
|
||||
if cfg.device == "mps":
|
||||
logger.warning("torch.compile() is not stable with MPS backend (Apple Silicon)")
|
||||
logger.warning("Skipping compilation. For better performance on MPS:")
|
||||
logger.warning(" 1. Use torch.float32 instead of bfloat16")
|
||||
logger.warning(" 2. Ensure model uses contiguous memory layouts")
|
||||
logger.warning(" 3. Consider using CUDA if available")
|
||||
else:
|
||||
logger.info(f"Compiling policy with mode: {cfg.compile_mode}")
|
||||
logger.info("First inference will be slower due to compilation, subsequent calls will be faster")
|
||||
|
||||
try:
|
||||
# Compile the predict_action_chunk method
|
||||
policy.predict_action_chunk = torch.compile(
|
||||
policy.predict_action_chunk,
|
||||
mode=cfg.compile_mode,
|
||||
fullgraph=False, # Allow graph breaks for flexibility
|
||||
backend="inductor", # Use inductor backend
|
||||
)
|
||||
logger.info("Policy compiled successfully")
|
||||
except Exception as e:
|
||||
logger.warning(f"Failed to compile policy: {e}")
|
||||
logger.warning("Continuing without compilation")
|
||||
# Apply torch.compile to predict_action_chunk method if enabled
|
||||
if cfg.use_torch_compile:
|
||||
policy = _apply_torch_compile(policy, cfg)
|
||||
|
||||
# Create robot or environment
|
||||
if cfg.robot is not None:
|
||||
Reference in New Issue
Block a user