This commit is contained in:
Eugene Mironov
2025-11-07 02:58:38 +07:00
parent ac1816ee9c
commit d10b7787eb
3 changed files with 173 additions and 29 deletions
+7
View File
@@ -25,6 +25,13 @@ Usage:
--rtc.execution_horizon=8 \ --rtc.execution_horizon=8 \
--device=mps --device=mps
# Basic usage with pi0.5 policy
uv run python examples/rtc/eval_dataset.py \
--policy.path=lerobot/pi05_libero_finetuned \
--dataset.repo_id=HuggingFaceVLA/libero \
--rtc.execution_horizon=8 \
--device=mps
# With torch.compile for faster inference (PyTorch 2.0+) # With torch.compile for faster inference (PyTorch 2.0+)
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop # Note: CUDA graphs disabled by default due to in-place ops in denoising loop
uv run python examples/rtc/eval_dataset.py \ uv run python examples/rtc/eval_dataset.py \
+83 -15
View File
@@ -19,11 +19,12 @@ import logging
import math import math
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal, TypedDict
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pi0.configuration_pi0 import PI0Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
) )
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type): def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type.""" """Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64: if device_type == "mps" and target_dtype == torch.float64:
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI0 PyTorch model.""" """Core PI0 PyTorch model."""
def __init__(self, config: PI0Config): def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant) paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant) action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI0Pytorch model") logging.info("Disabled gradient checkpointing for PI0Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs): def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled.""" """Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training: if self.gradient_checkpointing_enabled and self.training:
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
@torch.no_grad() # see openpi `sample_actions` (slightly adapted) @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions( def sample_actions(
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None self,
images,
img_masks,
lang_tokens,
lang_masks,
state,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor: ) -> Tensor:
"""Do a full inference forward and compute the action.""" """Do a full inference forward and compute the action."""
if num_steps is None: if num_steps is None:
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device) time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2: while time >= -dt / 2:
expanded_time = time.expand(bsize) expanded_time = time.expand(bsize)
v_t = self.denoise_step(
state, # Define a closure function to properly capture expanded_time
prefix_pad_masks, # This avoids the lambda expression (E731) and loop variable binding (B023) issues
past_key_values, def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
x_t, return self.denoise_step(
expanded_time, state=state,
) prefix_pad_masks=prefix_pad_masks,
x_t = x_t + dt * v_t past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt time += dt
return x_t return x_t
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
self.config = config self.config = config
# Initialize the core PI0 model # Initialize the core PI0 model
self.model = PI0Pytorch(config) self.init_rtc_processor()
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested # Enable gradient checkpointing if requested
if config.gradient_checkpointing: if config.gradient_checkpointing:
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps), ACTION: deque(maxlen=self.config.n_action_steps),
} }
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# Set rtc_processor to the model if it exists
if self.model is not None:
self.model.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model. """Preprocess images for the model.
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval() self.eval()
# Action queue logic for n_action_steps > 1 # Action queue logic for n_action_steps > 1
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
self.eval() self.eval()
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
state = self.prepare_state(batch) state = self.prepare_state(batch)
# Sample actions using the model # Sample actions using the model (pass through RTC kwargs)
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
# Unpad actions to actual action dimension # Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]
+83 -14
View File
@@ -19,11 +19,12 @@ import logging
import math import math
from collections import deque from collections import deque
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Literal from typing import TYPE_CHECKING, Literal, TypedDict
import torch import torch
import torch.nn.functional as F # noqa: N812 import torch.nn.functional as F # noqa: N812
from torch import Tensor, nn from torch import Tensor, nn
from typing_extensions import Unpack
from lerobot.utils.import_utils import _transformers_available from lerobot.utils.import_utils import _transformers_available
@@ -42,6 +43,7 @@ else:
from lerobot.configs.policies import PreTrainedConfig from lerobot.configs.policies import PreTrainedConfig
from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pi05.configuration_pi05 import PI05Config
from lerobot.policies.pretrained import PreTrainedPolicy, T from lerobot.policies.pretrained import PreTrainedPolicy, T
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
from lerobot.utils.constants import ( from lerobot.utils.constants import (
ACTION, ACTION,
OBS_LANGUAGE_ATTENTION_MASK, OBS_LANGUAGE_ATTENTION_MASK,
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
) )
class ActionSelectKwargs(TypedDict, total=False):
inference_delay: int | None
prev_chunk_left_over: Tensor | None
execution_horizon: int | None
def get_safe_dtype(target_dtype, device_type): def get_safe_dtype(target_dtype, device_type):
"""Get a safe dtype for the given device type.""" """Get a safe dtype for the given device type."""
if device_type == "mps" and target_dtype == torch.float64: if device_type == "mps" and target_dtype == torch.float64:
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
"""Core PI05 PyTorch model.""" """Core PI05 PyTorch model."""
def __init__(self, config: PI05Config): def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
super().__init__() super().__init__()
self.config = config self.config = config
self.rtc_processor = rtc_processor
paligemma_config = get_gemma_config(config.paligemma_variant) paligemma_config = get_gemma_config(config.paligemma_variant)
action_expert_config = get_gemma_config(config.action_expert_variant) action_expert_config = get_gemma_config(config.action_expert_variant)
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
logging.info("Disabled gradient checkpointing for PI05Pytorch model") logging.info("Disabled gradient checkpointing for PI05Pytorch model")
def _rtc_enabled(self):
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _apply_checkpoint(self, func, *args, **kwargs): def _apply_checkpoint(self, func, *args, **kwargs):
"""Helper method to apply gradient checkpointing if enabled.""" """Helper method to apply gradient checkpointing if enabled."""
if self.gradient_checkpointing_enabled and self.training: if self.gradient_checkpointing_enabled and self.training:
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
return F.mse_loss(u_t, v_t, reduction="none") return F.mse_loss(u_t, v_t, reduction="none")
@torch.no_grad() # see openpi `sample_actions` (slightly adapted) @torch.no_grad() # see openpi `sample_actions` (slightly adapted)
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor: def sample_actions(
self,
images,
img_masks,
tokens,
masks,
noise=None,
num_steps=None,
**kwargs: Unpack[ActionSelectKwargs],
) -> Tensor:
"""Do a full inference forward and compute the action.""" """Do a full inference forward and compute the action."""
if num_steps is None: if num_steps is None:
num_steps = self.config.num_inference_steps num_steps = self.config.num_inference_steps
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
time = torch.tensor(1.0, dtype=torch.float32, device=device) time = torch.tensor(1.0, dtype=torch.float32, device=device)
while time >= -dt / 2: while time >= -dt / 2:
expanded_time = time.expand(bsize) expanded_time = time.expand(bsize)
v_t = self.denoise_step(
prefix_pad_masks, # Define a closure function to properly capture expanded_time
past_key_values, # This avoids the lambda expression (E731) and loop variable binding (B023) issues
x_t, def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
expanded_time, return self.denoise_step(
) prefix_pad_masks=prefix_pad_masks,
x_t = x_t + dt * v_t past_key_values=past_key_values,
x_t=input_x_t,
timestep=current_timestep,
)
if self._rtc_enabled():
inference_delay = kwargs.get("inference_delay")
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
execution_horizon = kwargs.get("execution_horizon")
v_t = self.rtc_processor.denoise_step(
x_t=x_t,
prev_chunk_left_over=prev_chunk_left_over,
inference_delay=inference_delay,
time=time,
original_denoise_step_partial=denoise_step_partial_call,
execution_horizon=execution_horizon,
)
else:
v_t = denoise_step_partial_call(x_t)
# Euler step
x_t += dt * v_t
# Record x_t and v_t after Euler step
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
time += dt time += dt
return x_t return x_t
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
self.config = config self.config = config
# Initialize the core PI05 model # Initialize the core PI05 model
self.model = PI05Pytorch(config) self.init_rtc_processor()
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
# Enable gradient checkpointing if requested # Enable gradient checkpointing if requested
if config.gradient_checkpointing: if config.gradient_checkpointing:
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
ACTION: deque(maxlen=self.config.n_action_steps), ACTION: deque(maxlen=self.config.n_action_steps),
} }
def init_rtc_processor(self):
"""Initialize RTC processor if RTC is enabled in config."""
self.rtc_processor = None
# Create processor if config provided
# If RTC is not enabled - we can still track the denoising data
if self.config.rtc_config is not None:
self.rtc_processor = RTCProcessor(self.config.rtc_config)
# Set rtc_processor to the model if it exists
if self.model is not None:
self.model.rtc_processor = self.rtc_processor
def _rtc_enabled(self) -> bool:
return self.config.rtc_config is not None and self.config.rtc_config.enabled
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
"""Preprocess images for the model. """Preprocess images for the model.
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
@torch.no_grad() @torch.no_grad()
def select_action(self, batch: dict[str, Tensor]) -> Tensor: def select_action(self, batch: dict[str, Tensor]) -> Tensor:
"""Select a single action given environment observations.""" """Select a single action given environment observations."""
assert not self._rtc_enabled(), (
"RTC is not supported for select_action, use it with predict_action_chunk"
)
self.eval() self.eval()
# Action queue logic for n_action_steps > 1 # Action queue logic for n_action_steps > 1
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
return self._action_queue.popleft() return self._action_queue.popleft()
@torch.no_grad() @torch.no_grad()
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
"""Predict a chunk of actions given environment observations.""" """Predict a chunk of actions given environment observations."""
self.eval() self.eval()
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
images, img_masks = self._preprocess_images(batch) images, img_masks = self._preprocess_images(batch)
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
# Sample actions using the model (no separate state needed for PI05) # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
actions = self.model.sample_actions(images, img_masks, tokens, masks) actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
# Unpad actions to actual action dimension # Unpad actions to actual action dimension
original_action_dim = self.config.output_features[ACTION].shape[0] original_action_dim = self.config.output_features[ACTION].shape[0]