mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-24 21:19:53 +00:00
Pi0
This commit is contained in:
@@ -25,6 +25,13 @@ Usage:
|
|||||||
--rtc.execution_horizon=8 \
|
--rtc.execution_horizon=8 \
|
||||||
--device=mps
|
--device=mps
|
||||||
|
|
||||||
|
# Basic usage with pi0.5 policy
|
||||||
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
|
--policy.path=lerobot/pi05_libero_finetuned \
|
||||||
|
--dataset.repo_id=HuggingFaceVLA/libero \
|
||||||
|
--rtc.execution_horizon=8 \
|
||||||
|
--device=mps
|
||||||
|
|
||||||
# With torch.compile for faster inference (PyTorch 2.0+)
|
# With torch.compile for faster inference (PyTorch 2.0+)
|
||||||
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
# Note: CUDA graphs disabled by default due to in-place ops in denoising loop
|
||||||
uv run python examples/rtc/eval_dataset.py \
|
uv run python examples/rtc/eval_dataset.py \
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ else:
|
|||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
@@ -51,6 +53,12 @@ from lerobot.utils.constants import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionSelectKwargs(TypedDict, total=False):
|
||||||
|
inference_delay: int | None
|
||||||
|
prev_chunk_left_over: Tensor | None
|
||||||
|
execution_horizon: int | None
|
||||||
|
|
||||||
|
|
||||||
def get_safe_dtype(target_dtype, device_type):
|
def get_safe_dtype(target_dtype, device_type):
|
||||||
"""Get a safe dtype for the given device type."""
|
"""Get a safe dtype for the given device type."""
|
||||||
if device_type == "mps" and target_dtype == torch.float64:
|
if device_type == "mps" and target_dtype == torch.float64:
|
||||||
@@ -503,9 +511,10 @@ class PaliGemmaWithExpertModel(
|
|||||||
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||||
"""Core PI0 PyTorch model."""
|
"""Core PI0 PyTorch model."""
|
||||||
|
|
||||||
def __init__(self, config: PI0Config):
|
def __init__(self, config: PI0Config, rtc_processor: RTCProcessor | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.rtc_processor = rtc_processor
|
||||||
|
|
||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
@@ -560,6 +569,9 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI0Pytorch model")
|
||||||
|
|
||||||
|
def _rtc_enabled(self):
|
||||||
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
"""Helper method to apply gradient checkpointing if enabled."""
|
"""Helper method to apply gradient checkpointing if enabled."""
|
||||||
if self.gradient_checkpointing_enabled and self.training:
|
if self.gradient_checkpointing_enabled and self.training:
|
||||||
@@ -756,7 +768,15 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
|
|
||||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||||
def sample_actions(
|
def sample_actions(
|
||||||
self, images, img_masks, lang_tokens, lang_masks, state, noise=None, num_steps=None
|
self,
|
||||||
|
images,
|
||||||
|
img_masks,
|
||||||
|
lang_tokens,
|
||||||
|
lang_masks,
|
||||||
|
state,
|
||||||
|
noise=None,
|
||||||
|
num_steps=None,
|
||||||
|
**kwargs: Unpack[ActionSelectKwargs],
|
||||||
) -> Tensor:
|
) -> Tensor:
|
||||||
"""Do a full inference forward and compute the action."""
|
"""Do a full inference forward and compute the action."""
|
||||||
if num_steps is None:
|
if num_steps is None:
|
||||||
@@ -798,14 +818,41 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||||
while time >= -dt / 2:
|
while time >= -dt / 2:
|
||||||
expanded_time = time.expand(bsize)
|
expanded_time = time.expand(bsize)
|
||||||
v_t = self.denoise_step(
|
|
||||||
state,
|
# Define a closure function to properly capture expanded_time
|
||||||
prefix_pad_masks,
|
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||||
past_key_values,
|
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||||
x_t,
|
return self.denoise_step(
|
||||||
expanded_time,
|
state=state,
|
||||||
)
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
x_t = x_t + dt * v_t
|
past_key_values=past_key_values,
|
||||||
|
x_t=input_x_t,
|
||||||
|
timestep=current_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._rtc_enabled():
|
||||||
|
inference_delay = kwargs.get("inference_delay")
|
||||||
|
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||||
|
execution_horizon = kwargs.get("execution_horizon")
|
||||||
|
|
||||||
|
v_t = self.rtc_processor.denoise_step(
|
||||||
|
x_t=x_t,
|
||||||
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
|
inference_delay=inference_delay,
|
||||||
|
time=time,
|
||||||
|
original_denoise_step_partial=denoise_step_partial_call,
|
||||||
|
execution_horizon=execution_horizon,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
|
# Euler step
|
||||||
|
x_t += dt * v_t
|
||||||
|
|
||||||
|
# Record x_t and v_t after Euler step
|
||||||
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
time += dt
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
@@ -869,7 +916,8 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize the core PI0 model
|
# Initialize the core PI0 model
|
||||||
self.model = PI0Pytorch(config)
|
self.init_rtc_processor()
|
||||||
|
self.model = PI0Pytorch(config, rtc_processor=self.rtc_processor)
|
||||||
|
|
||||||
# Enable gradient checkpointing if requested
|
# Enable gradient checkpointing if requested
|
||||||
if config.gradient_checkpointing:
|
if config.gradient_checkpointing:
|
||||||
@@ -1059,6 +1107,22 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def init_rtc_processor(self):
|
||||||
|
"""Initialize RTC processor if RTC is enabled in config."""
|
||||||
|
self.rtc_processor = None
|
||||||
|
|
||||||
|
# Create processor if config provided
|
||||||
|
# If RTC is not enabled - we can still track the denoising data
|
||||||
|
if self.config.rtc_config is not None:
|
||||||
|
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||||
|
|
||||||
|
# Set rtc_processor to the model if it exists
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.rtc_processor = self.rtc_processor
|
||||||
|
|
||||||
|
def _rtc_enabled(self) -> bool:
|
||||||
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||||
"""Preprocess images for the model.
|
"""Preprocess images for the model.
|
||||||
|
|
||||||
@@ -1137,6 +1201,10 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
|
assert not self._rtc_enabled(), (
|
||||||
|
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||||
|
)
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1
|
# Action queue logic for n_action_steps > 1
|
||||||
@@ -1148,7 +1216,7 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1157,8 +1225,8 @@ class PI0Policy(PreTrainedPolicy):
|
|||||||
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
lang_tokens, lang_masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
state = self.prepare_state(batch)
|
state = self.prepare_state(batch)
|
||||||
|
|
||||||
# Sample actions using the model
|
# Sample actions using the model (pass through RTC kwargs)
|
||||||
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state)
|
actions = self.model.sample_actions(images, img_masks, lang_tokens, lang_masks, state, **kwargs)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
@@ -19,11 +19,12 @@ import logging
|
|||||||
import math
|
import math
|
||||||
from collections import deque
|
from collections import deque
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Literal
|
from typing import TYPE_CHECKING, Literal, TypedDict
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F # noqa: N812
|
import torch.nn.functional as F # noqa: N812
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
|
from typing_extensions import Unpack
|
||||||
|
|
||||||
from lerobot.utils.import_utils import _transformers_available
|
from lerobot.utils.import_utils import _transformers_available
|
||||||
|
|
||||||
@@ -42,6 +43,7 @@ else:
|
|||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
ACTION,
|
ACTION,
|
||||||
OBS_LANGUAGE_ATTENTION_MASK,
|
OBS_LANGUAGE_ATTENTION_MASK,
|
||||||
@@ -50,6 +52,12 @@ from lerobot.utils.constants import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class ActionSelectKwargs(TypedDict, total=False):
|
||||||
|
inference_delay: int | None
|
||||||
|
prev_chunk_left_over: Tensor | None
|
||||||
|
execution_horizon: int | None
|
||||||
|
|
||||||
|
|
||||||
def get_safe_dtype(target_dtype, device_type):
|
def get_safe_dtype(target_dtype, device_type):
|
||||||
"""Get a safe dtype for the given device type."""
|
"""Get a safe dtype for the given device type."""
|
||||||
if device_type == "mps" and target_dtype == torch.float64:
|
if device_type == "mps" and target_dtype == torch.float64:
|
||||||
@@ -502,9 +510,10 @@ class PaliGemmaWithExpertModel(
|
|||||||
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
||||||
"""Core PI05 PyTorch model."""
|
"""Core PI05 PyTorch model."""
|
||||||
|
|
||||||
def __init__(self, config: PI05Config):
|
def __init__(self, config: PI05Config, rtc_processor: RTCProcessor | None = None):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.rtc_processor = rtc_processor
|
||||||
|
|
||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
@@ -556,6 +565,9 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
self.paligemma_with_expert.gemma_expert.model.gradient_checkpointing = False
|
||||||
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
logging.info("Disabled gradient checkpointing for PI05Pytorch model")
|
||||||
|
|
||||||
|
def _rtc_enabled(self):
|
||||||
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
def _apply_checkpoint(self, func, *args, **kwargs):
|
def _apply_checkpoint(self, func, *args, **kwargs):
|
||||||
"""Helper method to apply gradient checkpointing if enabled."""
|
"""Helper method to apply gradient checkpointing if enabled."""
|
||||||
if self.gradient_checkpointing_enabled and self.training:
|
if self.gradient_checkpointing_enabled and self.training:
|
||||||
@@ -731,7 +743,16 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
return F.mse_loss(u_t, v_t, reduction="none")
|
return F.mse_loss(u_t, v_t, reduction="none")
|
||||||
|
|
||||||
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
@torch.no_grad() # see openpi `sample_actions` (slightly adapted)
|
||||||
def sample_actions(self, images, img_masks, tokens, masks, noise=None, num_steps=None) -> Tensor:
|
def sample_actions(
|
||||||
|
self,
|
||||||
|
images,
|
||||||
|
img_masks,
|
||||||
|
tokens,
|
||||||
|
masks,
|
||||||
|
noise=None,
|
||||||
|
num_steps=None,
|
||||||
|
**kwargs: Unpack[ActionSelectKwargs],
|
||||||
|
) -> Tensor:
|
||||||
"""Do a full inference forward and compute the action."""
|
"""Do a full inference forward and compute the action."""
|
||||||
if num_steps is None:
|
if num_steps is None:
|
||||||
num_steps = self.config.num_inference_steps
|
num_steps = self.config.num_inference_steps
|
||||||
@@ -770,13 +791,40 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
time = torch.tensor(1.0, dtype=torch.float32, device=device)
|
||||||
while time >= -dt / 2:
|
while time >= -dt / 2:
|
||||||
expanded_time = time.expand(bsize)
|
expanded_time = time.expand(bsize)
|
||||||
v_t = self.denoise_step(
|
|
||||||
prefix_pad_masks,
|
# Define a closure function to properly capture expanded_time
|
||||||
past_key_values,
|
# This avoids the lambda expression (E731) and loop variable binding (B023) issues
|
||||||
x_t,
|
def denoise_step_partial_call(input_x_t, current_timestep=expanded_time):
|
||||||
expanded_time,
|
return self.denoise_step(
|
||||||
)
|
prefix_pad_masks=prefix_pad_masks,
|
||||||
x_t = x_t + dt * v_t
|
past_key_values=past_key_values,
|
||||||
|
x_t=input_x_t,
|
||||||
|
timestep=current_timestep,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self._rtc_enabled():
|
||||||
|
inference_delay = kwargs.get("inference_delay")
|
||||||
|
prev_chunk_left_over = kwargs.get("prev_chunk_left_over")
|
||||||
|
execution_horizon = kwargs.get("execution_horizon")
|
||||||
|
|
||||||
|
v_t = self.rtc_processor.denoise_step(
|
||||||
|
x_t=x_t,
|
||||||
|
prev_chunk_left_over=prev_chunk_left_over,
|
||||||
|
inference_delay=inference_delay,
|
||||||
|
time=time,
|
||||||
|
original_denoise_step_partial=denoise_step_partial_call,
|
||||||
|
execution_horizon=execution_horizon,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
v_t = denoise_step_partial_call(x_t)
|
||||||
|
|
||||||
|
# Euler step
|
||||||
|
x_t += dt * v_t
|
||||||
|
|
||||||
|
# Record x_t and v_t after Euler step
|
||||||
|
if self.rtc_processor is not None and self.rtc_processor.is_debug_enabled():
|
||||||
|
self.rtc_processor.track(time=time, x_t=x_t, v_t=v_t)
|
||||||
|
|
||||||
time += dt
|
time += dt
|
||||||
|
|
||||||
return x_t
|
return x_t
|
||||||
@@ -839,7 +887,8 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
self.config = config
|
self.config = config
|
||||||
|
|
||||||
# Initialize the core PI05 model
|
# Initialize the core PI05 model
|
||||||
self.model = PI05Pytorch(config)
|
self.init_rtc_processor()
|
||||||
|
self.model = PI05Pytorch(config, rtc_processor=self.rtc_processor)
|
||||||
|
|
||||||
# Enable gradient checkpointing if requested
|
# Enable gradient checkpointing if requested
|
||||||
if config.gradient_checkpointing:
|
if config.gradient_checkpointing:
|
||||||
@@ -1035,6 +1084,22 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
ACTION: deque(maxlen=self.config.n_action_steps),
|
ACTION: deque(maxlen=self.config.n_action_steps),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
def init_rtc_processor(self):
|
||||||
|
"""Initialize RTC processor if RTC is enabled in config."""
|
||||||
|
self.rtc_processor = None
|
||||||
|
|
||||||
|
# Create processor if config provided
|
||||||
|
# If RTC is not enabled - we can still track the denoising data
|
||||||
|
if self.config.rtc_config is not None:
|
||||||
|
self.rtc_processor = RTCProcessor(self.config.rtc_config)
|
||||||
|
|
||||||
|
# Set rtc_processor to the model if it exists
|
||||||
|
if self.model is not None:
|
||||||
|
self.model.rtc_processor = self.rtc_processor
|
||||||
|
|
||||||
|
def _rtc_enabled(self) -> bool:
|
||||||
|
return self.config.rtc_config is not None and self.config.rtc_config.enabled
|
||||||
|
|
||||||
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
def _preprocess_images(self, batch: dict[str, Tensor]) -> tuple[list[Tensor], list[Tensor]]:
|
||||||
"""Preprocess images for the model.
|
"""Preprocess images for the model.
|
||||||
|
|
||||||
@@ -1109,6 +1174,10 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
def select_action(self, batch: dict[str, Tensor]) -> Tensor:
|
||||||
"""Select a single action given environment observations."""
|
"""Select a single action given environment observations."""
|
||||||
|
assert not self._rtc_enabled(), (
|
||||||
|
"RTC is not supported for select_action, use it with predict_action_chunk"
|
||||||
|
)
|
||||||
|
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
# Action queue logic for n_action_steps > 1
|
# Action queue logic for n_action_steps > 1
|
||||||
@@ -1120,7 +1189,7 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
return self._action_queue.popleft()
|
return self._action_queue.popleft()
|
||||||
|
|
||||||
@torch.no_grad()
|
@torch.no_grad()
|
||||||
def predict_action_chunk(self, batch: dict[str, Tensor]) -> Tensor:
|
def predict_action_chunk(self, batch: dict[str, Tensor], **kwargs: Unpack[ActionSelectKwargs]) -> Tensor:
|
||||||
"""Predict a chunk of actions given environment observations."""
|
"""Predict a chunk of actions given environment observations."""
|
||||||
self.eval()
|
self.eval()
|
||||||
|
|
||||||
@@ -1128,8 +1197,8 @@ class PI05Policy(PreTrainedPolicy):
|
|||||||
images, img_masks = self._preprocess_images(batch)
|
images, img_masks = self._preprocess_images(batch)
|
||||||
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
tokens, masks = batch[f"{OBS_LANGUAGE_TOKENS}"], batch[f"{OBS_LANGUAGE_ATTENTION_MASK}"]
|
||||||
|
|
||||||
# Sample actions using the model (no separate state needed for PI05)
|
# Sample actions using the model (pass through RTC kwargs, no separate state needed for PI05)
|
||||||
actions = self.model.sample_actions(images, img_masks, tokens, masks)
|
actions = self.model.sample_actions(images, img_masks, tokens, masks, **kwargs)
|
||||||
|
|
||||||
# Unpad actions to actual action dimension
|
# Unpad actions to actual action dimension
|
||||||
original_action_dim = self.config.output_features[ACTION].shape[0]
|
original_action_dim = self.config.output_features[ACTION].shape[0]
|
||||||
|
|||||||
Reference in New Issue
Block a user