mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-20 19:19:56 +00:00
Small fixes
This commit is contained in:
+86
-205
@@ -51,6 +51,7 @@ from lerobot.datasets.utils import build_dataset_frame, hw_to_dataset_features
|
|||||||
from lerobot.envs.configs import EnvConfig # noqa: F401
|
from lerobot.envs.configs import EnvConfig # noqa: F401
|
||||||
from lerobot.envs.factory import make_env
|
from lerobot.envs.factory import make_env
|
||||||
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
from lerobot.policies.factory import get_policy_class, make_pre_post_processors
|
||||||
|
from lerobot.policies.rtc.action_queue import ActionQueue
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
from lerobot.policies.rtc.latency_tracker import LatencyTracker
|
||||||
from lerobot.processor.factory import (
|
from lerobot.processor.factory import (
|
||||||
@@ -287,103 +288,6 @@ class EnvWrapper:
|
|||||||
return self._action_features
|
return self._action_features
|
||||||
|
|
||||||
|
|
||||||
class ActionQueue:
|
|
||||||
def __init__(self, cfg: RTCConfig):
|
|
||||||
self.queue = None # Processed actions for robot rollout
|
|
||||||
self.original_queue = None # Original actions for RTC
|
|
||||||
self.lock = Lock()
|
|
||||||
self.last_index = 0
|
|
||||||
self.cfg = cfg
|
|
||||||
|
|
||||||
def get(self) -> Tensor | None:
|
|
||||||
with self.lock:
|
|
||||||
if self.queue is None or self.last_index >= len(self.queue):
|
|
||||||
return None
|
|
||||||
|
|
||||||
action = self.queue[self.last_index]
|
|
||||||
self.last_index += 1
|
|
||||||
return action.clone()
|
|
||||||
|
|
||||||
def qsize(self) -> int:
|
|
||||||
# with self.lock:
|
|
||||||
if self.queue is None:
|
|
||||||
return 0
|
|
||||||
length = len(self.queue)
|
|
||||||
|
|
||||||
return length - self.last_index
|
|
||||||
|
|
||||||
def empty(self) -> bool:
|
|
||||||
# with self.lock:
|
|
||||||
if self.queue is None:
|
|
||||||
return True
|
|
||||||
|
|
||||||
length = len(self.queue)
|
|
||||||
return length - self.last_index + 1 <= 0
|
|
||||||
|
|
||||||
def get_action_index(self) -> int:
|
|
||||||
# with self.lock:
|
|
||||||
return self.last_index
|
|
||||||
|
|
||||||
def get_left_over(self) -> Tensor:
|
|
||||||
"""Get left over ORIGINAL actions for RTC prev_chunk_left_over."""
|
|
||||||
with self.lock:
|
|
||||||
if self.original_queue is None:
|
|
||||||
return None
|
|
||||||
return self.original_queue[self.last_index :]
|
|
||||||
|
|
||||||
def merge(
|
|
||||||
self,
|
|
||||||
original_actions: Tensor,
|
|
||||||
processed_actions: Tensor,
|
|
||||||
real_delay: int,
|
|
||||||
action_index_before_inference: int | None = 0,
|
|
||||||
):
|
|
||||||
with self.lock:
|
|
||||||
self._check_delays(real_delay, action_index_before_inference)
|
|
||||||
|
|
||||||
if self.cfg.enabled:
|
|
||||||
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
|
||||||
return
|
|
||||||
|
|
||||||
self._append_actions_queue(original_actions, processed_actions)
|
|
||||||
|
|
||||||
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
|
||||||
self.original_queue = original_actions[real_delay:].clone()
|
|
||||||
self.queue = processed_actions[real_delay:].clone()
|
|
||||||
|
|
||||||
logger.info(f"original_actions shape: {self.original_queue.shape}")
|
|
||||||
logger.info(f"processed_actions shape: {self.queue.shape}")
|
|
||||||
logger.info(f"real_delay: {real_delay}")
|
|
||||||
|
|
||||||
self.last_index = 0
|
|
||||||
|
|
||||||
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
|
||||||
if self.queue is None:
|
|
||||||
self.original_queue = original_actions.clone()
|
|
||||||
self.queue = processed_actions.clone()
|
|
||||||
return
|
|
||||||
|
|
||||||
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
|
||||||
self.original_queue = self.original_queue[self.last_index :]
|
|
||||||
|
|
||||||
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
|
||||||
self.queue = self.queue[self.last_index :]
|
|
||||||
|
|
||||||
self.last_index = 0
|
|
||||||
|
|
||||||
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
|
||||||
if action_index_before_inference is None:
|
|
||||||
return
|
|
||||||
|
|
||||||
indexes_diff = self.last_index - action_index_before_inference
|
|
||||||
if indexes_diff != real_delay:
|
|
||||||
# Let's check that action index difference (real delay calculated based on action queue)
|
|
||||||
# is the same as dealy calculated based on inference latency
|
|
||||||
logger.warning(
|
|
||||||
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class RTCDemoConfig(HubMixin):
|
class RTCDemoConfig(HubMixin):
|
||||||
"""Configuration for RTC demo with action chunking policies."""
|
"""Configuration for RTC demo with action chunking policies."""
|
||||||
@@ -413,16 +317,6 @@ class RTCDemoConfig(HubMixin):
|
|||||||
# Compute device
|
# Compute device
|
||||||
device: str | None = None # Device to run on (cuda, cpu, auto)
|
device: str | None = None # Device to run on (cuda, cpu, auto)
|
||||||
|
|
||||||
# Compilation options
|
|
||||||
compile_policy: bool = (
|
|
||||||
False # Compile policy with torch.compile() for faster inference (not supported on MPS)
|
|
||||||
)
|
|
||||||
compile_mode: str = "default" # Compilation mode: default, reduce-overhead, max-autotune
|
|
||||||
|
|
||||||
# Alternative optimization options (work on all devices including MPS)
|
|
||||||
use_channels_last: bool = False # Use channels_last memory format for images (faster on some devices)
|
|
||||||
enable_cudnn_benchmark: bool = True # Enable cuDNN benchmarking (CUDA only)
|
|
||||||
|
|
||||||
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
# Get new actions horizon. The amount of executed steps after which will be requested new actions.
|
||||||
# It should be higher than inference delay + execution horizon.
|
# It should be higher than inference delay + execution horizon.
|
||||||
action_queue_size_to_get_new_actions: int = 30
|
action_queue_size_to_get_new_actions: int = 30
|
||||||
@@ -430,8 +324,29 @@ class RTCDemoConfig(HubMixin):
|
|||||||
# Task to execute
|
# Task to execute
|
||||||
task: str = field(default="", metadata={"help": "Task to execute"})
|
task: str = field(default="", metadata={"help": "Task to execute"})
|
||||||
|
|
||||||
# Debug options
|
# Torch compile configuration
|
||||||
verbose_rtc_comparison: bool = True # Enable detailed RTC comparison output
|
use_torch_compile: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Use torch.compile for faster inference (PyTorch 2.0+)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_compile_backend: str = field(
|
||||||
|
default="inductor",
|
||||||
|
metadata={"help": "Backend for torch.compile (inductor, aot_eager, cudagraphs)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_compile_mode: str = field(
|
||||||
|
default="default",
|
||||||
|
metadata={"help": "Compilation mode (default, reduce-overhead, max-autotune)"},
|
||||||
|
)
|
||||||
|
|
||||||
|
torch_compile_disable_cudagraphs: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={
|
||||||
|
"help": "Disable CUDA graphs in torch.compile. Required due to in-place tensor "
|
||||||
|
"operations in denoising loop (x_t += dt * v_t) which cause tensor aliasing issues."
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
# HACK: We parse again the cli args here to get the pretrained path if there was one.
|
||||||
@@ -544,25 +459,9 @@ def get_actions(
|
|||||||
|
|
||||||
preproceseded_obs = preprocessor(obs_with_policy_features)
|
preproceseded_obs = preprocessor(obs_with_policy_features)
|
||||||
|
|
||||||
noise_size = (1, policy.config.chunk_size, policy.config.max_action_dim)
|
|
||||||
noise = policy.model.sample_noise(noise_size, policy_device)
|
|
||||||
noise_clone = noise.clone()
|
|
||||||
|
|
||||||
# Generate actions WITHOUT RTC for comparison (if verbose mode enabled)
|
|
||||||
if cfg.verbose_rtc_comparison:
|
|
||||||
policy.config.rtc_config.enabled = False
|
|
||||||
not_rtc_actions = policy.predict_action_chunk(
|
|
||||||
preproceseded_obs,
|
|
||||||
noise=noise,
|
|
||||||
inference_delay=inference_delay,
|
|
||||||
prev_chunk_left_over=prev_actions,
|
|
||||||
)
|
|
||||||
policy.config.rtc_config.enabled = True
|
|
||||||
|
|
||||||
# Generate actions WITH RTC
|
# Generate actions WITH RTC
|
||||||
actions = policy.predict_action_chunk(
|
actions = policy.predict_action_chunk(
|
||||||
preproceseded_obs,
|
preproceseded_obs,
|
||||||
noise=noise_clone if cfg.verbose_rtc_comparison else noise,
|
|
||||||
inference_delay=inference_delay,
|
inference_delay=inference_delay,
|
||||||
prev_chunk_left_over=prev_actions,
|
prev_chunk_left_over=prev_actions,
|
||||||
)
|
)
|
||||||
@@ -570,34 +469,6 @@ def get_actions(
|
|||||||
# Store original actions (before postprocessing) for RTC
|
# Store original actions (before postprocessing) for RTC
|
||||||
original_actions = actions.squeeze(0).clone()
|
original_actions = actions.squeeze(0).clone()
|
||||||
|
|
||||||
# Detailed comparison output (if verbose mode enabled)
|
|
||||||
if cfg.verbose_rtc_comparison:
|
|
||||||
logger.info("=" * 80)
|
|
||||||
logger.info("RTC ACTION COMPARISON")
|
|
||||||
logger.info("=" * 80)
|
|
||||||
|
|
||||||
# Print detailed statistics
|
|
||||||
logger.info("\n" + tensor_stats_str(not_rtc_actions, "not_rtc_actions (without RTC)"))
|
|
||||||
logger.info("\n" + tensor_stats_str(actions, "actions (with RTC)"))
|
|
||||||
logger.info(
|
|
||||||
"\n" + tensor_stats_str(prev_actions, "prev_actions (leftover from previous chunk)")
|
|
||||||
)
|
|
||||||
|
|
||||||
# Compare RTC vs non-RTC actions
|
|
||||||
logger.info(
|
|
||||||
compare_tensors(actions, not_rtc_actions, "actions (RTC)", "not_rtc_actions (no RTC)")
|
|
||||||
)
|
|
||||||
|
|
||||||
to_non_rtc_diff = actions - not_rtc_actions
|
|
||||||
|
|
||||||
print("to_non_rtc_diff", to_non_rtc_diff)
|
|
||||||
if prev_actions is not None:
|
|
||||||
prev_padded = torch.zeros_like(actions)
|
|
||||||
prev_padded[:, : prev_actions.shape[1], :] = prev_actions
|
|
||||||
to_prev_diff = actions - prev_padded
|
|
||||||
print("to_prev_diff", to_prev_diff)
|
|
||||||
print("=" * 80)
|
|
||||||
|
|
||||||
postprocessed_actions = postprocessor(actions)
|
postprocessed_actions = postprocessor(actions)
|
||||||
|
|
||||||
postprocessed_actions = postprocessed_actions.squeeze(0)
|
postprocessed_actions = postprocessed_actions.squeeze(0)
|
||||||
@@ -611,11 +482,6 @@ def get_actions(
|
|||||||
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
"[GET_ACTIONS] cfg.action_queue_size_to_get_new_actions Too small, It should be higher than inference delay + execution horizon."
|
||||||
)
|
)
|
||||||
|
|
||||||
logger.debug(f"[GET_ACTIONS] new_delay: {new_delay}")
|
|
||||||
logger.debug(f"[GET_ACTIONS] original_actions shape: {original_actions.shape}")
|
|
||||||
logger.debug(f"[GET_ACTIONS] postprocessed_actions shape: {postprocessed_actions.shape}")
|
|
||||||
logger.debug(f"[GET_ACTIONS] action_index_before_inference: {action_index_before_inference}")
|
|
||||||
|
|
||||||
action_queue.merge(
|
action_queue.merge(
|
||||||
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
original_actions, postprocessed_actions, new_delay, action_index_before_inference
|
||||||
)
|
)
|
||||||
@@ -675,10 +541,56 @@ def actor_control(
|
|||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
def stop_by_duration(shutdown_event: Event, cfg: RTCDemoConfig):
|
def _apply_torch_compile(policy, cfg: RTCDemoConfig):
|
||||||
"""Stop the demo by duration."""
|
"""Apply torch.compile to the policy's predict_action_chunk method.
|
||||||
time.sleep(cfg.duration)
|
|
||||||
shutdown_event.set()
|
Args:
|
||||||
|
policy: Policy instance to compile
|
||||||
|
cfg: Configuration containing torch compile settings
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Policy with compiled predict_action_chunk method
|
||||||
|
"""
|
||||||
|
|
||||||
|
# PI models handle their own compilation
|
||||||
|
if policy.type == "pi05" or policy.type == "pi0":
|
||||||
|
return policy
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Check if torch.compile is available (PyTorch 2.0+)
|
||||||
|
if not hasattr(torch, "compile"):
|
||||||
|
logger.warning(
|
||||||
|
f"torch.compile is not available. Requires PyTorch 2.0+. "
|
||||||
|
f"Current version: {torch.__version__}. Skipping compilation."
|
||||||
|
)
|
||||||
|
return policy
|
||||||
|
|
||||||
|
logger.info("Applying torch.compile to predict_action_chunk...")
|
||||||
|
logger.info(f" Backend: {cfg.torch_compile_backend}")
|
||||||
|
logger.info(f" Mode: {cfg.torch_compile_mode}")
|
||||||
|
logger.info(f" Disable CUDA graphs: {cfg.torch_compile_disable_cudagraphs}")
|
||||||
|
|
||||||
|
# Compile the predict_action_chunk method
|
||||||
|
# - CUDA graphs disabled to prevent tensor aliasing from in-place ops (x_t += dt * v_t)
|
||||||
|
compile_kwargs = {
|
||||||
|
"backend": cfg.torch_compile_backend,
|
||||||
|
"mode": cfg.torch_compile_mode,
|
||||||
|
}
|
||||||
|
|
||||||
|
# Disable CUDA graphs if requested (prevents tensor aliasing issues)
|
||||||
|
if cfg.torch_compile_disable_cudagraphs:
|
||||||
|
compile_kwargs["options"] = {"triton.cudagraphs": False}
|
||||||
|
|
||||||
|
original_method = policy.predict_action_chunk
|
||||||
|
compiled_method = torch.compile(original_method, **compile_kwargs)
|
||||||
|
policy.predict_action_chunk = compiled_method
|
||||||
|
logger.info("✓ Successfully compiled predict_action_chunk")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to apply torch.compile: {e}")
|
||||||
|
logger.warning("Continuing without torch.compile")
|
||||||
|
|
||||||
|
return policy
|
||||||
|
|
||||||
|
|
||||||
@parser.wrap()
|
@parser.wrap()
|
||||||
@@ -701,61 +613,30 @@ def demo_cli(cfg: RTCDemoConfig):
|
|||||||
actor_thread = None
|
actor_thread = None
|
||||||
|
|
||||||
policy_class = get_policy_class(cfg.policy.type)
|
policy_class = get_policy_class(cfg.policy.type)
|
||||||
policy = policy_class.from_pretrained(cfg.policy.pretrained_path)
|
|
||||||
|
# Load config and set compile_model for pi0/pi05 models
|
||||||
|
config = PreTrainedConfig.from_pretrained(cfg.policy.pretrained_path)
|
||||||
|
|
||||||
|
if cfg.policy.type == "pi05" or cfg.policy.type == "pi0":
|
||||||
|
config.compile_model = cfg.use_torch_compile
|
||||||
|
|
||||||
|
policy = policy_class.from_pretrained(cfg.policy.pretrained_path, config=config)
|
||||||
|
|
||||||
# Turn on RTC
|
# Turn on RTC
|
||||||
policy.config.rtc_config = cfg.rtc
|
policy.config.rtc_config = cfg.rtc
|
||||||
|
|
||||||
# Init RTC processort, as by default if RTC disabled in the config
|
# Init RTC processort, as by default if RTC disabled in the config
|
||||||
# The processor won't be created
|
# The processor won't be created
|
||||||
policy.init_rtc_processor(verbose=cfg.verbose_rtc_comparison)
|
policy.init_rtc_processor()
|
||||||
|
|
||||||
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
assert policy.name in ["smolvla"], "Only smolvla are supported for RTC"
|
||||||
|
|
||||||
policy = policy.to(cfg.device)
|
policy = policy.to(cfg.device)
|
||||||
policy.eval()
|
policy.eval()
|
||||||
|
|
||||||
# Apply memory format optimizations
|
# Apply torch.compile to predict_action_chunk method if enabled
|
||||||
if cfg.use_channels_last:
|
if cfg.use_torch_compile:
|
||||||
logger.info("Converting model to channels_last memory format")
|
policy = _apply_torch_compile(policy, cfg)
|
||||||
try:
|
|
||||||
# Convert vision encoder to channels_last for better performance
|
|
||||||
if hasattr(policy, "vision_encoder"):
|
|
||||||
policy.vision_encoder = policy.vision_encoder.to(memory_format=torch.channels_last)
|
|
||||||
logger.info("Successfully converted to channels_last format")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to convert to channels_last: {e}")
|
|
||||||
|
|
||||||
# Enable cuDNN benchmarking for CUDA
|
|
||||||
if cfg.enable_cudnn_benchmark and cfg.device == "cuda":
|
|
||||||
torch.backends.cudnn.benchmark = True
|
|
||||||
logger.info("Enabled cuDNN benchmarking")
|
|
||||||
|
|
||||||
# Compile policy if requested
|
|
||||||
if cfg.compile_policy:
|
|
||||||
# Check if device is MPS - torch.compile has issues with MPS backend
|
|
||||||
if cfg.device == "mps":
|
|
||||||
logger.warning("torch.compile() is not stable with MPS backend (Apple Silicon)")
|
|
||||||
logger.warning("Skipping compilation. For better performance on MPS:")
|
|
||||||
logger.warning(" 1. Use torch.float32 instead of bfloat16")
|
|
||||||
logger.warning(" 2. Ensure model uses contiguous memory layouts")
|
|
||||||
logger.warning(" 3. Consider using CUDA if available")
|
|
||||||
else:
|
|
||||||
logger.info(f"Compiling policy with mode: {cfg.compile_mode}")
|
|
||||||
logger.info("First inference will be slower due to compilation, subsequent calls will be faster")
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Compile the predict_action_chunk method
|
|
||||||
policy.predict_action_chunk = torch.compile(
|
|
||||||
policy.predict_action_chunk,
|
|
||||||
mode=cfg.compile_mode,
|
|
||||||
fullgraph=False, # Allow graph breaks for flexibility
|
|
||||||
backend="inductor", # Use inductor backend
|
|
||||||
)
|
|
||||||
logger.info("Policy compiled successfully")
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"Failed to compile policy: {e}")
|
|
||||||
logger.warning("Continuing without compilation")
|
|
||||||
|
|
||||||
# Create robot or environment
|
# Create robot or environment
|
||||||
if cfg.robot is not None:
|
if cfg.robot is not None:
|
||||||
@@ -0,0 +1,219 @@
|
|||||||
|
#!/usr/bin/env python
|
||||||
|
|
||||||
|
# Copyright 2025 The HuggingFace Inc. team. All rights reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
|
||||||
|
"""Action queue management for Real-Time Chunking (RTC).
|
||||||
|
|
||||||
|
This module provides ActionQueue, a thread-safe queue for managing action chunks
|
||||||
|
in real-time control scenarios. It supports both RTC-enabled and non-RTC modes,
|
||||||
|
handling action merging and leftover tracking.
|
||||||
|
"""
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from threading import Lock
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torch import Tensor
|
||||||
|
|
||||||
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionQueue:
|
||||||
|
"""Thread-safe queue for managing action chunks in real-time control.
|
||||||
|
|
||||||
|
This queue handles two types of action sequences:
|
||||||
|
- Original actions: Used for RTC to compute leftovers from previous chunks
|
||||||
|
- Processed actions: Post-processed actions ready for robot execution
|
||||||
|
|
||||||
|
The queue operates in two modes:
|
||||||
|
1. RTC-enabled: Replaces the entire queue with new actions, accounting for inference delay
|
||||||
|
2. RTC-disabled: Appends new actions to the queue, maintaining continuity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg (RTCConfig): Configuration for Real-Time Chunking behavior.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
queue (Tensor | None): Processed actions for robot rollout (time_steps, action_dim).
|
||||||
|
original_queue (Tensor | None): Original actions for RTC computation (time_steps, action_dim).
|
||||||
|
last_index (int): Current consumption index in the queue.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, cfg: RTCConfig):
|
||||||
|
"""Initialize the action queue.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
cfg: RTC configuration controlling queue behavior.
|
||||||
|
"""
|
||||||
|
self.queue = None # Processed actions for robot rollout
|
||||||
|
self.original_queue = None # Original actions for RTC
|
||||||
|
self.lock = Lock()
|
||||||
|
self.last_index = 0
|
||||||
|
self.cfg = cfg
|
||||||
|
|
||||||
|
def get(self) -> Tensor | None:
|
||||||
|
"""Get the next action from the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor | None: The next action (action_dim,) or None if queue is empty.
|
||||||
|
Returns a clone to prevent external modifications.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if self.queue is None or self.last_index >= len(self.queue):
|
||||||
|
return None
|
||||||
|
|
||||||
|
action = self.queue[self.last_index]
|
||||||
|
self.last_index += 1
|
||||||
|
return action.clone()
|
||||||
|
|
||||||
|
def qsize(self) -> int:
|
||||||
|
"""Get the number of remaining actions in the queue.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of unconsumed actions.
|
||||||
|
"""
|
||||||
|
if self.queue is None:
|
||||||
|
return 0
|
||||||
|
length = len(self.queue)
|
||||||
|
return length - self.last_index
|
||||||
|
|
||||||
|
def empty(self) -> bool:
|
||||||
|
"""Check if the queue is empty.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if no actions remain, False otherwise.
|
||||||
|
"""
|
||||||
|
if self.queue is None:
|
||||||
|
return True
|
||||||
|
|
||||||
|
length = len(self.queue)
|
||||||
|
return length - self.last_index <= 0
|
||||||
|
|
||||||
|
def get_action_index(self) -> int:
|
||||||
|
"""Get the current action consumption index.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Index of the next action to be consumed.
|
||||||
|
"""
|
||||||
|
return self.last_index
|
||||||
|
|
||||||
|
def get_left_over(self) -> Tensor | None:
|
||||||
|
"""Get leftover original actions for RTC prev_chunk_left_over.
|
||||||
|
|
||||||
|
These are the unconsumed actions from the current chunk, which will be
|
||||||
|
used by RTC to compute corrections for the next chunk.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tensor | None: Remaining original actions (remaining_steps, action_dim),
|
||||||
|
or None if no original queue exists.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
if self.original_queue is None:
|
||||||
|
return None
|
||||||
|
return self.original_queue[self.last_index :]
|
||||||
|
|
||||||
|
def merge(
|
||||||
|
self,
|
||||||
|
original_actions: Tensor,
|
||||||
|
processed_actions: Tensor,
|
||||||
|
real_delay: int,
|
||||||
|
action_index_before_inference: int | None = 0,
|
||||||
|
):
|
||||||
|
"""Merge new actions into the queue.
|
||||||
|
|
||||||
|
This method operates differently based on RTC mode:
|
||||||
|
- RTC enabled: Replaces the queue, accounting for inference delay
|
||||||
|
- RTC disabled: Appends to the queue, maintaining continuity
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_actions: Unprocessed actions from policy (time_steps, action_dim).
|
||||||
|
processed_actions: Post-processed actions for robot (time_steps, action_dim).
|
||||||
|
real_delay: Number of time steps of inference delay.
|
||||||
|
action_index_before_inference: Index before inference started, for validation.
|
||||||
|
"""
|
||||||
|
with self.lock:
|
||||||
|
self._check_delays(real_delay, action_index_before_inference)
|
||||||
|
|
||||||
|
if self.cfg.enabled:
|
||||||
|
self._replace_actions_queue(original_actions, processed_actions, real_delay)
|
||||||
|
return
|
||||||
|
|
||||||
|
self._append_actions_queue(original_actions, processed_actions)
|
||||||
|
|
||||||
|
def _replace_actions_queue(self, original_actions: Tensor, processed_actions: Tensor, real_delay: int):
|
||||||
|
"""Replace the queue with new actions (RTC mode).
|
||||||
|
|
||||||
|
Discards the first `real_delay` actions since they correspond to the time
|
||||||
|
spent during inference, when the robot was executing previous actions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_actions: Unprocessed actions from policy.
|
||||||
|
processed_actions: Post-processed actions for robot.
|
||||||
|
real_delay: Number of time steps to skip due to inference delay.
|
||||||
|
"""
|
||||||
|
self.original_queue = original_actions[real_delay:].clone()
|
||||||
|
self.queue = processed_actions[real_delay:].clone()
|
||||||
|
|
||||||
|
logger.debug(f"original_actions shape: {self.original_queue.shape}")
|
||||||
|
logger.debug(f"processed_actions shape: {self.queue.shape}")
|
||||||
|
logger.debug(f"real_delay: {real_delay}")
|
||||||
|
|
||||||
|
self.last_index = 0
|
||||||
|
|
||||||
|
def _append_actions_queue(self, original_actions: Tensor, processed_actions: Tensor):
|
||||||
|
"""Append new actions to the queue (non-RTC mode).
|
||||||
|
|
||||||
|
Removes already-consumed actions and appends new ones, maintaining
|
||||||
|
queue continuity without replacement.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
original_actions: Unprocessed actions from policy.
|
||||||
|
processed_actions: Post-processed actions for robot.
|
||||||
|
"""
|
||||||
|
if self.queue is None:
|
||||||
|
self.original_queue = original_actions.clone()
|
||||||
|
self.queue = processed_actions.clone()
|
||||||
|
return
|
||||||
|
|
||||||
|
self.original_queue = torch.cat([self.original_queue, original_actions.clone()])
|
||||||
|
self.original_queue = self.original_queue[self.last_index :]
|
||||||
|
|
||||||
|
self.queue = torch.cat([self.queue, processed_actions.clone()])
|
||||||
|
self.queue = self.queue[self.last_index :]
|
||||||
|
|
||||||
|
self.last_index = 0
|
||||||
|
|
||||||
|
def _check_delays(self, real_delay: int, action_index_before_inference: int | None = None):
|
||||||
|
"""Validate that computed delays match expectations.
|
||||||
|
|
||||||
|
Compares the delay computed from inference latency with the actual
|
||||||
|
number of actions consumed during inference.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
real_delay: Delay computed from inference latency.
|
||||||
|
action_index_before_inference: Action index when inference started.
|
||||||
|
"""
|
||||||
|
if action_index_before_inference is None:
|
||||||
|
return
|
||||||
|
|
||||||
|
indexes_diff = self.last_index - action_index_before_inference
|
||||||
|
if indexes_diff != real_delay:
|
||||||
|
# Let's check that action index difference (real delay calculated based on action queue)
|
||||||
|
# is the same as delay calculated based on inference latency
|
||||||
|
logger.warning(
|
||||||
|
f"[ACTION_QUEUE] Indexes diff is not equal to real delay. "
|
||||||
|
f"Indexes diff: {indexes_diff}, real delay: {real_delay}"
|
||||||
|
)
|
||||||
@@ -87,15 +87,6 @@ class RTCProcessor:
|
|||||||
**metadata,
|
**metadata,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_tracker_stats(self) -> dict | None:
|
|
||||||
"""Get tracker statistics summary.
|
|
||||||
|
|
||||||
Returns None if tracker is disabled or None.
|
|
||||||
"""
|
|
||||||
if self.tracker is not None:
|
|
||||||
return self.tracker.get_step_stats_summary()
|
|
||||||
return None
|
|
||||||
|
|
||||||
def get_all_debug_steps(self) -> list:
|
def get_all_debug_steps(self) -> list:
|
||||||
"""Get all debug steps from tracker.
|
"""Get all debug steps from tracker.
|
||||||
|
|
||||||
@@ -105,15 +96,6 @@ class RTCProcessor:
|
|||||||
return self.tracker.get_all_steps()
|
return self.tracker.get_all_steps()
|
||||||
return []
|
return []
|
||||||
|
|
||||||
def get_recent_debug_steps(self, n: int = 1) -> list:
|
|
||||||
"""Get recent debug steps from tracker.
|
|
||||||
|
|
||||||
Returns empty list if tracker is disabled or None.
|
|
||||||
"""
|
|
||||||
if self.tracker is not None:
|
|
||||||
return self.tracker.get_recent_steps(n)
|
|
||||||
return []
|
|
||||||
|
|
||||||
def is_debug_enabled(self) -> bool:
|
def is_debug_enabled(self) -> bool:
|
||||||
"""Check if debug tracking is enabled.
|
"""Check if debug tracking is enabled.
|
||||||
|
|
||||||
@@ -129,15 +111,6 @@ class RTCProcessor:
|
|||||||
if self.tracker is not None:
|
if self.tracker is not None:
|
||||||
self.tracker.reset()
|
self.tracker.reset()
|
||||||
|
|
||||||
def get_tracker_length(self) -> int:
|
|
||||||
"""Get the number of recorded debug steps.
|
|
||||||
|
|
||||||
Returns 0 if tracker is disabled or None.
|
|
||||||
"""
|
|
||||||
if self.tracker is not None:
|
|
||||||
return len(self.tracker)
|
|
||||||
return 0
|
|
||||||
|
|
||||||
# ====================== End Tracker Proxy Methods ======================
|
# ====================== End Tracker Proxy Methods ======================
|
||||||
|
|
||||||
def denoise_step(
|
def denoise_step(
|
||||||
|
|||||||
Reference in New Issue
Block a user