mirror of
https://github.com/huggingface/lerobot.git
synced 2026-05-22 12:09:42 +00:00
enable variable image sizes to pi0/pi0.5 (#2609)
* enable variable image sizes to pi0/pi0.5 * add square image assertion
This commit is contained in:
@@ -23,6 +23,8 @@ from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
|||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
from lerobot.utils.constants import OBS_IMAGES
|
from lerobot.utils.constants import OBS_IMAGES
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi0")
|
@PreTrainedConfig.register_subclass("pi0")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -51,7 +53,10 @@ class PI0Config(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
image_resolution: tuple[int, int] = (
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
) # see openpi `preprocessing_pytorch.py`
|
||||||
|
|
||||||
# Add empty images. Used to add empty cameras when no image features are present.
|
# Add empty images. Used to add empty cameras when no image features are present.
|
||||||
empty_cameras: int = 0
|
empty_cameras: int = 0
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ else:
|
|||||||
PaliGemmaForConditionalGeneration = None
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi0.configuration_pi0 import PI0Config
|
from lerobot.policies.pi0.configuration_pi0 import DEFAULT_IMAGE_SIZE, PI0Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -337,6 +337,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
@@ -356,6 +357,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
|
vlm_config_hf.vision_config.image_size = image_size
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
@@ -519,11 +521,17 @@ class PI0Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
|
|
||||||
|
if config.image_resolution[0] != config.image_resolution[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||||
|
)
|
||||||
|
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||||
paligemma_config,
|
paligemma_config,
|
||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=[False, False],
|
use_adarms=[False, False],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
|
image_size=config.image_resolution[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
|
|||||||
@@ -22,6 +22,8 @@ from lerobot.optim.optimizers import AdamWConfig
|
|||||||
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
from lerobot.optim.schedulers import CosineDecayWithWarmupSchedulerConfig
|
||||||
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
from lerobot.policies.rtc.configuration_rtc import RTCConfig
|
||||||
|
|
||||||
|
DEFAULT_IMAGE_SIZE = 224
|
||||||
|
|
||||||
|
|
||||||
@PreTrainedConfig.register_subclass("pi05")
|
@PreTrainedConfig.register_subclass("pi05")
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -50,7 +52,10 @@ class PI05Config(PreTrainedConfig):
|
|||||||
# Real-Time Chunking (RTC) configuration
|
# Real-Time Chunking (RTC) configuration
|
||||||
rtc_config: RTCConfig | None = None
|
rtc_config: RTCConfig | None = None
|
||||||
|
|
||||||
image_resolution: tuple[int, int] = (224, 224) # see openpi `preprocessing_pytorch.py`
|
image_resolution: tuple[int, int] = (
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
DEFAULT_IMAGE_SIZE,
|
||||||
|
) # see openpi `preprocessing_pytorch.py`
|
||||||
|
|
||||||
# Add empty images. Used to add empty cameras when no image features are present.
|
# Add empty images. Used to add empty cameras when no image features are present.
|
||||||
empty_cameras: int = 0
|
empty_cameras: int = 0
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ else:
|
|||||||
PaliGemmaForConditionalGeneration = None
|
PaliGemmaForConditionalGeneration = None
|
||||||
|
|
||||||
from lerobot.configs.policies import PreTrainedConfig
|
from lerobot.configs.policies import PreTrainedConfig
|
||||||
from lerobot.policies.pi05.configuration_pi05 import PI05Config
|
from lerobot.policies.pi05.configuration_pi05 import DEFAULT_IMAGE_SIZE, PI05Config
|
||||||
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
from lerobot.policies.pretrained import PreTrainedPolicy, T
|
||||||
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
from lerobot.policies.rtc.modeling_rtc import RTCProcessor
|
||||||
from lerobot.utils.constants import (
|
from lerobot.utils.constants import (
|
||||||
@@ -336,6 +336,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=None,
|
use_adarms=None,
|
||||||
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
precision: Literal["bfloat16", "float32"] = "bfloat16",
|
||||||
|
image_size: int = DEFAULT_IMAGE_SIZE,
|
||||||
):
|
):
|
||||||
if use_adarms is None:
|
if use_adarms is None:
|
||||||
use_adarms = [False, False]
|
use_adarms = [False, False]
|
||||||
@@ -355,6 +356,7 @@ class PaliGemmaWithExpertModel(
|
|||||||
vlm_config_hf.text_config.vocab_size = 257152
|
vlm_config_hf.text_config.vocab_size = 257152
|
||||||
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
vlm_config_hf.text_config.use_adarms = use_adarms[0]
|
||||||
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
vlm_config_hf.text_config.adarms_cond_dim = vlm_config.width if use_adarms[0] else None
|
||||||
|
vlm_config_hf.vision_config.image_size = image_size
|
||||||
vlm_config_hf.vision_config.intermediate_size = 4304
|
vlm_config_hf.vision_config.intermediate_size = 4304
|
||||||
vlm_config_hf.vision_config.projection_dim = 2048
|
vlm_config_hf.vision_config.projection_dim = 2048
|
||||||
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
vlm_config_hf.vision_config.projector_hidden_act = "gelu_fast"
|
||||||
@@ -518,11 +520,17 @@ class PI05Pytorch(nn.Module): # see openpi `PI0Pytorch`
|
|||||||
paligemma_config = get_gemma_config(config.paligemma_variant)
|
paligemma_config = get_gemma_config(config.paligemma_variant)
|
||||||
action_expert_config = get_gemma_config(config.action_expert_variant)
|
action_expert_config = get_gemma_config(config.action_expert_variant)
|
||||||
|
|
||||||
|
if config.image_resolution[0] != config.image_resolution[1]:
|
||||||
|
raise ValueError(
|
||||||
|
f"PaliGemma expects square image resolution, invalid resolution: {config.image_resolution}"
|
||||||
|
)
|
||||||
|
|
||||||
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
self.paligemma_with_expert = PaliGemmaWithExpertModel(
|
||||||
paligemma_config,
|
paligemma_config,
|
||||||
action_expert_config,
|
action_expert_config,
|
||||||
use_adarms=[False, True],
|
use_adarms=[False, True],
|
||||||
precision=config.dtype,
|
precision=config.dtype,
|
||||||
|
image_size=config.image_resolution[0],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
self.action_in_proj = nn.Linear(config.max_action_dim, action_expert_config.width)
|
||||||
|
|||||||
Reference in New Issue
Block a user