From 6684c686127c5a1ccabb6fc2d9293e5e2e7a7bff Mon Sep 17 00:00:00 2001 From: Eugene Mironov Date: Fri, 7 Nov 2025 02:58:38 +0700 Subject: [PATCH] Pi0 --- examples/rtc/eval_dataset.py | 7 ++ src/lerobot/policies/pi0/modeling_pi0.py | 98 ++++++++++++++++++---- src/lerobot/policies/pi05/modeling_pi05.py | 97 +++++++++++++++++---- 3 files changed, 173 insertions(+), 29 deletions(-) diff --git a/examples/rtc/eval_dataset.py b/examples/rtc/eval_dataset.py index f14a00711..dcf6ed660 100644 --- a/examples/rtc/eval_dataset.py +++ b/examples/rtc/eval_dataset.py @@ -25,6 +25,13 @@ Usage: --rtc.execution_horizon=8 \ --device=mps + # Basic usage with pi0.5 policy + uv run python examples/rtc/eval_dataset.py \ + --policy.path=lerobot/pi05_libero_finetuned \ + --dataset.repo_id=HuggingFaceVLA/libero \ + --rtc.execution_horizon=8 \ + --device=mps + # With torch.compile for faster inference (PyTorch 2.0+) # Note: CUDA graphs disabled by default due to in-place ops in denoising loop uv run python examples/rtc/eval_dataset.py \ diff --git a/src/lerobot/policies/pi0/modeling_pi0.py b/src/lerobot/policies/pi0/modeling_pi0.py index 596b273d5..f6e1d5f9c 100644 --- a/src/lerobot/policies/pi0/modeling_pi0.py +++ b/src/lerobot/policies/pi0/modeling_pi0.py @@ -19,11 +19,12 @@ import logging import math from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypedDict import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn +from typing_extensions import Unpack from lerobot.utils.import_utils import _transformers_available @@ -42,6 +43,7 @@ else: from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi0.configuration_pi0 import PI0Config from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -51,6 +53,12 @@ from lerobot.utils.constants import ( ) +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel( class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI0 PyTorch model.""" - def __init__(self, config: PI0Config): + def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config + self.rtc_processor = rtc_processor paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) @@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI0Pytorch model") + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: @@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` @torch.no_grad() # see openpi `sample_actions` (slightly adapted) def sample_actions( - self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None + self, + images, + img_masks, + lang_tokens, + lang_masks, + state, + noise=None, + num_steps=None, + **kwargs: Unpack[ActionSelectKwargs], ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: @@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch` time = torch.tensor(1.0, dtype=torch.float32, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) - v_t = self.denoise_step( - state, - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - x_t = x_t + dt * v_t + + # Define a closure function to properly capture expanded_time + # This avoids the lambda expression (E731) and loop variable binding (B023) issues + def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + return self.denoise_step( + state=state, + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) + + if self._rtc_enabled(): + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + + # Euler step + x_t += dt * v_t + + # Record x_t and v_t after Euler step + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) + time += dt return x_t @@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy): self.config = config # Initialize the core PI0 model - self.model = PI0Pytorch(config) + self.init_rtc_processor() + self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor) # Enable gradient checkpointing if requested if config.gradient_checkpointing: @@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + # Set rtc_processor to the model if it exists + if self.model is not None: + self.model.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Preprocess images for the model. @@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + self.eval() # Action queue logic for n_action_steps > 1 @@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy): return self._action_queue.popleft() @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Predict a chunk of actions given environment observations.""" self.eval() @@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy): lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] state = self.prepare_state(batch) - # Sample actions using the model - actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state) + # Sample actions using the model (pass through RTC kwargs) + actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0] diff --git a/src/lerobot/policies/pi05/modeling_pi05.py b/src/lerobot/policies/pi05/modeling_pi05.py index 93ca5fa82..dd002d307 100644 --- a/src/lerobot/policies/pi05/modeling_pi05.py +++ b/src/lerobot/policies/pi05/modeling_pi05.py @@ -19,11 +19,12 @@ import logging import math from collections import deque from pathlib import Path -from typing import TYPE_CHECKING, Literal +from typing import TYPE_CHECKING, Literal, TypedDict import torch import torch.nn.functional as F # noqa: N812 from torch import Tensor, nn +from typing_extensions import Unpack from lerobot.utils.import_utils import _transformers_available @@ -42,6 +43,7 @@ else: from lerobot.configs.policies import PreTrainedConfig from lerobot.policies.pi05.configuration_pi05 import PI05Config from lerobot.policies.pretrained import PreTrainedPolicy, T +from lerobot.policies.rtc.modeling_rtc import RTCProcessor from lerobot.utils.constants import ( ACTION, OBS_LANGUAGE_ATTENTION_MASK, @@ -50,6 +52,12 @@ from lerobot.utils.constants import ( ) +class ActionSelectKwargs(TypedDict, total=False): + inference_delay: int | None + prev_chunk_left_over: Tensor | None + execution_horizon: int | None + + def get_safe_dtype(target_dtype, device_type): """Get a safe dtype for the given device type.""" if device_type == "mps" and target_dtype == torch.float64: @@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel( class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` """Core PI05 PyTorch model.""" - def __init__(self, config: PI05Config): + def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None): super().__init__() self.config = config + self.rtc_processor = rtc_processor paligemma_config = get_gemma_config(config.paligemma_variant) action_expert_config = get_gemma_config(config.action_expert_variant) @@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False logging.info("Disabled gradient checkpointing for PI05Pytorch model") + def _rtc_enabled(self): + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _apply_checkpoint(self, func, *args, **kwargs): """Helper method to apply gradient checkpointing if enabled.""" if self.gradient_checkpointing_enabled and self.training: @@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` return F.mse_loss(u_t, v_t, reduction="none") @torch.no_grad() # see openpi `sample_actions` (slightly adapted) - def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor: + def sample_actions( + self, + images, + img_masks, + tokens, + masks, + noise=None, + num_steps=None, + **kwargs: Unpack[ActionSelectKwargs], + ) -> Tensor: """Do a full inference forward and compute the action.""" if num_steps is None: num_steps = self.config.num_inference_steps @@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch` time = torch.tensor(1.0, dtype=torch.float32, device=device) while time >= -dt / 2: expanded_time = time.expand(bsize) - v_t = self.denoise_step( - prefix_pad_masks, - past_key_values, - x_t, - expanded_time, - ) - x_t = x_t + dt * v_t + + # Define a closure function to properly capture expanded_time + # This avoids the lambda expression (E731) and loop variable binding (B023) issues + def denoise_step_partial_call(input_x_t, current_timestep=expanded_time): + return self.denoise_step( + prefix_pad_masks=prefix_pad_masks, + past_key_values=past_key_values, + x_t=input_x_t, + timestep=current_timestep, + ) + + if self._rtc_enabled(): + inference_delay = kwargs.get("inference_delay") + prev_chunk_left_over = kwargs.get("prev_chunk_left_over") + execution_horizon = kwargs.get("execution_horizon") + + v_t = self.rtc_processor.denoise_step( + x_t=x_t, + prev_chunk_left_over=prev_chunk_left_over, + inference_delay=inference_delay, + time=time, + original_denoise_step_partial=denoise_step_partial_call, + execution_horizon=execution_horizon, + ) + else: + v_t = denoise_step_partial_call(x_t) + + # Euler step + x_t += dt * v_t + + # Record x_t and v_t after Euler step + if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled(): + self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t) + time += dt return x_t @@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy): self.config = config # Initialize the core PI05 model - self.model = PI05Pytorch(config) + self.init_rtc_processor() + self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor) # Enable gradient checkpointing if requested if config.gradient_checkpointing: @@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy): ACTION: deque(maxlen=self.config.n_action_steps), } + def init_rtc_processor(self): + """Initialize RTC processor if RTC is enabled in config.""" + self.rtc_processor = None + + # Create processor if config provided + # If RTC is not enabled - we can still track the denoising data + if self.config.rtc_config is not None: + self.rtc_processor = RTCProcessor(self.config.rtc_config) + + # Set rtc_processor to the model if it exists + if self.model is not None: + self.model.rtc_processor = self.rtc_processor + + def _rtc_enabled(self) -> bool: + return self.config.rtc_config is not None and self.config.rtc_config.enabled + def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]: """Preprocess images for the model. @@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy): @torch.no_grad() def select_action(self, batch: dict[str, Tensor]) -> Tensor: """Select a single action given environment observations.""" + assert not self._rtc_enabled(), ( + "RTC is not supported for select_action, use it with predict_action_chunk" + ) + self.eval() # Action queue logic for n_action_steps > 1 @@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy): return self._action_queue.popleft() @torch.no_grad() - def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor: + def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor: """Predict a chunk of actions given environment observations.""" self.eval() @@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy): images, img_masks = self._preprocess_images(batch) tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"] - # Sample actions using the model (no separate state needed for PI05) - actions = self.model.sample_actions(images, img_masks, tokens, masks) + # Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05) + actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs) # Unpad actions to actual action dimension original_action_dim = self.config.output_features[ACTION].shape[0]