This commit is contained in:
Jade Choghari
2025-11-17 14:43:14 +01:00
parent 42d615b69d
commit 8591fc10b3
3 changed files with 5 additions and 9 deletions
@@ -70,7 +70,7 @@ class XVLAConfig(PreTrainedConfig):
num_domains: int = 30
len_soft_prompts: int = 32
dim_time: int = 32
max_len_seq: int = 512 # TODO: jadechoghari: change to 512 1024
max_len_seq: int = 512
use_hetero_proj: bool = False
# Action & proprioception
+4 -8
View File
@@ -34,7 +34,7 @@ from .action_hub import build_action_space
from .configuration_florence2 import Florence2Config
from .configuration_xvla import XVLAConfig, XVLAConfig as PreTrainedConfig
from .modeling_florence2 import Florence2ForConditionalGeneration
from .transformer import SoftPromptedTransformer
from .soft_transformer import SoftPromptedTransformer
class XVLAModel(nn.Module):
@@ -68,7 +68,7 @@ class XVLAModel(nn.Module):
if projection_dim is None:
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
self.transformer = SoftPromptedTransformer(
self.soft_prompted_transformer = SoftPromptedTransformer(
hidden_size=config.hidden_size,
multi_modal_input_size=projection_dim,
depth=config.depth,
@@ -95,10 +95,6 @@ class XVLAModel(nn.Module):
batch_size, num_views = pixel_values.shape[:2]
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
flat_images = pixel_values.flatten(0, 1)
# TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this
# target_size = (224, 224)
# flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
num_valid = int(flat_mask.sum().item())
if num_valid == 0:
raise ValueError("At least one image view must be valid per batch.")
@@ -144,7 +140,7 @@ class XVLAModel(nn.Module):
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
pred_action = self.transformer(
pred_action = self.soft_prompted_transformer(
domain_id=domain_id,
action_with_noise=action_noisy_m,
t=t,
@@ -177,7 +173,7 @@ class XVLAModel(nn.Module):
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
action = self.transformer(
action = self.soft_prompted_transformer(
domain_id=domain_id,
action_with_noise=x_t_m,
proprio=proprio_m,