Small fixes

This commit is contained in:
Eugene Mironov
2025-11-08 16:38:13 +07:00
parent ab0a9c3d7a
commit ac33f20e51
3 changed files with 305 additions and 232 deletions
@@ -51,6 +51,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
from lerobot.envs.configs import EnvConfig # noqa: F401
from lerobot.envs.factory import make_env
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
from lerobot.policies.rtc.action_queue import ActionQueue
from lerobot.policies.rtc.configuration_rtc import RTCConfig
from lerobot.policies.rtc.latency_tracker import LatencyTracker
from lerobot.processor.factory import (
@@ -287,103 +288,6 @@ class EnvWrapper:
return self._action_features
class ActionQueue:
def __init__(self, cfg: RTCConfig):
self.queue = None # Processed actions for robot rollout
self.original_queue = None # Original actions for RTC
self.lock = Lock()
self.last_index = 0
self.cfg = cfg
def get(self) -> Tensor | None:
with self.lock:
if self.queue is None or self.last_index >= len(self.queue):
return None
action = self.queue[self.last_index]
self.last_index += 1
return action.clone()
def qsize(self) -> int:
# with self.lock:
if self.queue is None:
return 0
length = len(self.queue)
return length - self.last_index
def empty(self) -> bool:
# with self.lock:
if self.queue is None:
return True
length = len(self.queue)
return length - self.last_index + 1 <= 0
def get_action_index(self) -> int:
# with self.lock:
return self.last_index
def get_left_over(self) -> Tensor:
"""Get left over ORIGINAL actions for RTC prev_chunk_left_over."""
with self.lock:
if self.original_queue is None:
return None
return self.original_queue[self.last_index :]
def merge(
self,
original_actions: Tensor,
processed_actions: Tensor,
real_delay: int,
action_index_before_inference: int | None = 0,
):
with self.lock:
self._check_delays(real_delay, action_index_before_inference)
if self.cfg.enabled:
self._replace_actions_queue(original_actions, processed_actions, real_delay)
return
self._append_actions_queue(original_actions, processed_actions)
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
self.original_queue = original_actions[real_delay:].clone()
self.queue = processed_actions[real_delay:].clone()
logger.info(f"original_actions shape: {self.original_queue.shape}")
logger.info(f"processed_actions shape: {self.queue.shape}")
logger.info(f"real_delay: {real_delay}")
self.last_index = 0
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
if self.queue is None:
self.original_queue = original_actions.clone()
self.queue = processed_actions.clone()
return
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
self.original_queue = self.original_queue[self.last_index :]
self.queue = torch.cat([self.queue, processed_actions.clone()])
self.queue = self.queue[self.last_index :]
self.last_index = 0
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
if action_index_before_inference is None:
return
indexes_diff = self.last_index - action_index_before_inference
if indexes_diff != real_delay:
# Let's check that action index difference (real delay calculated based on action queue)
# is the same as dealy calculated based on inference latency
logger.warning(
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. Indexes diff: {indexes_diff}, real delay: {real_delay}"
)
@dataclass
class RTCDemoConfig(HubMixin):
"""Configuration for RTC demo with action chunking policies."""
@@ -413,16 +317,6 @@ class RTCDemoConfig(HubMixin):
# Compute device
device: str | None = None # Device to run on (cuda, cpu, auto)
# Compilation options
compile_policy: bool = (
False # Compile policy with torch.compile() for faster inference (not supported on MPS)
)
compile_mode: str = "default" # Compilation mode: default, reduce-overhead, max-autotune
# Alternative optimization options (work on all devices including MPS)
use_channels_last: bool = False # Use channels_last memory format for images (faster on some devices)
enable_cudnn_benchmark: bool = True # Enable cuDNN benchmarking (CUDA only)
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
# It should be higher than inference delay + execution horizon.
action_queue_size_to_get_new_actions: int = 30
@@ -430,8 +324,29 @@ class RTCDemoConfig(HubMixin):
# Task to execute
task: str = field(default="", metadata={"help": "Task to execute"})
# Debug options
verbose_rtc_comparison: bool = True # Enable detailed RTC comparison output
# Torch compile configuration
use_torch_compile: bool = field(
default=False,
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
)
torch_compile_backend: str = field(
default="inductor",
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
)
torch_compile_mode: str = field(
default="default",
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
)
torch_compile_disable_cudagraphs: bool = field(
default=True,
metadata={
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
},
)
def __post_init__(self):
# HACK: We parse again the cli args here to get the pretrained path if there was one.
@@ -544,25 +459,9 @@ def get_actions(
preproceseded_obs = preprocessor(obs_with_policy_features)
noise_size = (1, policy.config.chunk_size, policy.config.max_action_dim)
noise = policy.model.sample_noise(noise_size, policy_device)
noise_clone = noise.clone()
# Generate actions WITHOUT RTC for comparison (if verbose mode enabled)
if cfg.verbose_rtc_comparison:
policy.config.rtc_config.enabled = False
not_rtc_actions = policy.predict_action_chunk(
preproceseded_obs,
noise=noise,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
policy.config.rtc_config.enabled = True
# Generate actions WITH RTC
actions = policy.predict_action_chunk(
preproceseded_obs,
noise=noise_clone if cfg.verbose_rtc_comparison else noise,
inference_delay=inference_delay,
prev_chunk_left_over=prev_actions,
)
@@ -570,34 +469,6 @@ def get_actions(
# Store original actions (before postprocessing) for RTC
original_actions = actions.squeeze(0).clone()
# Detailed comparison output (if verbose mode enabled)
if cfg.verbose_rtc_comparison:
logger.info("=" * 80)
logger.info("RTC ACTION COMPARISON")
logger.info("=" * 80)
# Print detailed statistics
logger.info("\n" + tensor_stats_str(not_rtc_actions, "not_rtc_actions (without RTC)"))
logger.info("\n" + tensor_stats_str(actions, "actions (with RTC)"))
logger.info(
"\n" + tensor_stats_str(prev_actions, "prev_actions (leftover from previous chunk)")
)
# Compare RTC vs non-RTC actions
logger.info(
compare_tensors(actions, not_rtc_actions, "actions (RTC)", "not_rtc_actions (no RTC)")
)
to_non_rtc_diff = actions - not_rtc_actions
print("to_non_rtc_diff", to_non_rtc_diff)
if prev_actions is not None:
prev_padded = torch.zeros_like(actions)
prev_padded[:, : prev_actions.shape[1], :] = prev_actions
to_prev_diff = actions - prev_padded
print("to_prev_diff", to_prev_diff)
print("=" * 80)
postprocessed_actions = postprocessor(actions)
postprocessed_actions = postprocessed_actions.squeeze(0)
@@ -611,11 +482,6 @@ def get_actions(
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
)
logger.debug(f"[GET_ACTIONS] new_delay: {new_delay}")
logger.debug(f"[GET_ACTIONS] original_actions shape: {original_actions.shape}")
logger.debug(f"[GET_ACTIONS] postprocessed_actions shape: {postprocessed_actions.shape}")
logger.debug(f"[GET_ACTIONS] action_index_before_inference: {action_index_before_inference}")
action_queue.merge(
original_actions, postprocessed_actions, new_delay, action_index_before_inference
)
@@ -675,10 +541,56 @@ def actor_control(
sys.exit(1)
def stop_by_duration(shutdown_event: Event, cfg: RTCDemoConfig):
"""Stop the demo by duration."""
time.sleep(cfg.duration)
shutdown_event.set()
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
"""Apply torch.compile to the policy's predict_action_chunk method.
Args:
policy: Policy instance to compile
cfg: Configuration containing torch compile settings
Returns:
Policy with compiled predict_action_chunk method
"""
# PI models handle their own compilation
if policy.type == "pi05" or policy.type == "pi0":
return policy
try:
# Check if torch.compile is available (PyTorch 2.0+)
if not hasattr(torch, "compile"):
logger.warning(
f"torch.compile is not available. Requires PyTorch 2.0+. "
f"Current version: {torch.__version__}. Skipping compilation."
)
return policy
logger.info("Applying torch.compile to predict_action_chunk...")
logger.info(f" Backend: {cfg.torch_compile_backend}")
logger.info(f" Mode: {cfg.torch_compile_mode}")
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
# Compile the predict_action_chunk method
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
compile_kwargs = {
"backend": cfg.torch_compile_backend,
"mode": cfg.torch_compile_mode,
}
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
if cfg.torch_compile_disable_cudagraphs:
compile_kwargs["options"] = {"triton.cudagraphs": False}
original_method = policy.predict_action_chunk
compiled_method = torch.compile(original_method, **compile_kwargs)
policy.predict_action_chunk = compiled_method
logger.info("✓ Successfully compiled predict_action_chunk")
except Exception as e:
logger.error(f"Failed to apply torch.compile: {e}")
logger.warning("Continuing without torch.compile")
return policy
@parser.wrap()
@@ -701,61 +613,30 @@ def demo_cli(cfg: RTCDemoConfig):
actor_thread = None
policy_class = get_policy_class(cfg.policy.type)
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
# Load config and set compile_model for pi0/pi05 models
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
config.compile_model = cfg.use_torch_compile
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
# Turn on RTC
policy.config.rtc_config = cfg.rtc
# Init RTC processort, as by default if RTC disabled in the config
# The processor won't be created
policy.init_rtc_processor(verbose=cfg.verbose_rtc_comparison)
policy.init_rtc_processor()
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
policy = policy.to(cfg.device)
policy.eval()
# Apply memory format optimizations
if cfg.use_channels_last:
logger.info("Converting model to channels_last memory format")
try:
# Convert vision encoder to channels_last for better performance
if hasattr(policy, "vision_encoder"):
policy.vision_encoder = policy.vision_encoder.to(memory_format=torch.channels_last)
logger.info("Successfully converted to channels_last format")
except Exception as e:
logger.warning(f"Failed to convert to channels_last: {e}")
# Enable cuDNN benchmarking for CUDA
if cfg.enable_cudnn_benchmark and cfg.device == "cuda":
torch.backends.cudnn.benchmark = True
logger.info("Enabled cuDNN benchmarking")
# Compile policy if requested
if cfg.compile_policy:
# Check if device is MPS - torch.compile has issues with MPS backend
if cfg.device == "mps":
logger.warning("torch.compile() is not stable with MPS backend (Apple Silicon)")
logger.warning("Skipping compilation. For better performance on MPS:")
logger.warning(" 1. Use torch.float32 instead of bfloat16")
logger.warning(" 2. Ensure model uses contiguous memory layouts")
logger.warning(" 3. Consider using CUDA if available")
else:
logger.info(f"Compiling policy with mode: {cfg.compile_mode}")
logger.info("First inference will be slower due to compilation, subsequent calls will be faster")
try:
# Compile the predict_action_chunk method
policy.predict_action_chunk = torch.compile(
policy.predict_action_chunk,
mode=cfg.compile_mode,
fullgraph=False, # Allow graph breaks for flexibility
backend="inductor", # Use inductor backend
)
logger.info("Policy compiled successfully")
except Exception as e:
logger.warning(f"Failed to compile policy: {e}")
logger.warning("Continuing without compilation")
# Apply torch.compile to predict_action_chunk method if enabled
if cfg.use_torch_compile:
policy = _apply_torch_compile(policy, cfg)
# Create robot or environment
if cfg.robot is not None: