mirror of
https://github.com/huggingface/lerobot.git
synced 2026-06-29 06:07:40 +00:00
renaming
This commit is contained in:
@@ -70,7 +70,7 @@ class XVLAConfig(PreTrainedConfig):
|
||||
num_domains: int = 30
|
||||
len_soft_prompts: int = 32
|
||||
dim_time: int = 32
|
||||
max_len_seq: int = 512 # TODO: jadechoghari: change to 512 1024
|
||||
max_len_seq: int = 512
|
||||
use_hetero_proj: bool = False
|
||||
|
||||
# Action & proprioception
|
||||
|
||||
@@ -34,7 +34,7 @@ from .action_hub import build_action_space
|
||||
from .configuration_florence2 import Florence2Config
|
||||
from .configuration_xvla import XVLAConfig, XVLAConfig as PreTrainedConfig
|
||||
from .modeling_florence2 import Florence2ForConditionalGeneration
|
||||
from .transformer import SoftPromptedTransformer
|
||||
from .soft_transformer import SoftPromptedTransformer
|
||||
|
||||
|
||||
class XVLAModel(nn.Module):
|
||||
@@ -68,7 +68,7 @@ class XVLAModel(nn.Module):
|
||||
if projection_dim is None:
|
||||
raise ValueError("Florence2 config must provide `projection_dim` for multimodal fusion.")
|
||||
|
||||
self.transformer = SoftPromptedTransformer(
|
||||
self.soft_prompted_transformer = SoftPromptedTransformer(
|
||||
hidden_size=config.hidden_size,
|
||||
multi_modal_input_size=projection_dim,
|
||||
depth=config.depth,
|
||||
@@ -95,10 +95,6 @@ class XVLAModel(nn.Module):
|
||||
batch_size, num_views = pixel_values.shape[:2]
|
||||
flat_mask = image_mask.view(-1).to(dtype=torch.bool)
|
||||
flat_images = pixel_values.flatten(0, 1)
|
||||
# TODO: jadechoghari: remove this resizing logic, and provide a way in training to do this
|
||||
# target_size = (224, 224)
|
||||
# flat_images = F.interpolate(flat_images, size=target_size, mode="bilinear", align_corners=False)
|
||||
|
||||
num_valid = int(flat_mask.sum().item())
|
||||
if num_valid == 0:
|
||||
raise ValueError("At least one image view must be valid per batch.")
|
||||
@@ -144,7 +140,7 @@ class XVLAModel(nn.Module):
|
||||
action_noisy = torch.randn_like(action) * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, action_noisy_m = self.action_space.preprocess(proprio, action_noisy)
|
||||
|
||||
pred_action = self.transformer(
|
||||
pred_action = self.soft_prompted_transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=action_noisy_m,
|
||||
t=t,
|
||||
@@ -177,7 +173,7 @@ class XVLAModel(nn.Module):
|
||||
t = torch.full((batch_size,), i / steps, device=proprio.device, dtype=proprio.dtype)
|
||||
x_t = x1 * t.view(-1, 1, 1) + action * (1 - t).view(-1, 1, 1)
|
||||
proprio_m, x_t_m = self.action_space.preprocess(proprio, x_t)
|
||||
action = self.transformer(
|
||||
action = self.soft_prompted_transformer(
|
||||
domain_id=domain_id,
|
||||
action_with_noise=x_t_m,
|
||||
proprio=proprio_m,
|
||||
|
||||
Reference in New Issue
Block a user